From 573316bcde9583ab42ecd0bfaffaeaec2f6d8c3b Mon Sep 17 00:00:00 2001 From: ShinyGwyn <79344143+ShinyGwyn@users.noreply.github.com> Date: Thu, 18 Mar 2021 19:40:34 +0800 Subject: [PATCH] Feature: add gRPC Transport for vmess/trojan (#1287) Co-authored-by: eMeab <32988354+eMeab@users.noreply.github.com> Co-authored-by: Dreamacro <8615343+Dreamacro@users.noreply.github.com> --- adapters/outbound/trojan.go | 98 ++++++++++++++-- adapters/outbound/vmess.go | 92 ++++++++++++++- component/gun/gun.go | 226 ++++++++++++++++++++++++++++++++++++ component/gun/leb128.go | 58 +++++++++ 4 files changed, 458 insertions(+), 16 deletions(-) create mode 100644 component/gun/gun.go create mode 100644 component/gun/leb128.go diff --git a/adapters/outbound/trojan.go b/adapters/outbound/trojan.go index 7b61b573..0c996f9b 100644 --- a/adapters/outbound/trojan.go +++ b/adapters/outbound/trojan.go @@ -2,34 +2,51 @@ package outbound import ( "context" + "crypto/tls" "encoding/json" "fmt" "net" "strconv" "github.com/Dreamacro/clash/component/dialer" + "github.com/Dreamacro/clash/component/gun" "github.com/Dreamacro/clash/component/trojan" C "github.com/Dreamacro/clash/constant" + + "golang.org/x/net/http2" ) type Trojan struct { *Base instance *trojan.Trojan + + // for gun mux + gunTLSConfig *tls.Config + gunConfig *gun.Config + transport *http2.Transport } type TrojanOption struct { - Name string `proxy:"name"` - Server string `proxy:"server"` - Port int `proxy:"port"` - Password string `proxy:"password"` - ALPN []string `proxy:"alpn,omitempty"` - SNI string `proxy:"sni,omitempty"` - SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` - UDP bool `proxy:"udp,omitempty"` + Name string `proxy:"name"` + Server string `proxy:"server"` + Port int `proxy:"port"` + Password string `proxy:"password"` + ALPN []string `proxy:"alpn,omitempty"` + SNI string `proxy:"sni,omitempty"` + SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` + UDP bool `proxy:"udp,omitempty"` + Network string `proxy:"network,omitempty"` + GrpcOpts GrpcOptions `proxy:"grpc-opts,omitempty"` } func (t *Trojan) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { - c, err := t.instance.StreamConn(c) + var err error + if t.transport != nil { + c, err = gun.StreamGunWithConn(c, t.gunTLSConfig, t.gunConfig) + } else { + c, err = t.instance.StreamConn(c) + } + if err != nil { return nil, fmt.Errorf("%s connect error: %w", t.addr, err) } @@ -39,6 +56,21 @@ func (t *Trojan) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) } func (t *Trojan) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { + // gun transport, TODO: Optimize mux dial code + if t.transport != nil { + c, err := gun.StreamGunWithTransport(t.transport, t.gunConfig) + if err != nil { + return nil, err + } + + if err = t.instance.WriteHeader(c, trojan.CommandTCP, serializesSocksAddr(metadata)); err != nil { + c.Close() + return nil, err + } + + return NewConn(c, t), nil + } + c, err := dialer.DialContext(ctx, "tcp", t.addr) if err != nil { return nil, fmt.Errorf("%s connect error: %w", t.addr, err) @@ -53,6 +85,22 @@ func (t *Trojan) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, } func (t *Trojan) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { + // gun transport, TODO: Optimize mux dial code + if t.transport != nil { + c, err := gun.StreamGunWithTransport(t.transport, t.gunConfig) + if err != nil { + return nil, err + } + + if err = t.instance.WriteHeader(c, trojan.CommandUDP, serializesSocksAddr(metadata)); err != nil { + c.Close() + return nil, err + } + + pc := t.instance.PacketConn(c) + return newPacketConn(pc, t), err + } + ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout) defer cancel() c, err := dialer.DialContext(ctx, "tcp", t.addr) @@ -95,7 +143,7 @@ func NewTrojan(option TrojanOption) (*Trojan, error) { tOption.ServerName = option.SNI } - return &Trojan{ + t := &Trojan{ Base: &Base{ name: option.Name, addr: addr, @@ -103,5 +151,33 @@ func NewTrojan(option TrojanOption) (*Trojan, error) { udp: option.UDP, }, instance: trojan.New(tOption), - }, nil + } + + if option.Network == "grpc" { + dialFn := func(network, addr string) (net.Conn, error) { + c, err := dialer.DialContext(context.Background(), "tcp", t.addr) + if err != nil { + return nil, fmt.Errorf("%s connect error: %s", t.addr, err.Error()) + } + tcpKeepAlive(c) + return c, nil + } + + tlsConfig := &tls.Config{ + NextProtos: option.ALPN, + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: tOption.SkipCertVerify, + ServerName: tOption.ServerName, + ClientSessionCache: getClientSessionCache(), + } + + t.transport = gun.NewHTTP2Client(dialFn, tlsConfig) + t.gunTLSConfig = tlsConfig + t.gunConfig = &gun.Config{ + ServiceName: option.GrpcOpts.GrpcServiceName, + Host: tOption.ServerName, + } + } + + return t, nil } diff --git a/adapters/outbound/vmess.go b/adapters/outbound/vmess.go index db1e2039..279a3d4b 100644 --- a/adapters/outbound/vmess.go +++ b/adapters/outbound/vmess.go @@ -2,6 +2,7 @@ package outbound import ( "context" + "crypto/tls" "errors" "fmt" "net" @@ -10,15 +11,23 @@ import ( "strings" "github.com/Dreamacro/clash/component/dialer" + "github.com/Dreamacro/clash/component/gun" "github.com/Dreamacro/clash/component/resolver" "github.com/Dreamacro/clash/component/vmess" C "github.com/Dreamacro/clash/constant" + + "golang.org/x/net/http2" ) type Vmess struct { *Base client *vmess.Client option *VmessOption + + // for gun mux + gunTLSConfig *tls.Config + gunConfig *gun.Config + transport *http2.Transport } type VmessOption struct { @@ -33,6 +42,7 @@ type VmessOption struct { Network string `proxy:"network,omitempty"` HTTPOpts HTTPOptions `proxy:"http-opts,omitempty"` HTTP2Opts HTTP2Options `proxy:"h2-opts,omitempty"` + GrpcOpts GrpcOptions `proxy:"grpc-opts,omitempty"` WSPath string `proxy:"ws-path,omitempty"` WSHeaders map[string]string `proxy:"ws-headers,omitempty"` SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` @@ -50,6 +60,10 @@ type HTTP2Options struct { Path string `proxy:"path,omitempty"` } +type GrpcOptions struct { + GrpcServiceName string `proxy:"grpc-service-name,omitempty"` +} + func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { var err error switch v.option.Network { @@ -129,6 +143,8 @@ func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { } c, err = vmess.StreamH2Conn(c, h2Opts) + case "grpc": + c, err = gun.StreamGunWithConn(c, v.gunTLSConfig, v.gunConfig) default: // handle TLS if v.option.TLS { @@ -155,6 +171,21 @@ func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { } func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { + // gun transport, TODO: Optimize mux dial code + if v.transport != nil { + c, err := gun.StreamGunWithTransport(v.transport, v.gunConfig) + if err != nil { + return nil, err + } + + c, err = v.client.StreamConn(c, parseVmessAddr(metadata)) + if err != nil { + return nil, err + } + + return NewConn(c, v), nil + } + c, err := dialer.DialContext(ctx, "tcp", v.addr) if err != nil { return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) @@ -166,7 +197,7 @@ func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, } func (v *Vmess) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { - // vmess use stream-oriented udp, so clash needs a net.UDPAddr + // vmess use stream-oriented udp with a special address, so we needs a net.UDPAddr if !metadata.Resolved() { ip, err := resolver.ResolveIP(metadata.Host) if err != nil { @@ -175,6 +206,21 @@ func (v *Vmess) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { metadata.DstIP = ip } + // gun transport, TODO: Optimize mux dial code + if v.transport != nil { + c, err := gun.StreamGunWithTransport(v.transport, v.gunConfig) + if err != nil { + return nil, err + } + + c, err = v.client.StreamConn(c, parseVmessAddr(metadata)) + if err != nil { + return nil, err + } + + return newPacketConn(&vmessPacketConn{Conn: c, rAddr: metadata.UDPAddr()}, v), nil + } + ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout) defer cancel() c, err := dialer.DialContext(ctx, "tcp", v.addr) @@ -201,11 +247,15 @@ func NewVmess(option VmessOption) (*Vmess, error) { if err != nil { return nil, err } - if option.Network == "h2" && !option.TLS { - return nil, fmt.Errorf("TLS must be true with h2 network") + + switch option.Network { + case "h2", "grpc": + if !option.TLS { + return nil, fmt.Errorf("TLS must be true with h2/grpc network") + } } - return &Vmess{ + v := &Vmess{ Base: &Base{ name: option.Name, addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)), @@ -214,7 +264,39 @@ func NewVmess(option VmessOption) (*Vmess, error) { }, client: client, option: &option, - }, nil + } + + if option.Network == "grpc" { + dialFn := func(network, addr string) (net.Conn, error) { + c, err := dialer.DialContext(context.Background(), "tcp", v.addr) + if err != nil { + return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) + } + tcpKeepAlive(c) + return c, nil + } + + gunConfig := &gun.Config{ + ServiceName: v.option.GrpcOpts.GrpcServiceName, + Host: v.option.ServerName, + } + tlsConfig := &tls.Config{ + InsecureSkipVerify: v.option.SkipCertVerify, + ServerName: v.option.ServerName, + } + + if v.option.ServerName == "" { + host, _, _ := net.SplitHostPort(v.addr) + tlsConfig.ServerName = host + gunConfig.Host = host + } + + v.gunTLSConfig = tlsConfig + v.gunConfig = gunConfig + v.transport = gun.NewHTTP2Client(dialFn, tlsConfig) + } + + return v, nil } func parseVmessAddr(metadata *C.Metadata) *vmess.DstAddr { diff --git a/component/gun/gun.go b/component/gun/gun.go new file mode 100644 index 00000000..4d2b8fd1 --- /dev/null +++ b/component/gun/gun.go @@ -0,0 +1,226 @@ +// Modified from: https://github.com/Qv2ray/gun-lite +// License: MIT + +package gun + +import ( + "crypto/tls" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "sync" + "time" + + "github.com/Dreamacro/clash/common/pool" + + "go.uber.org/atomic" + "golang.org/x/net/http2" +) + +var ( + ErrInvalidLength = errors.New("invalid length") +) + +var ( + defaultHeader = http.Header{ + "content-type": []string{"application/grpc"}, + "user-agent": []string{"grpc-go/1.36.0"}, + } +) + +type DialFn = func(network, addr string) (net.Conn, error) + +type Conn struct { + response *http.Response + request *http.Request + client *http.Client + writer *io.PipeWriter + once sync.Once + close *atomic.Bool + err error + + buf []byte + offset int +} + +type Config struct { + ServiceName string + Host string +} + +func (g *Conn) initRequest() { + response, err := g.client.Do(g.request) + if err != nil { + g.err = err + g.writer.Close() + return + } + + if !g.close.Load() { + g.response = response + } else { + response.Body.Close() + } +} + +func (g *Conn) Read(b []byte) (n int, err error) { + g.once.Do(g.initRequest) + if g.err != nil { + return 0, g.err + } + + if g.buf != nil { + n = copy(b, g.buf[g.offset:]) + g.offset += n + if g.offset == len(g.buf) { + g.offset = 0 + g.buf = nil + } + return + } + + buf := make([]byte, 5) + _, err = io.ReadFull(g.response.Body, buf) + if err != nil { + return 0, err + } + grpcPayloadLen := binary.BigEndian.Uint32(buf[1:]) + if grpcPayloadLen > pool.RelayBufferSize { + return 0, ErrInvalidLength + } + + buf = pool.Get(int(grpcPayloadLen)) + _, err = io.ReadFull(g.response.Body, buf) + if err != nil { + pool.Put(buf) + return 0, io.ErrUnexpectedEOF + } + protobufPayloadLen, protobufLengthLen := decodeUleb128(buf[1:]) + if protobufLengthLen == 0 { + pool.Put(buf) + return 0, ErrInvalidLength + } + if grpcPayloadLen != uint32(protobufPayloadLen)+uint32(protobufLengthLen)+1 { + pool.Put(buf) + return 0, ErrInvalidLength + } + + if len(b) >= int(grpcPayloadLen)-1-int(protobufLengthLen) { + n = copy(b, buf[1+protobufLengthLen:]) + pool.Put(buf) + return + } + n = copy(b, buf[1+protobufLengthLen:]) + g.offset = n + 1 + int(protobufLengthLen) + g.buf = buf + return +} + +func (g *Conn) Write(b []byte) (n int, err error) { + protobufHeader := appendUleb128([]byte{0x0A}, uint64(len(b))) + grpcHeader := make([]byte, 5) + grpcPayloadLen := uint32(len(protobufHeader) + len(b)) + binary.BigEndian.PutUint32(grpcHeader[1:5], grpcPayloadLen) + + buffers := net.Buffers{grpcHeader, protobufHeader, b} + _, err = buffers.WriteTo(g.writer) + if err == io.ErrClosedPipe && g.err != nil { + err = g.err + } + + return len(b), err +} + +func (g *Conn) Close() error { + g.close.Store(true) + if r := g.response; r != nil { + r.Body.Close() + } + + return g.writer.Close() +} + +func (g *Conn) LocalAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4zero, Port: 0} } +func (g *Conn) RemoteAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4zero, Port: 0} } +func (g *Conn) SetDeadline(t time.Time) error { return nil } +func (g *Conn) SetReadDeadline(t time.Time) error { return nil } +func (g *Conn) SetWriteDeadline(t time.Time) error { return nil } + +func NewHTTP2Client(dialFn DialFn, tlsConfig *tls.Config) *http2.Transport { + dialFunc := func(network, addr string, cfg *tls.Config) (net.Conn, error) { + pconn, err := dialFn(network, addr) + if err != nil { + return nil, err + } + + cn := tls.Client(pconn, cfg) + if err := cn.Handshake(); err != nil { + pconn.Close() + return nil, err + } + state := cn.ConnectionState() + if p := state.NegotiatedProtocol; p != http2.NextProtoTLS { + cn.Close() + return nil, errors.New("http2: unexpected ALPN protocol " + p + "; want q" + http2.NextProtoTLS) + } + return cn, nil + } + + return &http2.Transport{ + DialTLS: dialFunc, + TLSClientConfig: tlsConfig, + AllowHTTP: false, + DisableCompression: true, + ReadIdleTimeout: 0, + PingTimeout: 0, + } +} + +func StreamGunWithTransport(transport *http2.Transport, cfg *Config) (net.Conn, error) { + serviceName := "GunService" + if cfg.ServiceName != "" { + serviceName = cfg.ServiceName + } + + client := &http.Client{ + Transport: transport, + } + + reader, writer := io.Pipe() + request := &http.Request{ + Method: http.MethodPost, + Body: reader, + URL: &url.URL{ + Scheme: "https", + Host: cfg.Host, + Path: fmt.Sprintf("/%s/Tun", serviceName), + }, + Proto: "HTTP/2", + ProtoMajor: 2, + ProtoMinor: 0, + Header: defaultHeader, + } + + conn := &Conn{ + request: request, + client: client, + writer: writer, + close: atomic.NewBool(false), + } + + go conn.once.Do(conn.initRequest) + return conn, nil +} + +func StreamGunWithConn(conn net.Conn, tlsConfig *tls.Config, cfg *Config) (net.Conn, error) { + dialFn := func(network, addr string) (net.Conn, error) { + return conn, nil + } + + transport := NewHTTP2Client(dialFn, tlsConfig) + return StreamGunWithTransport(transport, cfg) +} diff --git a/component/gun/leb128.go b/component/gun/leb128.go new file mode 100644 index 00000000..05463458 --- /dev/null +++ b/component/gun/leb128.go @@ -0,0 +1,58 @@ +// Copy from: https://github.com/Equim-chan/leb128 +// License: BSD-3-Clause + +package gun + +var sevenbits = [...]byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, + 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f, + 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f, + 0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, + 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, +} + +func decodeUleb128(b []byte) (u uint64, n uint8) { + l := uint8(len(b) & 0xff) + // The longest LEB128 encoded sequence is 10 byte long (9 0xff's and 1 0x7f) + // so make sure we won't overflow. + if l > 10 { + l = 10 + } + + var i uint8 + for i = 0; i < l; i++ { + u |= uint64(b[i]&0x7f) << (7 * i) + if b[i]&0x80 == 0 { + n = uint8(i + 1) + return + } + } + + return +} + +func appendUleb128(b []byte, v uint64) []byte { + // If it's less than or equal to 7-bit + if v < 0x80 { + return append(b, sevenbits[v]) + } + + for { + c := uint8(v & 0x7f) + v >>= 7 + + if v != 0 { + c |= 0x80 + } + + b = append(b, c) + if c&0x80 == 0 { + break + } + } + + return b +}