diff --git a/transport/vless/conn.go b/transport/vless/conn.go index 0f4d3bf7..aceda463 100644 --- a/transport/vless/conn.go +++ b/transport/vless/conn.go @@ -57,9 +57,14 @@ func (vc *Conn) Write(p []byte) (int, error) { select { case <-vc.handshake: default: - err := vc.sendRequest(p) - if err != nil { - return 0, err + if vc.sendRequest(p) { + if vc.err != nil { + return 0, vc.err + } + return len(p), vc.err + } + if vc.err != nil { + return 0, vc.err } } return vc.ExtendedConn.Write(p) @@ -69,21 +74,25 @@ func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error { select { case <-vc.handshake: default: - err := vc.sendRequest(buffer.Bytes()) - if err != nil { - return err + if vc.sendRequest(buffer.Bytes()) { + return vc.err + } + if vc.err != nil { + return vc.err } } return vc.ExtendedConn.WriteBuffer(buffer) } -func (vc *Conn) sendRequest(p []byte) (err error) { +func (vc *Conn) sendRequest(p []byte) bool { vc.handshakeMutex.Lock() defer vc.handshakeMutex.Unlock() select { case <-vc.handshake: - return vc.err + // The handshake has been completed before. + // So return false to remind the caller. + return false default: } defer close(vc.handshake) @@ -93,9 +102,9 @@ func (vc *Conn) sendRequest(p []byte) (err error) { requestLen += 1 // addons length var addonsBytes []byte if vc.addons != nil { - addonsBytes, err = proto.Marshal(vc.addons) - if err != nil { - return err + addonsBytes, vc.err = proto.Marshal(vc.addons) + if vc.err != nil { + return true } } requestLen += len(addonsBytes) @@ -137,8 +146,8 @@ func (vc *Conn) sendRequest(p []byte) (err error) { buf.Must(buf.Error(buffer.Write(p))) - _, err = vc.ExtendedConn.Write(buffer.Bytes()) - return + _, vc.err = vc.ExtendedConn.Write(buffer.Bytes()) + return true } func (vc *Conn) recvResponse() error { @@ -203,7 +212,7 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) { select { case <-c.handshake: case <-time.After(200 * time.Millisecond): - _ = c.sendRequest(nil) + c.sendRequest(nil) } }() return c, nil