fix: Vision filter TLS 1.2

Add magic from sing-box. 5ce3ddee9b/transport/vless/vision.go (L199)
This commit is contained in:
gVisor bot 2023-02-27 01:06:41 +08:00
parent e45b8dc404
commit 6b4e46af72
2 changed files with 37 additions and 28 deletions

View file

@ -248,7 +248,7 @@ func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error {
if vc.writeDirect { if vc.writeDirect {
vc.ExtendedWriter = N.NewExtendedWriter(vc.Conn) vc.ExtendedWriter = N.NewExtendedWriter(vc.Conn)
log.Debugln("XTLS Vision direct write start") log.Debugln("XTLS Vision direct write start")
//time.Sleep(10 * time.Millisecond) //time.Sleep(5 * time.Millisecond)
} }
if buffer2 != nil { if buffer2 != nil {
if vc.writeDirect || !vc.isTLS { if vc.writeDirect || !vc.isTLS {
@ -393,22 +393,22 @@ func (vc *Conn) sendRequest(p []byte) bool {
} }
func (vc *Conn) recvResponse() error { func (vc *Conn) recvResponse() error {
var buf [1]byte var buffer [1]byte
_, vc.err = io.ReadFull(vc.ExtendedReader, buf[:]) _, vc.err = io.ReadFull(vc.ExtendedReader, buffer[:])
if vc.err != nil { if vc.err != nil {
return vc.err return vc.err
} }
if buf[0] != Version { if buffer[0] != Version {
return errors.New("unexpected response version") return errors.New("unexpected response version")
} }
_, vc.err = io.ReadFull(vc.ExtendedReader, buf[:]) _, vc.err = io.ReadFull(vc.ExtendedReader, buffer[:])
if vc.err != nil { if vc.err != nil {
return vc.err return vc.err
} }
length := int64(buf[0]) length := int64(buffer[0])
if length != 0 { // addon data length > 0 if length != 0 { // addon data length > 0
io.CopyN(io.Discard, vc.ExtendedReader, length) // just discard io.CopyN(io.Discard, vc.ExtendedReader, length) // just discard
} }
@ -482,19 +482,23 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) {
var p uintptr var p uintptr
switch underlying := conn.(type) { switch underlying := conn.(type) {
case *gotls.Conn: case *gotls.Conn:
//log.Debugln("type tls")
c.Conn = underlying.NetConn() c.Conn = underlying.NetConn()
c.tlsConn = underlying c.tlsConn = underlying
t = reflect.TypeOf(underlying).Elem() t = reflect.TypeOf(underlying).Elem()
p = uintptr(unsafe.Pointer(underlying)) p = uintptr(unsafe.Pointer(underlying))
case *utls.UConn: case *utls.UConn:
//log.Debugln("type *utls.UConn")
c.Conn = underlying.NetConn() c.Conn = underlying.NetConn()
c.tlsConn = underlying c.tlsConn = underlying
t = reflect.TypeOf(underlying.Conn).Elem() t = reflect.TypeOf(underlying.Conn).Elem()
p = uintptr(unsafe.Pointer(underlying.Conn)) p = uintptr(unsafe.Pointer(underlying.Conn))
case *tlsC.UConn: case *tlsC.UConn:
//log.Debugln("type *tlsC.UConn")
c.Conn = underlying.NetConn() c.Conn = underlying.NetConn()
c.tlsConn = underlying.UConn c.tlsConn = underlying.UConn
t = reflect.TypeOf(underlying.Conn).Elem() t = reflect.TypeOf(underlying.Conn).Elem()
//log.Debugln("t:%v", t)
p = uintptr(unsafe.Pointer(underlying.Conn)) p = uintptr(unsafe.Pointer(underlying.Conn))
default: default:
return nil, fmt.Errorf(`failed to use %s, maybe "security" is not "tls" or "utls"`, client.Addons.Flow) return nil, fmt.Errorf(`failed to use %s, maybe "security" is not "tls" or "utls"`, client.Addons.Flow)
@ -503,9 +507,9 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) {
r, _ := t.FieldByName("rawInput") r, _ := t.FieldByName("rawInput")
c.input = (*bytes.Reader)(unsafe.Pointer(p + i.Offset)) c.input = (*bytes.Reader)(unsafe.Pointer(p + i.Offset))
c.rawInput = (*bytes.Buffer)(unsafe.Pointer(p + r.Offset)) c.rawInput = (*bytes.Buffer)(unsafe.Pointer(p + r.Offset))
// if _, ok := c.Conn.(*net.TCPConn); !ok { //if _, ok := c.Conn.(*net.TCPConn); !ok {
// log.Debugln("XTLS underlying conn is not *net.TCPConn, got %T", c.Conn) // log.Debugln("XTLS underlying conn is not *net.TCPConn, got %T", c.Conn)
// } //}
} }
} }

View file

@ -4,7 +4,7 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
log "github.com/sirupsen/logrus" "github.com/Dreamacro/clash/log"
) )
var ( var (
@ -27,24 +27,30 @@ const (
tlsHandshakeTypeServerHello byte = 0x02 tlsHandshakeTypeServerHello byte = 0x02
) )
func (vc *Conn) FilterTLS(p []byte) (index int) { func (vc *Conn) FilterTLS(buffer []byte) (index int) {
if vc.packetsToFilter <= 0 { if vc.packetsToFilter <= 0 {
return 0 return 0
} }
lenP := len(p) lenP := len(buffer)
vc.packetsToFilter -= 1 vc.packetsToFilter--
if index = bytes.Index(p, tlsServerHandshakeStart); index != -1 { if index := bytes.Index(buffer, tlsServerHandshakeStart); index != -1 {
if lenP >= index+5 && p[index+5] == tlsHandshakeTypeServerHello { if lenP >= index+5 {
vc.remainingServerHello = binary.BigEndian.Uint16(p[index+3:]) + 5 if buffer[0] == 22 && buffer[1] == 3 && buffer[2] == 3 {
vc.isTLS = true vc.isTLS = true
vc.isTLS12orAbove = true if buffer[5] == tlsHandshakeTypeServerHello {
if lenP-index >= 79 && vc.remainingServerHello >= 79 { log.Debugln("isTLS12orAbove")
sessionIDLen := int(p[index+43]) vc.remainingServerHello = binary.BigEndian.Uint16(buffer[index+3:]) + 5
vc.cipher = binary.BigEndian.Uint16(p[index+43+sessionIDLen+1:])
vc.isTLS12orAbove = true
if lenP-index >= 79 && vc.remainingServerHello >= 79 {
sessionIDLen := int(buffer[index+43])
vc.cipher = binary.BigEndian.Uint16(buffer[index+43+sessionIDLen+1:])
}
}
} }
} }
} else if index = bytes.Index(p, tlsClientHandshakeStart); index != -1 { } else if index := bytes.Index(buffer, tlsClientHandshakeStart); index != -1 {
if lenP >= index+5 && p[index+5] == tlsHandshakeTypeClientHello { if lenP >= index+5 && buffer[index+5] == tlsHandshakeTypeClientHello {
vc.isTLS = true vc.isTLS = true
} }
} }
@ -62,22 +68,21 @@ func (vc *Conn) FilterTLS(p []byte) (index int) {
vc.remainingServerHello -= uint16(end) vc.remainingServerHello -= uint16(end)
end += i end += i
} }
if bytes.Contains(p[i:end], tls13SupportedVersions) { if bytes.Contains(buffer[i:end], tls13SupportedVersions) {
// TLS 1.3 Client Hello // TLS 1.3 Client Hello
cs, ok := tls13CipherSuiteMap[vc.cipher] cs, ok := tls13CipherSuiteMap[vc.cipher]
if ok && cs != "TLS_AES_128_CCM_8_SHA256" { if ok && cs != "TLS_AES_128_CCM_8_SHA256" {
vc.enableXTLS = true vc.enableXTLS = true
} }
log.Debugln("XTLS Vision found TLS 1.3, packetLength=", lenP, ", CipherSuite=", cs) log.Debugln("XTLS Vision found TLS 1.3, packetLength= %d CipherSuite= %s", lenP, cs)
vc.packetsToFilter = 0 vc.packetsToFilter = 0
return return
} else if vc.remainingServerHello <= 0 { } else if vc.remainingServerHello <= 0 {
log.Debugln("XTLS Vision found TLS 1.2, packetLength=", lenP) log.Debugln("XTLS Vision found TLS 1.2, packetLength= %d", lenP)
vc.packetsToFilter = 0 vc.packetsToFilter = 0
return return
} }
log.Debugln("XTLS Vision found inconclusive server hello, packetLength=", lenP, log.Debugln("XTLS Vision found inconclusive server hello, packetLength= %d,remainingServerHelloBytes= %d", lenP, vc.remainingServerHello)
", remainingServerHelloBytes=", vc.remainingServerHello)
} }
if vc.packetsToFilter <= 0 { if vc.packetsToFilter <= 0 {
log.Debugln("XTLS Vision stop filtering") log.Debugln("XTLS Vision stop filtering")