From e0b3b4515e22fd3495d09c23ba71e7835c8c5f64 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Mon, 27 Feb 2023 12:02:44 +0800 Subject: [PATCH] adjust: Simplify VLESS handshake lock --- transport/vless/conn.go | 67 ++++++++++++++++++--------------------- transport/vless/filter.go | 13 ++++---- 2 files changed, 37 insertions(+), 43 deletions(-) diff --git a/transport/vless/conn.go b/transport/vless/conn.go index 53dc7c85..c8477090 100644 --- a/transport/vless/conn.go +++ b/transport/vless/conn.go @@ -33,8 +33,8 @@ type Conn struct { addons *Addons received bool - handshake chan struct{} handshakeMutex sync.Mutex + needHandshake bool err error tlsConn net.Conn @@ -181,19 +181,25 @@ func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error { } func (vc *Conn) Write(p []byte) (int, error) { - select { - case <-vc.handshake: - default: - if vc.sendRequest(p) { + if vc.needHandshake { + vc.handshakeMutex.Lock() + if vc.needHandshake { + vc.needHandshake = false + if vc.sendRequest(p) { + vc.handshakeMutex.Unlock() + if vc.err != nil { + return 0, vc.err + } + return len(p), vc.err + } if vc.err != nil { + vc.handshakeMutex.Unlock() return 0, vc.err } - return len(p), vc.err - } - if vc.err != nil { - return 0, vc.err } + vc.handshakeMutex.Unlock() } + if vc.writeFilterApplicationData { _buffer := buf.StackNew() defer buf.KeepAlive(_buffer) @@ -210,16 +216,22 @@ func (vc *Conn) Write(p []byte) (int, error) { } func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error { - select { - case <-vc.handshake: - default: - if vc.sendRequest(buffer.Bytes()) { - return vc.err - } - if vc.err != nil { - return vc.err + if vc.needHandshake { + vc.handshakeMutex.Lock() + if vc.needHandshake { + vc.needHandshake = false + if vc.sendRequest(buffer.Bytes()) { + vc.handshakeMutex.Unlock() + return vc.err + } + if vc.err != nil { + vc.handshakeMutex.Unlock() + return vc.err + } } + vc.handshakeMutex.Unlock() } + if vc.writeFilterApplicationData { buffer2 := ReshapeBuffer(buffer) defer buffer2.Release() @@ -281,18 +293,6 @@ func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error { } func (vc *Conn) sendRequest(p []byte) bool { - vc.handshakeMutex.Lock() - defer vc.handshakeMutex.Unlock() - - select { - case <-vc.handshake: - // The handshake has been completed before. - // So return false to remind the caller. - return false - default: - } - defer close(vc.handshake) - var addonsBytes []byte if vc.addons != nil { addonsBytes, vc.err = proto.Marshal(vc.addons) @@ -431,12 +431,7 @@ func (vc *Conn) Upstream() any { } func (vc *Conn) NeedHandshake() bool { - select { - case <-vc.handshake: - return false - default: - } - return true + return vc.needHandshake } func (vc *Conn) IsXTLSVisionEnabled() bool { @@ -451,7 +446,7 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) { Conn: conn, id: client.uuid, dst: dst, - handshake: make(chan struct{}), + needHandshake: true, } if !dst.UDP && client.Addons != nil { diff --git a/transport/vless/filter.go b/transport/vless/filter.go index e16100d5..3ddfb8b9 100644 --- a/transport/vless/filter.go +++ b/transport/vless/filter.go @@ -33,14 +33,13 @@ func (vc *Conn) FilterTLS(buffer []byte) (index int) { } lenP := len(buffer) vc.packetsToFilter-- - if index := bytes.Index(buffer, tlsServerHandshakeStart); index != -1 { + if index = bytes.Index(buffer, tlsServerHandshakeStart); index != -1 { if lenP >= index+5 { if buffer[0] == 22 && buffer[1] == 3 && buffer[2] == 3 { vc.isTLS = true if buffer[5] == tlsHandshakeTypeServerHello { - log.Debugln("isTLS12orAbove") + //log.Debugln("isTLS12orAbove") vc.remainingServerHello = binary.BigEndian.Uint16(buffer[index+3:]) + 5 - vc.isTLS12orAbove = true if lenP-index >= 79 && vc.remainingServerHello >= 79 { sessionIDLen := int(buffer[index+43]) @@ -49,7 +48,7 @@ func (vc *Conn) FilterTLS(buffer []byte) (index int) { } } } - } else if index := bytes.Index(buffer, tlsClientHandshakeStart); index != -1 { + } else if index = bytes.Index(buffer, tlsClientHandshakeStart); index != -1 { if lenP >= index+5 && buffer[index+5] == tlsHandshakeTypeClientHello { vc.isTLS = true } @@ -74,15 +73,15 @@ func (vc *Conn) FilterTLS(buffer []byte) (index int) { if ok && cs != "TLS_AES_128_CCM_8_SHA256" { vc.enableXTLS = true } - log.Debugln("XTLS Vision found TLS 1.3, packetLength= %d CipherSuite= %s", lenP, cs) + log.Debugln("XTLS Vision found TLS 1.3, packetLength=%d, CipherSuite=%s", lenP, cs) vc.packetsToFilter = 0 return } else if vc.remainingServerHello <= 0 { - log.Debugln("XTLS Vision found TLS 1.2, packetLength= %d", lenP) + log.Debugln("XTLS Vision found TLS 1.2, packetLength=%d", lenP) vc.packetsToFilter = 0 return } - log.Debugln("XTLS Vision found inconclusive server hello, packetLength= %d,remainingServerHelloBytes= %d", lenP, vc.remainingServerHello) + log.Debugln("XTLS Vision found inconclusive server hello, packetLength=%d, remainingServerHelloBytes=%d", lenP, vc.remainingServerHello) } if vc.packetsToFilter <= 0 { log.Debugln("XTLS Vision stop filtering")