From aaf1ce67d1a2278962f6199a5174754d7919a11c Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Sat, 25 Feb 2023 13:12:19 +0800 Subject: [PATCH] feat: Support VLESS XTLS Vision (#406) --- adapter/outbound/vless.go | 8 +- common/buf/sing.go | 6 + transport/vless/conn.go | 352 +++++++++++++++++++++++++++++++++----- transport/vless/filter.go | 79 +++++++++ transport/vless/vision.go | 69 ++++++++ transport/vless/vless.go | 1 + transport/vless/xtls.go | 5 + 7 files changed, 477 insertions(+), 43 deletions(-) create mode 100644 transport/vless/filter.go create mode 100644 transport/vless/vision.go diff --git a/adapter/outbound/vless.go b/adapter/outbound/vless.go index eef05687..010af23c 100644 --- a/adapter/outbound/vless.go +++ b/adapter/outbound/vless.go @@ -171,7 +171,7 @@ func (v *Vless) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { func (v *Vless) streamTLSOrXTLSConn(conn net.Conn, isH2 bool) (net.Conn, error) { host, _, _ := net.SplitHostPort(v.addr) - if v.isXTLSEnabled() && !isH2 { + if v.isLegacyXTLSEnabled() && !isH2 { xtlsOpts := vless.XTLSConfig{ Host: host, SkipCertVerify: v.option.SkipCertVerify, @@ -206,8 +206,8 @@ func (v *Vless) streamTLSOrXTLSConn(conn net.Conn, isH2 bool) (net.Conn, error) return conn, nil } -func (v *Vless) isXTLSEnabled() bool { - return v.client.Addons != nil +func (v *Vless) isLegacyXTLSEnabled() bool { + return v.client.Addons != nil && v.client.Addons.Flow != vless.XRV } // DialContext implements C.ProxyAdapter @@ -479,7 +479,7 @@ func NewVless(option VlessOption) (*Vless, error) { if option.Network != "ws" && len(option.Flow) >= 16 { option.Flow = option.Flow[:16] switch option.Flow { - case vless.XRO, vless.XRD, vless.XRS: + case vless.XRO, vless.XRD, vless.XRS, vless.XRV: addons = &vless.Addons{ Flow: option.Flow, } diff --git a/common/buf/sing.go b/common/buf/sing.go index b5e015f5..f86b5755 100644 --- a/common/buf/sing.go +++ b/common/buf/sing.go @@ -5,9 +5,15 @@ import ( "github.com/sagernet/sing/common/buf" ) +const BufferSize = buf.BufferSize + type Buffer = buf.Buffer +var New = buf.New +var StackNew = buf.StackNew var StackNewSize = buf.StackNewSize +var With = buf.With + var KeepAlive = common.KeepAlive //go:norace diff --git a/transport/vless/conn.go b/transport/vless/conn.go index e063d465..eae8868e 100644 --- a/transport/vless/conn.go +++ b/transport/vless/conn.go @@ -1,23 +1,33 @@ package vless import ( + "bytes" + "crypto/subtle" + gotls "crypto/tls" "encoding/binary" "errors" "fmt" "io" "net" + "reflect" "sync" + "unsafe" "github.com/Dreamacro/clash/common/buf" N "github.com/Dreamacro/clash/common/net" + tlsC "github.com/Dreamacro/clash/component/tls" + "github.com/Dreamacro/clash/log" "github.com/gofrs/uuid" + utls "github.com/sagernet/utls" xtls "github.com/xtls/go" "google.golang.org/protobuf/proto" ) type Conn struct { - N.ExtendedConn + N.ExtendedWriter + N.ExtendedReader + net.Conn dst *DstAddr id *uuid.UUID addons *Addons @@ -26,30 +36,143 @@ type Conn struct { handshake chan struct{} handshakeMutex sync.Mutex err error + + tlsConn net.Conn + input *bytes.Reader + rawInput *bytes.Buffer + + packetsToFilter int + isTLS bool + isTLS12orAbove bool + enableXTLS bool + cipher uint16 + remainingServerHello uint16 + readRemainingContent int + readRemainingPadding int + readProcess bool + readFilterUUID bool + readLastCommand byte + writeFilterApplicationData bool + writeDirect bool } func (vc *Conn) Read(b []byte) (int, error) { if vc.received { - return vc.ExtendedConn.Read(b) + if vc.readProcess { + buffer := buf.With(b) + err := vc.ReadBuffer(buffer) + return buffer.Len(), err + } + return vc.ExtendedReader.Read(b) } if err := vc.recvResponse(); err != nil { return 0, err } vc.received = true - return vc.ExtendedConn.Read(b) + return vc.Read(b) } func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error { if vc.received { - return vc.ExtendedConn.ReadBuffer(buffer) + toRead := buffer.FreeBytes() + if vc.readRemainingContent > 0 { + if vc.readRemainingContent < buffer.FreeLen() { + toRead = toRead[:vc.readRemainingContent] + } + n, err := vc.ExtendedReader.Read(toRead) + buffer.Truncate(n) + vc.readRemainingContent -= n + vc.FilterTLS(toRead) + return err + } + if vc.readRemainingPadding > 0 { + _, err := io.CopyN(io.Discard, vc.ExtendedReader, int64(vc.readRemainingPadding)) + if err != nil { + return err + } + vc.readRemainingPadding = 0 + } + if vc.readProcess { + switch vc.readLastCommand { + case commandPaddingContinue: + //if vc.isTLS || vc.packetsToFilter > 0 { + headerUUIDLen := 0 + if vc.readFilterUUID { + headerUUIDLen = uuid.Size + } + var header []byte + if need := headerUUIDLen + paddingHeaderLen; buffer.FreeLen() < need { + header = make([]byte, need) + } else { + header = buffer.FreeBytes()[:need] + } + _, err := io.ReadFull(vc.ExtendedReader, header) + if err != nil { + return err + } + pos := 0 + if vc.readFilterUUID { + vc.readFilterUUID = false + pos = uuid.Size + if subtle.ConstantTimeCompare(vc.id.Bytes(), header[:uuid.Size]) != 1 { + err = fmt.Errorf("XTLS Vision server responded unknown UUID: %s", + uuid.FromBytesOrNil(header[:uuid.Size]).String()) + log.Errorln(err.Error()) + return err + } + } + vc.readLastCommand = header[pos] + vc.readRemainingContent = int(binary.BigEndian.Uint16(header[pos+1:])) + vc.readRemainingPadding = int(binary.BigEndian.Uint16(header[pos+3:])) + log.Debugln("XTLS Vision read padding: command=%d, payloadLen=%d, paddingLen=%d", + vc.readLastCommand, vc.readRemainingContent, vc.readRemainingPadding) + return vc.ReadBuffer(buffer) + //} + case commandPaddingEnd: + vc.readProcess = false + return vc.ReadBuffer(buffer) + case commandPaddingDirect: + if vc.input != nil { + _, err := buffer.ReadFrom(vc.input) + if err != nil { + return err + } + if vc.input.Len() == 0 { + vc.input = nil + } + if buffer.IsFull() { + return nil + } + } + if vc.rawInput != nil { + _, err := buffer.ReadFrom(vc.rawInput) + if err != nil { + return err + } + if vc.rawInput.Len() == 0 { + vc.rawInput = nil + } + } + if vc.input == nil && vc.rawInput == nil { + vc.readProcess = false + vc.ExtendedReader = N.NewExtendedReader(vc.Conn) + log.Debugln("XTLS Vision direct read start") + } + default: + err := fmt.Errorf("XTLS Vision read unknown command: %d", vc.readLastCommand) + log.Debugln(err.Error()) + return err + } + } + return vc.ExtendedReader.ReadBuffer(buffer) } if err := vc.recvResponse(); err != nil { return err } vc.received = true - return vc.ExtendedConn.ReadBuffer(buffer) + return vc.ReadBuffer(buffer) } func (vc *Conn) Write(p []byte) (int, error) { @@ -66,7 +189,19 @@ func (vc *Conn) Write(p []byte) (int, error) { return 0, vc.err } } - return vc.ExtendedConn.Write(p) + if vc.writeFilterApplicationData { + _buffer := buf.StackNew() + defer buf.KeepAlive(_buffer) + buffer := buf.Dup(_buffer) + defer buffer.Release() + buffer.Write(p) + err := vc.WriteBuffer(buffer) + if err != nil { + return 0, err + } + return len(p), nil + } + return vc.ExtendedWriter.Write(p) } func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error { @@ -80,7 +215,57 @@ func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error { return vc.err } } - return vc.ExtendedConn.WriteBuffer(buffer) + if vc.writeFilterApplicationData && vc.isTLS { + buffer2 := ReshapeBuffer(buffer) + defer buffer2.Release() + vc.FilterTLS(buffer.Bytes()) + command := commandPaddingContinue + if buffer.Len() > 6 && bytes.Equal(buffer.To(3), tlsApplicationDataStart) || vc.packetsToFilter <= 0 { + command = commandPaddingEnd + if vc.enableXTLS { + command = commandPaddingDirect + vc.writeDirect = true + } + vc.writeFilterApplicationData = false + } + ApplyPadding(buffer, command, nil) + err := vc.ExtendedWriter.WriteBuffer(buffer) + if err != nil { + return err + } + if vc.writeDirect { + vc.ExtendedWriter = N.NewExtendedWriter(vc.Conn) + log.Debugln("XTLS Vision direct write start") + //time.Sleep(10 * time.Millisecond) + } + if buffer2 != nil { + if vc.writeDirect { + return vc.ExtendedWriter.WriteBuffer(buffer2) + } + vc.FilterTLS(buffer2.Bytes()) + command = commandPaddingContinue + if buffer2.Len() > 6 && bytes.Equal(buffer2.To(3), tlsApplicationDataStart) || vc.packetsToFilter <= 0 { + command = commandPaddingEnd + if vc.enableXTLS { + command = commandPaddingDirect + vc.writeDirect = true + } + vc.writeFilterApplicationData = false + } + ApplyPadding(buffer2, command, nil) + err = vc.ExtendedWriter.WriteBuffer(buffer2) + if vc.writeDirect { + vc.ExtendedWriter = N.NewExtendedWriter(vc.Conn) + log.Debugln("XTLS Vision direct write start") + //time.Sleep(10 * time.Millisecond) + } + } + return err + } + /*if vc.writeDirect { + log.Debugln("XTLS Vision Direct write, payloadLen=%d", buffer.Len()) + }*/ + return vc.ExtendedWriter.WriteBuffer(buffer) } func (vc *Conn) sendRequest(p []byte) bool { @@ -96,9 +281,6 @@ func (vc *Conn) sendRequest(p []byte) bool { } defer close(vc.handshake) - requestLen := 1 // protocol version - requestLen += 16 // UUID - requestLen += 1 // addons length var addonsBytes []byte if vc.addons != nil { addonsBytes, vc.err = proto.Marshal(vc.addons) @@ -106,19 +288,32 @@ func (vc *Conn) sendRequest(p []byte) bool { return true } } - requestLen += len(addonsBytes) - requestLen += 1 // command - if !vc.dst.Mux { - requestLen += 2 // port - requestLen += 1 // addr type - requestLen += len(vc.dst.Addr) - } - requestLen += len(p) + isVision := vc.IsXTLSVisionEnabled() - _buffer := buf.StackNewSize(requestLen) - defer buf.KeepAlive(_buffer) - buffer := buf.Dup(_buffer) - defer buffer.Release() + var buffer *buf.Buffer + if isVision { + _buffer := buf.StackNew() + defer buf.KeepAlive(_buffer) + buffer = buf.Dup(_buffer) + defer buffer.Release() + } else { + requestLen := 1 // protocol version + requestLen += 16 // UUID + requestLen += 1 // addons length + requestLen += len(addonsBytes) + requestLen += 1 // command + if !vc.dst.Mux { + requestLen += 2 // port + requestLen += 1 // addr type + requestLen += len(vc.dst.Addr) + } + requestLen += len(p) + + _buffer := buf.StackNewSize(requestLen) + defer buf.KeepAlive(_buffer) + buffer = buf.Dup(_buffer) + defer buffer.Release() + } buf.Must( buffer.WriteByte(Version), // protocol version @@ -143,15 +338,51 @@ func (vc *Conn) sendRequest(p []byte) bool { ) } - buf.Must(buf.Error(buffer.Write(p))) + if isVision && !vc.dst.UDP && !vc.dst.Mux { + if len(p) == 0 { + vc.packetsToFilter = 0 + vc.writeFilterApplicationData = false + WriteWithPadding(buffer, nil, commandPaddingEnd, vc.id) + } else { + vc.FilterTLS(p) + if vc.isTLS { + WriteWithPadding(buffer, p, commandPaddingContinue, vc.id) + } else { + buf.Must(buf.Error(buffer.Write(p))) + vc.readProcess = false + vc.writeFilterApplicationData = false + vc.packetsToFilter = 0 + } + } + } else { + buf.Must(buf.Error(buffer.Write(p))) + } - _, vc.err = vc.ExtendedConn.Write(buffer.Bytes()) + _, vc.err = vc.ExtendedWriter.Write(buffer.Bytes()) + if vc.err != nil { + return true + } + if isVision { + switch underlying := vc.tlsConn.(type) { + case *gotls.Conn: + if underlying.ConnectionState().Version != gotls.VersionTLS13 { + vc.err = ErrNotTLS13 + } + case *utls.UConn: + if underlying.ConnectionState().Version != utls.VersionTLS13 { + vc.err = ErrNotTLS13 + } + default: + vc.err = fmt.Errorf(`failed to use %s, maybe "security" is not "tls" or "utls"`, vc.addons.Flow) + } + vc.tlsConn = nil + } return true } func (vc *Conn) recvResponse() error { var buf [1]byte - _, vc.err = io.ReadFull(vc.ExtendedConn, buf[:]) + _, vc.err = io.ReadFull(vc.ExtendedReader, buf[:]) if vc.err != nil { return vc.err } @@ -160,30 +391,46 @@ func (vc *Conn) recvResponse() error { return errors.New("unexpected response version") } - _, vc.err = io.ReadFull(vc.ExtendedConn, buf[:]) + _, vc.err = io.ReadFull(vc.ExtendedReader, buf[:]) if vc.err != nil { return vc.err } length := int64(buf[0]) if length != 0 { // addon data length > 0 - io.CopyN(io.Discard, vc.ExtendedConn, length) // just discard + io.CopyN(io.Discard, vc.ExtendedReader, length) // just discard } return nil } +func (vc *Conn) FrontHeadroom() int { + if vc.IsXTLSVisionEnabled() { + return paddingHeaderLen + } + return 0 +} + func (vc *Conn) Upstream() any { - return vc.ExtendedConn + if vc.tlsConn == nil { + return vc.Conn + } + return vc.tlsConn +} + +func (vc *Conn) IsXTLSVisionEnabled() bool { + return vc.addons != nil && vc.addons.Flow == XRV } // newConn return a Conn instance func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) { c := &Conn{ - ExtendedConn: N.NewExtendedConn(conn), - id: client.uuid, - dst: dst, - handshake: make(chan struct{}), + ExtendedReader: N.NewExtendedReader(conn), + ExtendedWriter: N.NewExtendedWriter(conn), + Conn: conn, + id: client.uuid, + dst: dst, + handshake: make(chan struct{}), } if !dst.UDP && client.Addons != nil { @@ -204,15 +451,42 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) { } else { return nil, fmt.Errorf("failed to use %s, maybe \"security\" is not \"xtls\"", client.Addons.Flow) } + case XRV: + c.packetsToFilter = 6 + c.readProcess = true + c.readFilterUUID = true + c.writeFilterApplicationData = true + c.addons = client.Addons + var t reflect.Type + var p uintptr + switch underlying := conn.(type) { + case *gotls.Conn: + c.Conn = underlying.NetConn() + c.tlsConn = underlying + t = reflect.TypeOf(underlying).Elem() + p = uintptr(unsafe.Pointer(underlying)) + case *utls.UConn: + c.Conn = underlying.NetConn() + c.tlsConn = underlying + t = reflect.TypeOf(underlying.Conn).Elem() + p = uintptr(unsafe.Pointer(underlying.Conn)) + case *tlsC.UConn: + c.Conn = underlying.NetConn() + c.tlsConn = underlying.UConn + t = reflect.TypeOf(underlying.Conn).Elem() + p = uintptr(unsafe.Pointer(underlying.Conn)) + default: + return nil, fmt.Errorf(`failed to use %s, maybe "security" is not "tls" or "utls"`, client.Addons.Flow) + } + i, _ := t.FieldByName("input") + r, _ := t.FieldByName("rawInput") + c.input = (*bytes.Reader)(unsafe.Pointer(p + i.Offset)) + c.rawInput = (*bytes.Buffer)(unsafe.Pointer(p + r.Offset)) + if _, ok := c.Conn.(*net.TCPConn); !ok { + log.Debugln("XTLS underlying conn is not *net.TCPConn, got %s", reflect.TypeOf(conn).Name()) + } } } - //go func() { - // select { - // case <-c.handshake: - // case <-time.After(200 * time.Millisecond): - // c.sendRequest(nil) - // } - //}() return c, nil } diff --git a/transport/vless/filter.go b/transport/vless/filter.go new file mode 100644 index 00000000..15d595bc --- /dev/null +++ b/transport/vless/filter.go @@ -0,0 +1,79 @@ +package vless + +import ( + "bytes" + "encoding/binary" + + log "github.com/sirupsen/logrus" +) + +var ( + tls13SupportedVersions = []byte{0x00, 0x2b, 0x00, 0x02, 0x03, 0x04} + tlsClientHandshakeStart = []byte{0x16, 0x03} + tlsServerHandshakeStart = []byte{0x16, 0x03, 0x03} + tlsApplicationDataStart = []byte{0x17, 0x03, 0x03} + + tls13CipherSuiteMap = map[uint16]string{ + 0x1301: "TLS_AES_128_GCM_SHA256", + 0x1302: "TLS_AES_256_GCM_SHA384", + 0x1303: "TLS_CHACHA20_POLY1305_SHA256", + 0x1304: "TLS_AES_128_CCM_SHA256", + 0x1305: "TLS_AES_128_CCM_8_SHA256", + } +) + +const ( + tlsHandshakeTypeClientHello byte = 0x01 + tlsHandshakeTypeServerHello byte = 0x02 +) + +func (vc *Conn) FilterTLS(p []byte) (index int) { + if vc.packetsToFilter <= 0 { + return 0 + } + lenP := len(p) + vc.packetsToFilter -= 1 + if index = bytes.Index(p, tlsServerHandshakeStart); index != -1 { + if lenP >= index+5 && p[index+5] == tlsHandshakeTypeServerHello { + vc.remainingServerHello = binary.BigEndian.Uint16(p[index+3:]) + 5 + vc.isTLS = true + vc.isTLS12orAbove = true + if lenP-index >= 79 && vc.remainingServerHello >= 79 { + sessionIDLen := int(p[index+43]) + vc.cipher = binary.BigEndian.Uint16(p[index+43+sessionIDLen+1:]) + } + } + } else if index = bytes.Index(p, tlsClientHandshakeStart); index != -1 { + if lenP >= index+5 && p[index+5] == tlsHandshakeTypeClientHello { + vc.isTLS = true + } + } + + if vc.remainingServerHello > 0 { + end := vc.remainingServerHello + vc.remainingServerHello -= end + if end > uint16(lenP) { + end = uint16(lenP) + } + if bytes.Contains(p[index:end], tls13SupportedVersions) { + // TLS 1.3 Client Hello + cs, ok := tls13CipherSuiteMap[vc.cipher] + if ok && cs != "TLS_AES_128_CCM_8_SHA256" { + vc.enableXTLS = true + } + log.Debugln("XTLS Vision found TLS 1.3, packetLength=", lenP, ", CipherSuite=", cs) + vc.packetsToFilter = 0 + return + } else if vc.remainingServerHello < 0 { + log.Debugln("XTLS Vision found TLS 1.2, packetLength=", lenP) + vc.packetsToFilter = 0 + return + } + log.Debugln("XTLS Vision found inconclusive server hello, packetLength=", lenP, + ", remainingServerHelloBytes=", vc.remainingServerHello) + } + if vc.packetsToFilter <= 0 { + log.Debugln("XTLS Vision stop filtering") + } + return +} diff --git a/transport/vless/vision.go b/transport/vless/vision.go new file mode 100644 index 00000000..f87a6870 --- /dev/null +++ b/transport/vless/vision.go @@ -0,0 +1,69 @@ +package vless + +import ( + "bytes" + "encoding/binary" + "math/rand" + + "github.com/Dreamacro/clash/common/buf" + "github.com/Dreamacro/clash/log" + + "github.com/gofrs/uuid" +) + +const ( + paddingHeaderLen = 1 + 2 + 2 // =5 + + commandPaddingContinue byte = 0x00 + commandPaddingEnd byte = 0x01 + commandPaddingDirect byte = 0x02 +) + +func WriteWithPadding(buffer *buf.Buffer, p []byte, command byte, userUUID *uuid.UUID) { + contentLen := int32(len(p)) + var paddingLen int32 + if contentLen < 900 { + paddingLen = rand.Int31n(500) + 900 - contentLen + } + + if userUUID != nil { // unnecessary, but keep the same with Xray + buffer.Write(userUUID.Bytes()) + } + buffer.WriteByte(command) + binary.BigEndian.PutUint16(buffer.Extend(2), uint16(contentLen)) + binary.BigEndian.PutUint16(buffer.Extend(2), uint16(paddingLen)) + buffer.Write(p) + buffer.Extend(int(paddingLen)) + log.Debugln("XTLS Vision write padding1: command=%v, payloadLen=%v, paddingLen=%v", command, contentLen, paddingLen) +} + +func ApplyPadding(buffer *buf.Buffer, command byte, userUUID *uuid.UUID) { + contentLen := int32(buffer.Len()) + var paddingLen int32 + if contentLen < 900 { + paddingLen = rand.Int31n(500) + 900 - contentLen + } + + binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(paddingLen)) + binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(contentLen)) + buffer.ExtendHeader(1)[0] = command + if userUUID != nil { // unnecessary, but keep the same with Xray + copy(buffer.ExtendHeader(uuid.Size), userUUID.Bytes()) + } + buffer.Extend(int(paddingLen)) + log.Debugln("XTLS Vision write padding2: command=%d, payloadLen=%d, paddingLen=%d", command, contentLen, paddingLen) +} + +func ReshapeBuffer(buffer *buf.Buffer) *buf.Buffer { + if buffer.Len() <= buf.BufferSize-paddingHeaderLen { + return nil + } + cutAt := bytes.LastIndex(buffer.Bytes(), tlsApplicationDataStart) + if cutAt == -1 { + cutAt = buf.BufferSize / 2 + } + buffer2 := buf.New() + buffer2.Write(buffer.From(cutAt)) + buffer.Truncate(cutAt) + return buffer2 +} diff --git a/transport/vless/vless.go b/transport/vless/vless.go index 4b101703..6989374c 100644 --- a/transport/vless/vless.go +++ b/transport/vless/vless.go @@ -12,6 +12,7 @@ const ( XRO = "xtls-rprx-origin" XRD = "xtls-rprx-direct" XRS = "xtls-rprx-splice" + XRV = "xtls-rprx-vision" Version byte = 0 // protocol version. preview version is 0 ) diff --git a/transport/vless/xtls.go b/transport/vless/xtls.go index a1aea44f..3a319568 100644 --- a/transport/vless/xtls.go +++ b/transport/vless/xtls.go @@ -2,6 +2,7 @@ package vless import ( "context" + "errors" "net" tlsC "github.com/Dreamacro/clash/component/tls" @@ -9,6 +10,10 @@ import ( xtls "github.com/xtls/go" ) +var ( + ErrNotTLS13 = errors.New("XTLS Vision based on TLS 1.3 outer connection") +) + type XTLSConfig struct { Host string SkipCertVerify bool