adjust: Simplify VLESS handshake lock

This commit is contained in:
Hellojack 2023-02-27 12:02:44 +08:00 committed by GitHub
parent 76ccebf099
commit ecb2a5f3c6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 43 deletions

View file

@ -33,8 +33,8 @@ type Conn struct {
addons *Addons addons *Addons
received bool received bool
handshake chan struct{}
handshakeMutex sync.Mutex handshakeMutex sync.Mutex
needHandshake bool
err error err error
tlsConn net.Conn tlsConn net.Conn
@ -181,19 +181,25 @@ func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error {
} }
func (vc *Conn) Write(p []byte) (int, error) { func (vc *Conn) Write(p []byte) (int, error) {
select { if vc.needHandshake {
case <-vc.handshake: vc.handshakeMutex.Lock()
default: if vc.needHandshake {
vc.needHandshake = false
if vc.sendRequest(p) { if vc.sendRequest(p) {
vc.handshakeMutex.Unlock()
if vc.err != nil { if vc.err != nil {
return 0, vc.err return 0, vc.err
} }
return len(p), vc.err return len(p), vc.err
} }
if vc.err != nil { if vc.err != nil {
vc.handshakeMutex.Unlock()
return 0, vc.err return 0, vc.err
} }
} }
vc.handshakeMutex.Unlock()
}
if vc.writeFilterApplicationData { if vc.writeFilterApplicationData {
_buffer := buf.StackNew() _buffer := buf.StackNew()
defer buf.KeepAlive(_buffer) defer buf.KeepAlive(_buffer)
@ -210,16 +216,22 @@ func (vc *Conn) Write(p []byte) (int, error) {
} }
func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error { func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error {
select { if vc.needHandshake {
case <-vc.handshake: vc.handshakeMutex.Lock()
default: if vc.needHandshake {
vc.needHandshake = false
if vc.sendRequest(buffer.Bytes()) { if vc.sendRequest(buffer.Bytes()) {
vc.handshakeMutex.Unlock()
return vc.err return vc.err
} }
if vc.err != nil { if vc.err != nil {
vc.handshakeMutex.Unlock()
return vc.err return vc.err
} }
} }
vc.handshakeMutex.Unlock()
}
if vc.writeFilterApplicationData { if vc.writeFilterApplicationData {
buffer2 := ReshapeBuffer(buffer) buffer2 := ReshapeBuffer(buffer)
defer buffer2.Release() defer buffer2.Release()
@ -281,18 +293,6 @@ func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error {
} }
func (vc *Conn) sendRequest(p []byte) bool { func (vc *Conn) sendRequest(p []byte) bool {
vc.handshakeMutex.Lock()
defer vc.handshakeMutex.Unlock()
select {
case <-vc.handshake:
// The handshake has been completed before.
// So return false to remind the caller.
return false
default:
}
defer close(vc.handshake)
var addonsBytes []byte var addonsBytes []byte
if vc.addons != nil { if vc.addons != nil {
addonsBytes, vc.err = proto.Marshal(vc.addons) addonsBytes, vc.err = proto.Marshal(vc.addons)
@ -431,12 +431,7 @@ func (vc *Conn) Upstream() any {
} }
func (vc *Conn) NeedHandshake() bool { func (vc *Conn) NeedHandshake() bool {
select { return vc.needHandshake
case <-vc.handshake:
return false
default:
}
return true
} }
func (vc *Conn) IsXTLSVisionEnabled() bool { func (vc *Conn) IsXTLSVisionEnabled() bool {
@ -451,7 +446,7 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) {
Conn: conn, Conn: conn,
id: client.uuid, id: client.uuid,
dst: dst, dst: dst,
handshake: make(chan struct{}), needHandshake: true,
} }
if !dst.UDP && client.Addons != nil { if !dst.UDP && client.Addons != nil {

View file

@ -33,14 +33,13 @@ func (vc *Conn) FilterTLS(buffer []byte) (index int) {
} }
lenP := len(buffer) lenP := len(buffer)
vc.packetsToFilter-- vc.packetsToFilter--
if index := bytes.Index(buffer, tlsServerHandshakeStart); index != -1 { if index = bytes.Index(buffer, tlsServerHandshakeStart); index != -1 {
if lenP >= index+5 { if lenP >= index+5 {
if buffer[0] == 22 && buffer[1] == 3 && buffer[2] == 3 { if buffer[0] == 22 && buffer[1] == 3 && buffer[2] == 3 {
vc.isTLS = true vc.isTLS = true
if buffer[5] == tlsHandshakeTypeServerHello { if buffer[5] == tlsHandshakeTypeServerHello {
log.Debugln("isTLS12orAbove") //log.Debugln("isTLS12orAbove")
vc.remainingServerHello = binary.BigEndian.Uint16(buffer[index+3:]) + 5 vc.remainingServerHello = binary.BigEndian.Uint16(buffer[index+3:]) + 5
vc.isTLS12orAbove = true vc.isTLS12orAbove = true
if lenP-index >= 79 && vc.remainingServerHello >= 79 { if lenP-index >= 79 && vc.remainingServerHello >= 79 {
sessionIDLen := int(buffer[index+43]) sessionIDLen := int(buffer[index+43])
@ -49,7 +48,7 @@ func (vc *Conn) FilterTLS(buffer []byte) (index int) {
} }
} }
} }
} else if index := bytes.Index(buffer, tlsClientHandshakeStart); index != -1 { } else if index = bytes.Index(buffer, tlsClientHandshakeStart); index != -1 {
if lenP >= index+5 && buffer[index+5] == tlsHandshakeTypeClientHello { if lenP >= index+5 && buffer[index+5] == tlsHandshakeTypeClientHello {
vc.isTLS = true vc.isTLS = true
} }
@ -74,15 +73,15 @@ func (vc *Conn) FilterTLS(buffer []byte) (index int) {
if ok && cs != "TLS_AES_128_CCM_8_SHA256" { if ok && cs != "TLS_AES_128_CCM_8_SHA256" {
vc.enableXTLS = true vc.enableXTLS = true
} }
log.Debugln("XTLS Vision found TLS 1.3, packetLength= %d CipherSuite= %s", lenP, cs) log.Debugln("XTLS Vision found TLS 1.3, packetLength=%d CipherSuite=%s", lenP, cs)
vc.packetsToFilter = 0 vc.packetsToFilter = 0
return return
} else if vc.remainingServerHello <= 0 { } else if vc.remainingServerHello <= 0 {
log.Debugln("XTLS Vision found TLS 1.2, packetLength= %d", lenP) log.Debugln("XTLS Vision found TLS 1.2, packetLength=%d", lenP)
vc.packetsToFilter = 0 vc.packetsToFilter = 0
return return
} }
log.Debugln("XTLS Vision found inconclusive server hello, packetLength= %d,remainingServerHelloBytes= %d", lenP, vc.remainingServerHello) log.Debugln("XTLS Vision found inconclusive server hello, packetLength=%d, remainingServerHelloBytes=%d", lenP, vc.remainingServerHello)
} }
if vc.packetsToFilter <= 0 { if vc.packetsToFilter <= 0 {
log.Debugln("XTLS Vision stop filtering") log.Debugln("XTLS Vision stop filtering")