adjust: Simplify VLESS handshake lock
This commit is contained in:
parent
76ccebf099
commit
ecb2a5f3c6
2 changed files with 37 additions and 43 deletions
|
@ -33,8 +33,8 @@ type Conn struct {
|
||||||
addons *Addons
|
addons *Addons
|
||||||
received bool
|
received bool
|
||||||
|
|
||||||
handshake chan struct{}
|
|
||||||
handshakeMutex sync.Mutex
|
handshakeMutex sync.Mutex
|
||||||
|
needHandshake bool
|
||||||
err error
|
err error
|
||||||
|
|
||||||
tlsConn net.Conn
|
tlsConn net.Conn
|
||||||
|
@ -181,19 +181,25 @@ func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vc *Conn) Write(p []byte) (int, error) {
|
func (vc *Conn) Write(p []byte) (int, error) {
|
||||||
select {
|
if vc.needHandshake {
|
||||||
case <-vc.handshake:
|
vc.handshakeMutex.Lock()
|
||||||
default:
|
if vc.needHandshake {
|
||||||
|
vc.needHandshake = false
|
||||||
if vc.sendRequest(p) {
|
if vc.sendRequest(p) {
|
||||||
|
vc.handshakeMutex.Unlock()
|
||||||
if vc.err != nil {
|
if vc.err != nil {
|
||||||
return 0, vc.err
|
return 0, vc.err
|
||||||
}
|
}
|
||||||
return len(p), vc.err
|
return len(p), vc.err
|
||||||
}
|
}
|
||||||
if vc.err != nil {
|
if vc.err != nil {
|
||||||
|
vc.handshakeMutex.Unlock()
|
||||||
return 0, vc.err
|
return 0, vc.err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
vc.handshakeMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
if vc.writeFilterApplicationData {
|
if vc.writeFilterApplicationData {
|
||||||
_buffer := buf.StackNew()
|
_buffer := buf.StackNew()
|
||||||
defer buf.KeepAlive(_buffer)
|
defer buf.KeepAlive(_buffer)
|
||||||
|
@ -210,16 +216,22 @@ func (vc *Conn) Write(p []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error {
|
func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error {
|
||||||
select {
|
if vc.needHandshake {
|
||||||
case <-vc.handshake:
|
vc.handshakeMutex.Lock()
|
||||||
default:
|
if vc.needHandshake {
|
||||||
|
vc.needHandshake = false
|
||||||
if vc.sendRequest(buffer.Bytes()) {
|
if vc.sendRequest(buffer.Bytes()) {
|
||||||
|
vc.handshakeMutex.Unlock()
|
||||||
return vc.err
|
return vc.err
|
||||||
}
|
}
|
||||||
if vc.err != nil {
|
if vc.err != nil {
|
||||||
|
vc.handshakeMutex.Unlock()
|
||||||
return vc.err
|
return vc.err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
vc.handshakeMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
if vc.writeFilterApplicationData {
|
if vc.writeFilterApplicationData {
|
||||||
buffer2 := ReshapeBuffer(buffer)
|
buffer2 := ReshapeBuffer(buffer)
|
||||||
defer buffer2.Release()
|
defer buffer2.Release()
|
||||||
|
@ -281,18 +293,6 @@ func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vc *Conn) sendRequest(p []byte) bool {
|
func (vc *Conn) sendRequest(p []byte) bool {
|
||||||
vc.handshakeMutex.Lock()
|
|
||||||
defer vc.handshakeMutex.Unlock()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-vc.handshake:
|
|
||||||
// The handshake has been completed before.
|
|
||||||
// So return false to remind the caller.
|
|
||||||
return false
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
defer close(vc.handshake)
|
|
||||||
|
|
||||||
var addonsBytes []byte
|
var addonsBytes []byte
|
||||||
if vc.addons != nil {
|
if vc.addons != nil {
|
||||||
addonsBytes, vc.err = proto.Marshal(vc.addons)
|
addonsBytes, vc.err = proto.Marshal(vc.addons)
|
||||||
|
@ -431,12 +431,7 @@ func (vc *Conn) Upstream() any {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vc *Conn) NeedHandshake() bool {
|
func (vc *Conn) NeedHandshake() bool {
|
||||||
select {
|
return vc.needHandshake
|
||||||
case <-vc.handshake:
|
|
||||||
return false
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vc *Conn) IsXTLSVisionEnabled() bool {
|
func (vc *Conn) IsXTLSVisionEnabled() bool {
|
||||||
|
@ -451,7 +446,7 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) {
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
id: client.uuid,
|
id: client.uuid,
|
||||||
dst: dst,
|
dst: dst,
|
||||||
handshake: make(chan struct{}),
|
needHandshake: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
if !dst.UDP && client.Addons != nil {
|
if !dst.UDP && client.Addons != nil {
|
||||||
|
|
|
@ -33,14 +33,13 @@ func (vc *Conn) FilterTLS(buffer []byte) (index int) {
|
||||||
}
|
}
|
||||||
lenP := len(buffer)
|
lenP := len(buffer)
|
||||||
vc.packetsToFilter--
|
vc.packetsToFilter--
|
||||||
if index := bytes.Index(buffer, tlsServerHandshakeStart); index != -1 {
|
if index = bytes.Index(buffer, tlsServerHandshakeStart); index != -1 {
|
||||||
if lenP >= index+5 {
|
if lenP >= index+5 {
|
||||||
if buffer[0] == 22 && buffer[1] == 3 && buffer[2] == 3 {
|
if buffer[0] == 22 && buffer[1] == 3 && buffer[2] == 3 {
|
||||||
vc.isTLS = true
|
vc.isTLS = true
|
||||||
if buffer[5] == tlsHandshakeTypeServerHello {
|
if buffer[5] == tlsHandshakeTypeServerHello {
|
||||||
log.Debugln("isTLS12orAbove")
|
//log.Debugln("isTLS12orAbove")
|
||||||
vc.remainingServerHello = binary.BigEndian.Uint16(buffer[index+3:]) + 5
|
vc.remainingServerHello = binary.BigEndian.Uint16(buffer[index+3:]) + 5
|
||||||
|
|
||||||
vc.isTLS12orAbove = true
|
vc.isTLS12orAbove = true
|
||||||
if lenP-index >= 79 && vc.remainingServerHello >= 79 {
|
if lenP-index >= 79 && vc.remainingServerHello >= 79 {
|
||||||
sessionIDLen := int(buffer[index+43])
|
sessionIDLen := int(buffer[index+43])
|
||||||
|
@ -49,7 +48,7 @@ func (vc *Conn) FilterTLS(buffer []byte) (index int) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if index := bytes.Index(buffer, tlsClientHandshakeStart); index != -1 {
|
} else if index = bytes.Index(buffer, tlsClientHandshakeStart); index != -1 {
|
||||||
if lenP >= index+5 && buffer[index+5] == tlsHandshakeTypeClientHello {
|
if lenP >= index+5 && buffer[index+5] == tlsHandshakeTypeClientHello {
|
||||||
vc.isTLS = true
|
vc.isTLS = true
|
||||||
}
|
}
|
||||||
|
@ -74,15 +73,15 @@ func (vc *Conn) FilterTLS(buffer []byte) (index int) {
|
||||||
if ok && cs != "TLS_AES_128_CCM_8_SHA256" {
|
if ok && cs != "TLS_AES_128_CCM_8_SHA256" {
|
||||||
vc.enableXTLS = true
|
vc.enableXTLS = true
|
||||||
}
|
}
|
||||||
log.Debugln("XTLS Vision found TLS 1.3, packetLength= %d CipherSuite= %s", lenP, cs)
|
log.Debugln("XTLS Vision found TLS 1.3, packetLength=%d, CipherSuite=%s", lenP, cs)
|
||||||
vc.packetsToFilter = 0
|
vc.packetsToFilter = 0
|
||||||
return
|
return
|
||||||
} else if vc.remainingServerHello <= 0 {
|
} else if vc.remainingServerHello <= 0 {
|
||||||
log.Debugln("XTLS Vision found TLS 1.2, packetLength= %d", lenP)
|
log.Debugln("XTLS Vision found TLS 1.2, packetLength=%d", lenP)
|
||||||
vc.packetsToFilter = 0
|
vc.packetsToFilter = 0
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debugln("XTLS Vision found inconclusive server hello, packetLength= %d,remainingServerHelloBytes= %d", lenP, vc.remainingServerHello)
|
log.Debugln("XTLS Vision found inconclusive server hello, packetLength=%d, remainingServerHelloBytes=%d", lenP, vc.remainingServerHello)
|
||||||
}
|
}
|
||||||
if vc.packetsToFilter <= 0 {
|
if vc.packetsToFilter <= 0 {
|
||||||
log.Debugln("XTLS Vision stop filtering")
|
log.Debugln("XTLS Vision stop filtering")
|
||||||
|
|
Loading…
Reference in a new issue