fix: VLESS handshake write

This commit is contained in:
gVisor bot 2023-02-11 15:13:17 +08:00
parent 7021dc1878
commit 2e183fb53f

View file

@ -57,9 +57,14 @@ func (vc *Conn) Write(p []byte) (int, error) {
select { select {
case <-vc.handshake: case <-vc.handshake:
default: default:
err := vc.sendRequest(p) if vc.sendRequest(p) {
if err != nil { if vc.err != nil {
return 0, err return 0, vc.err
}
return len(p), vc.err
}
if vc.err != nil {
return 0, vc.err
} }
} }
return vc.ExtendedConn.Write(p) return vc.ExtendedConn.Write(p)
@ -69,21 +74,25 @@ func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error {
select { select {
case <-vc.handshake: case <-vc.handshake:
default: default:
err := vc.sendRequest(buffer.Bytes()) if vc.sendRequest(buffer.Bytes()) {
if err != nil { return vc.err
return err }
if vc.err != nil {
return vc.err
} }
} }
return vc.ExtendedConn.WriteBuffer(buffer) return vc.ExtendedConn.WriteBuffer(buffer)
} }
func (vc *Conn) sendRequest(p []byte) (err error) { func (vc *Conn) sendRequest(p []byte) bool {
vc.handshakeMutex.Lock() vc.handshakeMutex.Lock()
defer vc.handshakeMutex.Unlock() defer vc.handshakeMutex.Unlock()
select { select {
case <-vc.handshake: case <-vc.handshake:
return vc.err // The handshake has been completed before.
// So return false to remind the caller.
return false
default: default:
} }
defer close(vc.handshake) defer close(vc.handshake)
@ -93,9 +102,9 @@ func (vc *Conn) sendRequest(p []byte) (err error) {
requestLen += 1 // addons length requestLen += 1 // addons length
var addonsBytes []byte var addonsBytes []byte
if vc.addons != nil { if vc.addons != nil {
addonsBytes, err = proto.Marshal(vc.addons) addonsBytes, vc.err = proto.Marshal(vc.addons)
if err != nil { if vc.err != nil {
return err return true
} }
} }
requestLen += len(addonsBytes) requestLen += len(addonsBytes)
@ -137,8 +146,8 @@ func (vc *Conn) sendRequest(p []byte) (err error) {
buf.Must(buf.Error(buffer.Write(p))) buf.Must(buf.Error(buffer.Write(p)))
_, err = vc.ExtendedConn.Write(buffer.Bytes()) _, vc.err = vc.ExtendedConn.Write(buffer.Bytes())
return return true
} }
func (vc *Conn) recvResponse() error { func (vc *Conn) recvResponse() error {
@ -203,7 +212,7 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) {
select { select {
case <-c.handshake: case <-c.handshake:
case <-time.After(200 * time.Millisecond): case <-time.After(200 * time.Millisecond):
_ = c.sendRequest(nil) c.sendRequest(nil)
} }
}() }()
return c, nil return c, nil