From 4502776513bf45c5daf0d2158c84384626da29f1 Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Mon, 28 Mar 2022 00:44:13 +0800 Subject: [PATCH 1/8] Refactor: MainResolver --- adapter/outbound/direct.go | 2 + adapter/outbound/util.go | 2 +- component/dialer/dialer.go | 22 ++++----- component/dialer/options.go | 7 +++ component/resolver/resolver.go | 89 +++++++++++++++++++++------------- config/config.go | 52 +++++++++++--------- dns/resolver.go | 15 +++++- hub/executor/executor.go | 14 ++++-- 8 files changed, 128 insertions(+), 75 deletions(-) diff --git a/adapter/outbound/direct.go b/adapter/outbound/direct.go index 4c4305f5..61eb4571 100644 --- a/adapter/outbound/direct.go +++ b/adapter/outbound/direct.go @@ -14,6 +14,7 @@ type Direct struct { // DialContext implements C.ProxyAdapter func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { + opts = append(opts, dialer.WithDirect()) c, err := dialer.DialContext(ctx, "tcp", metadata.RemoteAddress(), d.Base.DialOptions(opts...)...) if err != nil { return nil, err @@ -24,6 +25,7 @@ func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ... // ListenPacketContext implements C.ProxyAdapter func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) { + opts = append(opts, dialer.WithDirect()) pc, err := dialer.ListenPacket(ctx, "udp", "", d.Base.DialOptions(opts...)...) if err != nil { return nil, err diff --git a/adapter/outbound/util.go b/adapter/outbound/util.go index b376522f..9322adc5 100644 --- a/adapter/outbound/util.go +++ b/adapter/outbound/util.go @@ -44,7 +44,7 @@ func resolveUDPAddr(network, address string) (*net.UDPAddr, error) { return nil, err } - ip, err := resolver.ResolveIP(host) + ip, err := resolver.ResolveProxyServerHost(host) if err != nil { return nil, err } diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index 09254c58..19a5fdd2 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -32,14 +32,14 @@ func DialContext(ctx context.Context, network, address string, options ...Option var ip net.IP switch network { case "tcp4", "udp4": - if opt.interfaceName != "" { - ip, err = resolver.ResolveIPv4WithMain(host) + if !opt.direct { + ip, err = resolver.ResolveIPv4ProxyServerHost(host) } else { ip, err = resolver.ResolveIPv4(host) } default: - if opt.interfaceName != "" { - ip, err = resolver.ResolveIPv6WithMain(host) + if !opt.direct { + ip, err = resolver.ResolveIPv6ProxyServerHost(host) } else { ip, err = resolver.ResolveIPv6(host) } @@ -121,7 +121,7 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt results := make(chan dialResult) var primary, fallback dialResult - startRacer := func(ctx context.Context, network, host string, ipv6 bool) { + startRacer := func(ctx context.Context, network, host string, direct bool, ipv6 bool) { result := dialResult{ipv6: ipv6, done: true} defer func() { select { @@ -135,14 +135,14 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt var ip net.IP if ipv6 { - if opt.interfaceName != "" { - ip, result.error = resolver.ResolveIPv6WithMain(host) + if !direct { + ip, result.error = resolver.ResolveIPv6ProxyServerHost(host) } else { ip, result.error = resolver.ResolveIPv6(host) } } else { - if opt.interfaceName != "" { - ip, result.error = resolver.ResolveIPv4WithMain(host) + if !direct { + ip, result.error = resolver.ResolveIPv4ProxyServerHost(host) } else { ip, result.error = resolver.ResolveIPv4(host) } @@ -155,8 +155,8 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt result.Conn, result.error = dialContext(ctx, network, ip, port, opt) } - go startRacer(ctx, network+"4", host, false) - go startRacer(ctx, network+"6", host, true) + go startRacer(ctx, network+"4", host, opt.direct, false) + go startRacer(ctx, network+"6", host, opt.direct, true) for res := range results { if res.error == nil { diff --git a/component/dialer/options.go b/component/dialer/options.go index 2d884094..2985dc7b 100644 --- a/component/dialer/options.go +++ b/component/dialer/options.go @@ -12,6 +12,7 @@ type option struct { interfaceName string addrReuse bool routingMark int + direct bool } type Option func(opt *option) @@ -33,3 +34,9 @@ func WithRoutingMark(mark int) Option { opt.routingMark = mark } } + +func WithDirect() Option { + return func(opt *option) { + opt.direct = true + } +} diff --git a/component/resolver/resolver.go b/component/resolver/resolver.go index a7300bd7..e1100a31 100644 --- a/component/resolver/resolver.go +++ b/component/resolver/resolver.go @@ -15,8 +15,8 @@ var ( // DefaultResolver aim to resolve ip DefaultResolver Resolver - // MainResolver resolve ip with main domain server - MainResolver Resolver + // ProxyServerHostResolver resolve ip to proxies server host + ProxyServerHostResolver Resolver // DisableIPv6 means don't resolve ipv6 host // default value is true @@ -46,10 +46,6 @@ func ResolveIPv4(host string) (net.IP, error) { return ResolveIPv4WithResolver(host, DefaultResolver) } -func ResolveIPv4WithMain(host string) (net.IP, error) { - return ResolveIPv4WithResolver(host, MainResolver) -} - func ResolveIPv4WithResolver(host string, r Resolver) (net.IP, error) { if node := DefaultHosts.Search(host); node != nil { if ip := node.Data.(net.IP).To4(); ip != nil { @@ -69,16 +65,20 @@ func ResolveIPv4WithResolver(host string, r Resolver) (net.IP, error) { return r.ResolveIPv4(host) } - ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) - defer cancel() - ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip4", host) - if err != nil { - return nil, err - } else if len(ipAddrs) == 0 { - return nil, ErrIPNotFound + if DefaultResolver == nil { + ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) + defer cancel() + ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip4", host) + if err != nil { + return nil, err + } else if len(ipAddrs) == 0 { + return nil, ErrIPNotFound + } + + return ipAddrs[rand.Intn(len(ipAddrs))], nil } - return ipAddrs[rand.Intn(len(ipAddrs))], nil + return nil, ErrIPNotFound } // ResolveIPv6 with a host, return ipv6 @@ -86,10 +86,6 @@ func ResolveIPv6(host string) (net.IP, error) { return ResolveIPv6WithResolver(host, DefaultResolver) } -func ResolveIPv6WithMain(host string) (net.IP, error) { - return ResolveIPv6WithResolver(host, MainResolver) -} - func ResolveIPv6WithResolver(host string, r Resolver) (net.IP, error) { if DisableIPv6 { return nil, ErrIPv6Disabled @@ -113,16 +109,20 @@ func ResolveIPv6WithResolver(host string, r Resolver) (net.IP, error) { return r.ResolveIPv6(host) } - ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) - defer cancel() - ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip6", host) - if err != nil { - return nil, err - } else if len(ipAddrs) == 0 { - return nil, ErrIPNotFound + if DefaultResolver == nil { + ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) + defer cancel() + ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip6", host) + if err != nil { + return nil, err + } else if len(ipAddrs) == 0 { + return nil, ErrIPNotFound + } + + return ipAddrs[rand.Intn(len(ipAddrs))], nil } - return ipAddrs[rand.Intn(len(ipAddrs))], nil + return nil, ErrIPNotFound } // ResolveIPWithResolver same as ResolveIP, but with a resolver @@ -145,12 +145,16 @@ func ResolveIPWithResolver(host string, r Resolver) (net.IP, error) { return ip, nil } - ipAddr, err := net.ResolveIPAddr("ip", host) - if err != nil { - return nil, err + if DefaultResolver == nil { + ipAddr, err := net.ResolveIPAddr("ip", host) + if err != nil { + return nil, err + } + + return ipAddr.IP, nil } - return ipAddr.IP, nil + return nil, ErrIPNotFound } // ResolveIP with a host, return ip @@ -158,7 +162,26 @@ func ResolveIP(host string) (net.IP, error) { return ResolveIPWithResolver(host, DefaultResolver) } -// ResolveIPWithMainResolver with a host, use main resolver, return ip -func ResolveIPWithMainResolver(host string) (net.IP, error) { - return ResolveIPWithResolver(host, MainResolver) +// ResolveIPv4ProxyServerHost proxies server host only +func ResolveIPv4ProxyServerHost(host string) (net.IP, error) { + if ProxyServerHostResolver != nil { + return ResolveIPv4WithResolver(host, ProxyServerHostResolver) + } + return ResolveIPv4(host) +} + +// ResolveIPv6ProxyServerHost proxies server host only +func ResolveIPv6ProxyServerHost(host string) (net.IP, error) { + if ProxyServerHostResolver != nil { + return ResolveIPv6WithResolver(host, ProxyServerHostResolver) + } + return ResolveIPv6(host) +} + +// ResolveProxyServerHost proxies server host only +func ResolveProxyServerHost(host string) (net.IP, error) { + if ProxyServerHostResolver != nil { + return ResolveIPWithResolver(host, ProxyServerHostResolver) + } + return ResolveIP(host) } diff --git a/config/config.go b/config/config.go index 36c2223a..eedc1959 100644 --- a/config/config.go +++ b/config/config.go @@ -63,17 +63,18 @@ type Controller struct { // DNS config type DNS struct { - Enable bool `yaml:"enable"` - IPv6 bool `yaml:"ipv6"` - NameServer []dns.NameServer `yaml:"nameserver"` - Fallback []dns.NameServer `yaml:"fallback"` - FallbackFilter FallbackFilter `yaml:"fallback-filter"` - Listen string `yaml:"listen"` - EnhancedMode C.DNSMode `yaml:"enhanced-mode"` - DefaultNameserver []dns.NameServer `yaml:"default-nameserver"` - FakeIPRange *fakeip.Pool - Hosts *trie.DomainTrie - NameServerPolicy map[string]dns.NameServer + Enable bool `yaml:"enable"` + IPv6 bool `yaml:"ipv6"` + NameServer []dns.NameServer `yaml:"nameserver"` + Fallback []dns.NameServer `yaml:"fallback"` + FallbackFilter FallbackFilter `yaml:"fallback-filter"` + Listen string `yaml:"listen"` + EnhancedMode C.DNSMode `yaml:"enhanced-mode"` + DefaultNameserver []dns.NameServer `yaml:"default-nameserver"` + FakeIPRange *fakeip.Pool + Hosts *trie.DomainTrie + NameServerPolicy map[string]dns.NameServer + ProxyServerNameserver []dns.NameServer } // FallbackFilter config @@ -125,18 +126,19 @@ type Config struct { } type RawDNS struct { - Enable bool `yaml:"enable"` - IPv6 bool `yaml:"ipv6"` - UseHosts bool `yaml:"use-hosts"` - NameServer []string `yaml:"nameserver"` - Fallback []string `yaml:"fallback"` - FallbackFilter RawFallbackFilter `yaml:"fallback-filter"` - Listen string `yaml:"listen"` - EnhancedMode C.DNSMode `yaml:"enhanced-mode"` - FakeIPRange string `yaml:"fake-ip-range"` - FakeIPFilter []string `yaml:"fake-ip-filter"` - DefaultNameserver []string `yaml:"default-nameserver"` - NameServerPolicy map[string]string `yaml:"nameserver-policy"` + Enable bool `yaml:"enable"` + IPv6 bool `yaml:"ipv6"` + UseHosts bool `yaml:"use-hosts"` + NameServer []string `yaml:"nameserver"` + Fallback []string `yaml:"fallback"` + FallbackFilter RawFallbackFilter `yaml:"fallback-filter"` + Listen string `yaml:"listen"` + EnhancedMode C.DNSMode `yaml:"enhanced-mode"` + FakeIPRange string `yaml:"fake-ip-range"` + FakeIPFilter []string `yaml:"fake-ip-filter"` + DefaultNameserver []string `yaml:"default-nameserver"` + NameServerPolicy map[string]string `yaml:"nameserver-policy"` + ProxyServerNameserver []string `yaml:"proxy-server-nameserver"` } type RawFallbackFilter struct { @@ -679,6 +681,10 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie, rules []C.Rule) (*DNS, return nil, err } + if dnsCfg.ProxyServerNameserver, err = parseNameServer(cfg.ProxyServerNameserver); err != nil { + return nil, err + } + if len(cfg.DefaultNameserver) == 0 { return nil, errors.New("default nameserver should have at least one nameserver") } diff --git a/dns/resolver.go b/dns/resolver.go index f22902d7..4ff12ee8 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -41,6 +41,7 @@ type Resolver struct { group singleflight.Group lruCache *cache.LruCache policy *trie.DomainTrie + proxyServer []dnsClient } // ResolveIP request with TypeA and TypeAAAA, priority return TypeA @@ -301,6 +302,11 @@ func (r *Resolver) asyncExchange(ctx context.Context, client []dnsClient, msg *D return ch } +// HasProxyServer has proxy server dns client +func (r *Resolver) HasProxyServer() bool { + return len(r.main) > 0 +} + type NameServer struct { Net string Addr string @@ -319,6 +325,7 @@ type FallbackFilter struct { type Config struct { Main, Fallback []NameServer Default []NameServer + ProxyServer []NameServer IPv6 bool EnhancedMode C.DNSMode FallbackFilter FallbackFilter @@ -344,6 +351,10 @@ func NewResolver(config Config) *Resolver { r.fallback = transform(config.Fallback, defaultResolver) } + if len(config.ProxyServer) != 0 { + r.proxyServer = transform(config.ProxyServer, defaultResolver) + } + if len(config.Policy) != 0 { r.policy = trie.New() for domain, nameserver := range config.Policy { @@ -377,10 +388,10 @@ func NewResolver(config Config) *Resolver { return r } -func NewMainResolver(old *Resolver) *Resolver { +func NewProxyServerHostResolver(old *Resolver) *Resolver { r := &Resolver{ ipv6: old.ipv6, - main: old.main, + main: old.proxyServer, lruCache: old.lruCache, hosts: old.hosts, policy: old.policy, diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 24b8693e..29096951 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -130,12 +130,13 @@ func updateDNS(c *config.DNS, t *config.Tun) { Domain: c.FallbackFilter.Domain, GeoSite: c.FallbackFilter.GeoSite, }, - Default: c.DefaultNameserver, - Policy: c.NameServerPolicy, + Default: c.DefaultNameserver, + Policy: c.NameServerPolicy, + ProxyServer: c.ProxyServerNameserver, } r := dns.NewResolver(cfg) - mr := dns.NewMainResolver(r) + pr := dns.NewProxyServerHostResolver(r) m := dns.NewEnhancer(cfg) // reuse cache of old host mapper @@ -144,9 +145,12 @@ func updateDNS(c *config.DNS, t *config.Tun) { } resolver.DefaultResolver = r - resolver.MainResolver = mr resolver.DefaultHostMapper = m + if pr.HasProxyServer() { + resolver.ProxyServerHostResolver = pr + } + if t.Enable { resolver.DefaultLocalServer = dns.NewLocalServer(r, m) } @@ -156,9 +160,9 @@ func updateDNS(c *config.DNS, t *config.Tun) { } else { if !t.Enable { resolver.DefaultResolver = nil - resolver.MainResolver = nil resolver.DefaultHostMapper = nil resolver.DefaultLocalServer = nil + resolver.ProxyServerHostResolver = nil } dns.ReCreateServer("", nil, nil) } From 7e2c6e51887bcf51fa0565ceba07771ec7e85754 Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Mon, 28 Mar 2022 00:46:44 +0800 Subject: [PATCH 2/8] Chore: adjust HealthCheck at first check --- adapter/provider/healthcheck.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/adapter/provider/healthcheck.go b/adapter/provider/healthcheck.go index 7a89ffa3..bfbaf6b0 100644 --- a/adapter/provider/healthcheck.go +++ b/adapter/provider/healthcheck.go @@ -31,7 +31,13 @@ type HealthCheck struct { func (hc *HealthCheck) process() { ticker := time.NewTicker(time.Duration(hc.interval) * time.Second) - go hc.check() + go func() { + t := time.NewTicker(30 * time.Second) + <-t.C + t.Stop() + hc.check() + }() + for { select { case <-ticker.C: From fe76cbf31c20fccce1f4b84723acb39a8d9e08ce Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Mon, 28 Mar 2022 03:18:51 +0800 Subject: [PATCH 3/8] Chore: code style --- tunnel/statistic/tracker.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tunnel/statistic/tracker.go b/tunnel/statistic/tracker.go index 9f231547..1f5f1f9c 100644 --- a/tunnel/statistic/tracker.go +++ b/tunnel/statistic/tracker.go @@ -77,9 +77,6 @@ func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.R if rule != nil { t.trackerInfo.Rule = rule.RuleType().String() t.trackerInfo.RulePayload = rule.Payload() - //if rule.RuleType() == C.GEOSITE || rule.RuleType() == C.GEOIP { - // t.trackerInfo.Rule = t.trackerInfo.Rule + " (" + rule.Payload() + ")" - //} } manager.Join(t) @@ -137,9 +134,6 @@ func NewUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, ru if rule != nil { ut.trackerInfo.Rule = rule.RuleType().String() ut.trackerInfo.RulePayload = rule.Payload() - //if rule.RuleType() == C.GEOSITE || rule.RuleType() == C.GEOIP { - // ut.trackerInfo.Rule = ut.trackerInfo.Rule + " (" + rule.Payload() + ")" - //} } manager.Join(ut) From 8df8f8cb08e8de289377b5b3f21f5c739853371d Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Mon, 28 Mar 2022 03:25:55 +0800 Subject: [PATCH 4/8] Chore: adjust gVisor stack --- hub/executor/executor.go | 4 +- hub/hub.go | 4 -- listener/tun/device/device.go | 3 + listener/tun/device/iobased/endpoint.go | 2 +- listener/tun/device/tun/tun_wireguard.go | 3 + .../tun/ipstack/gvisor/adapter/handler.go | 10 ++- listener/tun/ipstack/gvisor/handler.go | 4 +- listener/tun/ipstack/gvisor/nic.go | 27 ++++---- .../gvisor/{opts.go => option/option.go} | 27 ++++---- .../tun/ipstack/gvisor/{icmp.go => route.go} | 13 ++-- listener/tun/ipstack/gvisor/stack.go | 66 ++++++++++--------- listener/tun/ipstack/gvisor/tcp.go | 11 ++-- listener/tun/ipstack/gvisor/udp.go | 11 ++-- listener/tun/ipstack/system/stack.go | 14 ---- listener/tun/tun_adapter.go | 3 +- main.go | 8 +-- 16 files changed, 106 insertions(+), 104 deletions(-) rename listener/tun/ipstack/gvisor/{opts.go => option/option.go} (93%) rename listener/tun/ipstack/gvisor/{icmp.go => route.go} (51%) diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 29096951..887c473a 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -330,7 +330,9 @@ func updateIPTables(cfg *config.Config) { log.Infoln("[IPTABLES] Setting iptables completed") } -func Cleanup() { +func Shutdown() { P.Cleanup() tproxy.CleanupTProxyIPTables() + + log.Warnln("Clash shutting down") } diff --git a/hub/hub.go b/hub/hub.go index cde0bb57..471fdb5e 100644 --- a/hub/hub.go +++ b/hub/hub.go @@ -48,7 +48,3 @@ func Parse(options ...Option) error { executor.ApplyConfig(cfg, true) return nil } - -func Cleanup() { - executor.Cleanup() -} diff --git a/listener/tun/device/device.go b/listener/tun/device/device.go index 70115cbd..73b03dee 100644 --- a/listener/tun/device/device.go +++ b/listener/tun/device/device.go @@ -29,4 +29,7 @@ type Device interface { // UseIOBased work for other ip stack UseIOBased() error + + // Wait waits for the device to close. + Wait() } diff --git a/listener/tun/device/iobased/endpoint.go b/listener/tun/device/iobased/endpoint.go index a187491e..c0942d10 100644 --- a/listener/tun/device/iobased/endpoint.go +++ b/listener/tun/device/iobased/endpoint.go @@ -103,7 +103,7 @@ func (e *Endpoint) dispatchLoop(cancel context.CancelFunc) { case header.IPv6Version: e.InjectInbound(header.IPv6ProtocolNumber, pkt) } - pkt.DecRef() /* release */ + pkt.DecRef() } } diff --git a/listener/tun/device/tun/tun_wireguard.go b/listener/tun/device/tun/tun_wireguard.go index 50db6511..30398b55 100644 --- a/listener/tun/device/tun/tun_wireguard.go +++ b/listener/tun/device/tun/tun_wireguard.go @@ -106,6 +106,9 @@ func (t *TUN) Write(packet []byte) (int, error) { } func (t *TUN) Close() error { + if t.Endpoint != nil { + t.Endpoint.Close() + } return t.nt.Close() } diff --git a/listener/tun/ipstack/gvisor/adapter/handler.go b/listener/tun/ipstack/gvisor/adapter/handler.go index 2878b713..715f6636 100644 --- a/listener/tun/ipstack/gvisor/adapter/handler.go +++ b/listener/tun/ipstack/gvisor/adapter/handler.go @@ -3,6 +3,12 @@ package adapter // Handler is a TCP/UDP connection handler that implements // HandleTCPConn and HandleUDPConn methods. type Handler interface { - HandleTCPConn(TCPConn) - HandleUDPConn(UDPConn) + HandleTCP(TCPConn) + HandleUDP(UDPConn) } + +// TCPHandleFunc handles incoming TCP connection. +type TCPHandleFunc func(TCPConn) + +// UDPHandleFunc handles incoming UDP connection. +type UDPHandleFunc func(UDPConn) diff --git a/listener/tun/ipstack/gvisor/handler.go b/listener/tun/ipstack/gvisor/handler.go index 76aab2f5..6365234e 100644 --- a/listener/tun/ipstack/gvisor/handler.go +++ b/listener/tun/ipstack/gvisor/handler.go @@ -24,7 +24,7 @@ type GVHandler struct { UDPIn chan<- *inbound.PacketAdapter } -func (gh *GVHandler) HandleTCPConn(tunConn adapter.TCPConn) { +func (gh *GVHandler) HandleTCP(tunConn adapter.TCPConn) { id := tunConn.ID() rAddr := &net.UDPAddr{ @@ -77,7 +77,7 @@ func (gh *GVHandler) HandleTCPConn(tunConn adapter.TCPConn) { gh.TCPIn <- inbound.NewSocket(socks5.ParseAddrToSocksAddr(rAddr), tunConn, C.TUN) } -func (gh *GVHandler) HandleUDPConn(tunConn adapter.UDPConn) { +func (gh *GVHandler) HandleUDP(tunConn adapter.UDPConn) { id := tunConn.ID() rAddr := &net.UDPAddr{ diff --git a/listener/tun/ipstack/gvisor/nic.go b/listener/tun/ipstack/gvisor/nic.go index fb8ac1a2..0ca96778 100644 --- a/listener/tun/ipstack/gvisor/nic.go +++ b/listener/tun/ipstack/gvisor/nic.go @@ -3,14 +3,13 @@ package gvisor import ( "fmt" + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( - // defaultNICID is the ID of default NIC used by DefaultStack. - defaultNICID tcpip.NICID = 0x01 - // nicPromiscuousModeEnabled is the value used by stack to enable // or disable NIC's promiscuous mode. nicPromiscuousModeEnabled = true @@ -21,9 +20,9 @@ const ( ) // withCreatingNIC creates NIC for stack. -func withCreatingNIC(ep stack.LinkEndpoint) Option { - return func(s *gvStack) error { - if err := s.CreateNICWithOptions(s.nicID, ep, +func withCreatingNIC(nicID tcpip.NICID, ep stack.LinkEndpoint) option.Option { + return func(s *stack.Stack) error { + if err := s.CreateNICWithOptions(nicID, ep, stack.NICOptions{ Disabled: false, // If no queueing discipline was specified @@ -37,21 +36,21 @@ func withCreatingNIC(ep stack.LinkEndpoint) Option { } } -// withPromiscuousMode sets promiscuous mode in the given NIC. -func withPromiscuousMode(v bool) Option { - return func(s *gvStack) error { - if err := s.SetPromiscuousMode(s.nicID, v); err != nil { +// withPromiscuousMode sets promiscuous mode in the given NICs. +func withPromiscuousMode(nicID tcpip.NICID, v bool) option.Option { + return func(s *stack.Stack) error { + if err := s.SetPromiscuousMode(nicID, v); err != nil { return fmt.Errorf("set promiscuous mode: %s", err) } return nil } } -// withSpoofing sets address spoofing in the given NIC, allowing +// withSpoofing sets address spoofing in the given NICs, allowing // endpoints to bind to any address in the NIC. -func withSpoofing(v bool) Option { - return func(s *gvStack) error { - if err := s.SetSpoofing(s.nicID, v); err != nil { +func withSpoofing(nicID tcpip.NICID, v bool) option.Option { + return func(s *stack.Stack) error { + if err := s.SetSpoofing(nicID, v); err != nil { return fmt.Errorf("set spoofing: %s", err) } return nil diff --git a/listener/tun/ipstack/gvisor/opts.go b/listener/tun/ipstack/gvisor/option/option.go similarity index 93% rename from listener/tun/ipstack/gvisor/opts.go rename to listener/tun/ipstack/gvisor/option/option.go index 7fd5a65b..34508e0d 100644 --- a/listener/tun/ipstack/gvisor/opts.go +++ b/listener/tun/ipstack/gvisor/option/option.go @@ -1,4 +1,4 @@ -package gvisor +package option import ( "fmt" @@ -7,6 +7,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" ) @@ -56,11 +57,11 @@ const ( tcpDefaultBufferSize = 212 << 10 // 212 KiB ) -type Option func(*gvStack) error +type Option func(*stack.Stack) error // WithDefault sets all default values for stack. func WithDefault() Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { opts := []Option{ WithDefaultTTL(defaultTimeToLive), WithForwarding(ipForwardingEnabled), @@ -110,7 +111,7 @@ func WithDefault() Option { // WithDefaultTTL sets the default TTL used by stack. func WithDefaultTTL(ttl uint8) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { opt := tcpip.DefaultTTLOption(ttl) if err := s.SetNetworkProtocolOption(ipv4.ProtocolNumber, &opt); err != nil { return fmt.Errorf("set ipv4 default TTL: %s", err) @@ -124,7 +125,7 @@ func WithDefaultTTL(ttl uint8) Option { // WithForwarding sets packet forwarding between NICs for IPv4 & IPv6. func WithForwarding(v bool) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, v); err != nil { return fmt.Errorf("set ipv4 forwarding: %s", err) } @@ -138,7 +139,7 @@ func WithForwarding(v bool) Option { // WithICMPBurst sets the number of ICMP messages that can be sent // in a single burst. func WithICMPBurst(burst int) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { s.SetICMPBurst(burst) return nil } @@ -147,7 +148,7 @@ func WithICMPBurst(burst int) Option { // WithICMPLimit sets the maximum number of ICMP messages permitted // by rate limiter. func WithICMPLimit(limit rate.Limit) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { s.SetICMPLimit(limit) return nil } @@ -155,7 +156,7 @@ func WithICMPLimit(limit rate.Limit) Option { // WithTCPBufferSizeRange sets the receive and send buffer size range for TCP. func WithTCPBufferSizeRange(a, b, c int) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { rcvOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: a, Default: b, Max: c} if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvOpt); err != nil { return fmt.Errorf("set TCP receive buffer size range: %s", err) @@ -170,7 +171,7 @@ func WithTCPBufferSizeRange(a, b, c int) Option { // WithTCPCongestionControl sets the current congestion control algorithm. func WithTCPCongestionControl(cc string) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { opt := tcpip.CongestionControlOption(cc) if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { return fmt.Errorf("set TCP congestion control algorithm: %s", err) @@ -181,7 +182,7 @@ func WithTCPCongestionControl(cc string) Option { // WithTCPDelay enables or disables Nagle's algorithm in TCP. func WithTCPDelay(v bool) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { opt := tcpip.TCPDelayEnabled(v) if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { return fmt.Errorf("set TCP delay: %s", err) @@ -192,7 +193,7 @@ func WithTCPDelay(v bool) Option { // WithTCPModerateReceiveBuffer sets receive buffer moderation for TCP. func WithTCPModerateReceiveBuffer(v bool) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { opt := tcpip.TCPModerateReceiveBufferOption(v) if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { return fmt.Errorf("set TCP moderate receive buffer: %s", err) @@ -203,7 +204,7 @@ func WithTCPModerateReceiveBuffer(v bool) Option { // WithTCPSACKEnabled sets the SACK option for TCP. func WithTCPSACKEnabled(v bool) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { opt := tcpip.TCPSACKEnabled(v) if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { return fmt.Errorf("set TCP SACK: %s", err) @@ -214,7 +215,7 @@ func WithTCPSACKEnabled(v bool) Option { // WithTCPRecovery sets the recovery option for TCP. func WithTCPRecovery(v tcpip.TCPRecovery) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { return fmt.Errorf("set TCP Recovery: %s", err) } diff --git a/listener/tun/ipstack/gvisor/icmp.go b/listener/tun/ipstack/gvisor/route.go similarity index 51% rename from listener/tun/ipstack/gvisor/icmp.go rename to listener/tun/ipstack/gvisor/route.go index 8b56d397..5a3d3bf4 100644 --- a/listener/tun/ipstack/gvisor/icmp.go +++ b/listener/tun/ipstack/gvisor/route.go @@ -1,22 +1,23 @@ package gvisor import ( + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) -func withICMPHandler() Option { - return func(s *gvStack) error { - // Add default route table for IPv4 and IPv6. - // This will handle all incoming ICMP packets. +func withRouteTable(nicID tcpip.NICID) option.Option { + return func(s *stack.Stack) error { s.SetRouteTable([]tcpip.Route{ { Destination: header.IPv4EmptySubnet, - NIC: s.nicID, + NIC: nicID, }, { Destination: header.IPv6EmptySubnet, - NIC: s.nicID, + NIC: nicID, }, }) return nil diff --git a/listener/tun/ipstack/gvisor/stack.go b/listener/tun/ipstack/gvisor/stack.go index 104c0300..9061995d 100644 --- a/listener/tun/ipstack/gvisor/stack.go +++ b/listener/tun/ipstack/gvisor/stack.go @@ -5,6 +5,7 @@ import ( "github.com/Dreamacro/clash/listener/tun/device" "github.com/Dreamacro/clash/listener/tun/ipstack" "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter" + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -18,23 +19,23 @@ import ( type gvStack struct { *stack.Stack device device.Device - - handler adapter.Handler - nicID tcpip.NICID } func (s *gvStack) Close() error { + var err error + if s.device != nil { + err = s.device.Close() + s.device.Wait() + } if s.Stack != nil { s.Stack.Close() + s.Stack.Wait() } - if s.device != nil { - _ = s.device.Close() - } - return nil + return err } // New allocates a new *gvStack with given options. -func New(device device.Device, handler adapter.Handler, opts ...Option) (ipstack.Stack, error) { +func New(device device.Device, handler adapter.Handler, opts ...option.Option) (ipstack.Stack, error) { s := &gvStack{ Stack: stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ @@ -49,19 +50,15 @@ func New(device device.Device, handler adapter.Handler, opts ...Option) (ipstack }, }), - device: device, - handler: handler, - nicID: defaultNICID, + device: device, } - opts = append(opts, - // Important: We must initiate transport protocol handlers - // before creating NIC, otherwise NIC would dispatch packets - // to stack and cause race condition. - withICMPHandler(), withTCPHandler(), withUDPHandler(), + // Generate unique NIC id. + nicID := tcpip.NICID(s.Stack.UniqueID()) - // Create stack NIC and then bind link endpoint. - withCreatingNIC(device.(stack.LinkEndpoint)), + opts = append(opts, + // Create stack NIC and then bind link endpoint to it. + withCreatingNIC(nicID, device), // In the past we did s.AddAddressRange to assign 0.0.0.0/0 // onto the interface. We need that to be able to terminate @@ -70,27 +67,34 @@ func New(device device.Device, handler adapter.Handler, opts ...Option) (ipstack // Promiscuous mode. https://github.com/google/gvisor/issues/3876 // // Ref: https://github.com/cloudflare/slirpnetstack/blob/master/stack.go - withPromiscuousMode(nicPromiscuousModeEnabled), + withPromiscuousMode(nicID, nicPromiscuousModeEnabled), - // Enable spoofing if a stack may send packets from unowned addresses. - // This change required changes to some netgophers since previously, - // promiscuous mode was enough to let the netstack respond to all - // incoming packets regardless of the packet's destination address. Now - // that a stack.Route is not held for each incoming packet, finding a route - // may fail with local addresses we don't own but accepted packets for - // while in promiscuous mode. Since we also want to be able to send from - // any address (in response the received promiscuous mode packets), we need - // to enable spoofing. + // Enable spoofing if a stack may send packets from unowned + // addresses. This change required changes to some netgophers + // since previously, promiscuous mode was enough to let the + // netstack respond to all incoming packets regardless of the + // packet's destination address. Now that a stack.Route is not + // held for each incoming packet, finding a route may fail with + // local addresses we don't own but accepted packets for while + // in promiscuous mode. Since we also want to be able to send + // from any address (in response the received promiscuous mode + // packets), we need to enable spoofing. // // Ref: https://github.com/google/gvisor/commit/8c0701462a84ff77e602f1626aec49479c308127 - withSpoofing(nicSpoofingEnabled), + withSpoofing(nicID, nicSpoofingEnabled), + + // Add default route table for IPv4 and IPv6. This will handle + // all incoming ICMP packets. + withRouteTable(nicID), + + // Initiate transport protocol (TCP/UDP) with given handler. + withTCPHandler(handler.HandleTCP), withUDPHandler(handler.HandleUDP), ) for _, opt := range opts { - if err := opt(s); err != nil { + if err := opt(s.Stack); err != nil { return nil, err } } - return s, nil } diff --git a/listener/tun/ipstack/gvisor/tcp.go b/listener/tun/ipstack/gvisor/tcp.go index 8b8277e0..8bffb932 100644 --- a/listener/tun/ipstack/gvisor/tcp.go +++ b/listener/tun/ipstack/gvisor/tcp.go @@ -4,6 +4,9 @@ import ( "fmt" "time" + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter" + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -37,9 +40,9 @@ const ( tcpKeepaliveInterval = 30 * time.Second ) -func withTCPHandler() Option { - return func(s *gvStack) error { - tcpForwarder := tcp.NewForwarder(s.Stack, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) { +func withTCPHandler(handle adapter.TCPHandleFunc) option.Option { + return func(s *stack.Stack) error { + tcpForwarder := tcp.NewForwarder(s, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) { var wq waiter.Queue ep, err := r.CreateEndpoint(&wq) if err != nil { @@ -55,7 +58,7 @@ func withTCPHandler() Option { TCPConn: gonet.NewTCPConn(&wq, ep), id: r.ID(), } - s.handler.HandleTCPConn(conn) + handle(conn) }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) return nil diff --git a/listener/tun/ipstack/gvisor/udp.go b/listener/tun/ipstack/gvisor/udp.go index 6efbd204..688583a0 100644 --- a/listener/tun/ipstack/gvisor/udp.go +++ b/listener/tun/ipstack/gvisor/udp.go @@ -5,6 +5,7 @@ import ( "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter" + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -12,9 +13,9 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -func withUDPHandler() Option { - return func(s *gvStack) error { - udpForwarder := udp.NewForwarder(s.Stack, func(r *udp.ForwarderRequest) { +func withUDPHandler(handle adapter.UDPHandleFunc) option.Option { + return func(s *stack.Stack) error { + udpForwarder := udp.NewForwarder(s, func(r *udp.ForwarderRequest) { var wq waiter.Queue ep, err := r.CreateEndpoint(&wq) if err != nil { @@ -23,10 +24,10 @@ func withUDPHandler() Option { } conn := &udpConn{ - UDPConn: gonet.NewUDPConn(s.Stack, &wq, ep), + UDPConn: gonet.NewUDPConn(s, &wq, ep), id: r.ID(), } - s.handler.HandleUDPConn(conn) + handle(conn) }) s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) return nil diff --git a/listener/tun/ipstack/system/stack.go b/listener/tun/ipstack/system/stack.go index ad3f465d..d8b250ba 100644 --- a/listener/tun/ipstack/system/stack.go +++ b/listener/tun/ipstack/system/stack.go @@ -36,8 +36,6 @@ func (s sysStack) Close() error { return nil } -var ipv4LoopBack = netip.MustParsePrefix("127.0.0.0/8") - func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Prefix, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.Stack, error) { var ( gateway = tunAddress.Masked().Addr().Next() @@ -71,12 +69,6 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref rAddrIp, _ := netip.AddrFromSlice(rAddr.IP) rAddrPort := netip.AddrPortFrom(rAddrIp, uint16(rAddr.Port)) - if ipv4LoopBack.Contains(rAddrIp) { - conn.Close() - - continue - } - if D.ShouldHijackDns(dnsAddr, rAddrPort) { go func() { log.Debugln("[TUN] hijack dns tcp: %s", rAddrPort.String()) @@ -149,12 +141,6 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref rAddrIp, _ := netip.AddrFromSlice(rAddr.IP) rAddrPort := netip.AddrPortFrom(rAddrIp, uint16(rAddr.Port)) - if ipv4LoopBack.Contains(rAddrIp) { - pool.Put(buf) - - continue - } - if D.ShouldHijackDns(dnsAddr, rAddrPort) { go func() { defer pool.Put(buf) diff --git a/listener/tun/tun_adapter.go b/listener/tun/tun_adapter.go index 333a85bd..75d568e1 100644 --- a/listener/tun/tun_adapter.go +++ b/listener/tun/tun_adapter.go @@ -18,6 +18,7 @@ import ( "github.com/Dreamacro/clash/listener/tun/ipstack" "github.com/Dreamacro/clash/listener/tun/ipstack/commons" "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor" + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" "github.com/Dreamacro/clash/listener/tun/ipstack/system" "github.com/Dreamacro/clash/log" ) @@ -67,7 +68,7 @@ func New(tunConf *config.Tun, tunAddressPrefix string, tcpIn chan<- C.ConnContex DNSAdds: tunConf.DNSHijack, TCPIn: tcpIn, UDPIn: udpIn, }, - gvisor.WithDefault(), + option.WithDefault(), ) if err != nil { diff --git a/main.go b/main.go index bc8c7a55..104a4964 100644 --- a/main.go +++ b/main.go @@ -100,13 +100,9 @@ func main() { log.Fatalln("Parse config error: %s", err.Error()) } + defer executor.Shutdown() + sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) <-sigCh - - // cleanup - log.Warnln("Clash cleanup") - hub.Cleanup() - - log.Warnln("Clash shutting down") } From 56e2c172e1d8af6eb19e65d8da1c04f009fd9203 Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Tue, 29 Mar 2022 07:06:41 +0800 Subject: [PATCH 5/8] Chore: adjust tun_wireguard cache buffer --- listener/tun/device/tun/tun_wireguard.go | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/listener/tun/device/tun/tun_wireguard.go b/listener/tun/device/tun/tun_wireguard.go index 30398b55..6e0ad15f 100644 --- a/listener/tun/device/tun/tun_wireguard.go +++ b/listener/tun/device/tun/tun_wireguard.go @@ -8,7 +8,6 @@ import ( "os" "runtime" - "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/listener/tun/device" "github.com/Dreamacro/clash/listener/tun/device/iobased" @@ -22,6 +21,8 @@ type TUN struct { mtu uint32 name string offset int + + cache []byte } func Open(name string, mtu uint32) (_ device.Device, err error) { @@ -70,6 +71,10 @@ func Open(name string, mtu uint32) (_ device.Device, err error) { } t.mtu = uint32(tunMTU) + if t.offset > 0 { + t.cache = make([]byte, 65535) + } + return t, nil } @@ -78,19 +83,9 @@ func (t *TUN) Read(packet []byte) (int, error) { return t.nt.Read(packet, t.offset) } - buff := pool.Get(t.offset + cap(packet)) - defer func() { - _ = pool.Put(buff) - }() + n, err := t.nt.Read(t.cache, t.offset) - n, err := t.nt.Read(buff, t.offset) - if err != nil { - return 0, err - } - - _ = buff[:t.offset] - - copy(packet, buff[t.offset:t.offset+n]) + copy(packet, t.cache[t.offset:t.offset+n]) return n, err } @@ -100,7 +95,8 @@ func (t *TUN) Write(packet []byte) (int, error) { return t.nt.Write(packet, t.offset) } - packet = append(make([]byte, t.offset), packet...) + _ = t.cache[:t.offset] + packet = append(t.cache[:t.offset], packet...) return t.nt.Write(packet, t.offset) } From 131e9d38b6313e465e28845be131567409629207 Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Tue, 29 Mar 2022 07:18:09 +0800 Subject: [PATCH 6/8] Fix: Vless UDP --- README.md | 1 + adapter/outbound/vless.go | 17 +++++++++++++---- test/vless_test.go | 18 +++++++++--------- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 5df2420b..d7af8a41 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,7 @@ proxies: network: tcp servername: example.com # flow: xtls-rprx-direct # xtls-rprx-origin # enable XTLS + # udp: true # skip-cert-verify: true ``` diff --git a/adapter/outbound/vless.go b/adapter/outbound/vless.go index df7ce4a1..876613b2 100644 --- a/adapter/outbound/vless.go +++ b/adapter/outbound/vless.go @@ -3,6 +3,7 @@ package outbound import ( "context" "crypto/tls" + "encoding/binary" "errors" "fmt" "net" @@ -128,8 +129,9 @@ func (v *Vless) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { c, err = gun.StreamGunWithConn(c, v.gunTLSConfig, v.gunConfig) } default: + // default tcp network // handle TLS And XTLS - c, err = v.streamTLSOrXTLSConn(c, true) + c, err = v.streamTLSOrXTLSConn(c, false) } if err != nil { @@ -213,7 +215,7 @@ func (v *Vless) DialContext(ctx context.Context, metadata *C.Metadata, opts ...d // ListenPacketContext implements C.ProxyAdapter func (v *Vless) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { - // vmess use stream-oriented udp with a special address, so we needs a net.UDPAddr + // vless use stream-oriented udp with a special address, so we needs a net.UDPAddr if !metadata.Resolved() { ip, err := resolver.ResolveIP(metadata.Host) if err != nil { @@ -269,7 +271,7 @@ func parseVlessAddr(metadata *C.Metadata) *vless.DstAddr { copy(addr[1:], []byte(metadata.Host)) } - port, _ := strconv.Atoi(metadata.DstPort) + port, _ := strconv.ParseUint(metadata.DstPort, 10, 16) return &vless.DstAddr{ UDP: metadata.NetWork == C.UDP, AddrType: addrType, @@ -281,14 +283,21 @@ func parseVlessAddr(metadata *C.Metadata) *vless.DstAddr { type vlessPacketConn struct { net.Conn rAddr net.Addr + cache [2]byte } func (uc *vlessPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { + binary.BigEndian.PutUint16(uc.cache[:], uint16(len(b))) + _, _ = uc.Conn.Write(uc.cache[:]) return uc.Conn.Write(b) } func (uc *vlessPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { - n, err := uc.Conn.Read(b) + n, err := uc.Conn.Read(uc.cache[:]) + if err != nil { + return n, uc.rAddr, err + } + n, err = uc.Conn.Read(b) return n, uc.rAddr, err } diff --git a/test/vless_test.go b/test/vless_test.go index 8bce51d6..b5fbec86 100644 --- a/test/vless_test.go +++ b/test/vless_test.go @@ -40,7 +40,7 @@ func TestClash_VlessTLS(t *testing.T) { TLS: true, SkipCertVerify: true, ServerName: "example.org", - UDP: false, + UDP: true, }) if err != nil { assert.FailNow(t, err.Error()) @@ -71,16 +71,16 @@ func TestClash_VlessXTLS(t *testing.T) { defer cleanContainer(id) proxy, err := outbound.NewVless(outbound.VlessOption{ - Name: "vless", - Server: localIP.String(), - Port: 10002, - UUID: "b831381d-6324-4d53-ad4f-8cda48b30811", - TLS: true, - Flow: "xtls-rprx-direct", - //FlowShow: true, + Name: "vless", + Server: localIP.String(), + Port: 10002, + UUID: "b831381d-6324-4d53-ad4f-8cda48b30811", + TLS: true, SkipCertVerify: true, ServerName: "example.org", - UDP: false, + UDP: true, + Flow: "xtls-rprx-direct", + FlowShow: true, }) if err != nil { assert.FailNow(t, err.Error()) From b3ea2ff8b6ab5dcb93cda20aa7679fb5b7deee12 Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Tue, 29 Mar 2022 23:50:41 +0800 Subject: [PATCH 7/8] Chore: adjust VLESS --- README.md | 14 ++- adapter/outbound/vless.go | 141 +++++++++++++++++------ listener/tun/device/tun/tun_wireguard.go | 1 - test/config/vless-ws.json | 35 ++++++ test/config/vless-xtls.json | 4 +- test/vless_test.go | 40 ++++++- transport/vless/vless.go | 9 -- 7 files changed, 193 insertions(+), 51 deletions(-) create mode 100644 test/config/vless-ws.json diff --git a/README.md b/README.md index d7af8a41..e9821aaa 100644 --- a/README.md +++ b/README.md @@ -132,14 +132,24 @@ Support outbound transport protocol `VLESS`. The XTLS only support TCP transport by the XRAY-CORE. ```yaml proxies: - - name: "vless-tcp" + - name: "vless-tls" type: vless server: server port: 443 uuid: uuid network: tcp servername: example.com - # flow: xtls-rprx-direct # xtls-rprx-origin # enable XTLS + udp: true + # skip-cert-verify: true + - name: "vless-xtls" + type: vless + server: server + port: 443 + uuid: uuid + network: tcp + servername: example.com + flow: xtls-rprx-direct # or xtls-rprx-origin + # flow-show: true # print the XTLS direct log # udp: true # skip-cert-verify: true ``` diff --git a/adapter/outbound/vless.go b/adapter/outbound/vless.go index 876613b2..908c4de8 100644 --- a/adapter/outbound/vless.go +++ b/adapter/outbound/vless.go @@ -6,9 +6,11 @@ import ( "encoding/binary" "errors" "fmt" + "io" "net" "net/http" "strconv" + "sync" "github.com/Dreamacro/clash/component/dialer" "github.com/Dreamacro/clash/component/resolver" @@ -20,6 +22,11 @@ import ( "golang.org/x/net/http2" ) +const ( + // max packet length + maxLength = 1024 << 3 +) + type Vless struct { *Base client *vless.Client @@ -39,7 +46,6 @@ type VlessOption struct { UUID string `proxy:"uuid"` Flow string `proxy:"flow,omitempty"` FlowShow bool `proxy:"flow-show,omitempty"` - TLS bool `proxy:"tls,omitempty"` UDP bool `proxy:"udp,omitempty"` Network string `proxy:"network,omitempty"` HTTPOpts HTTPOptions `proxy:"http-opts,omitempty"` @@ -80,19 +86,19 @@ func (v *Vless) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { wsOpts.Headers = header } - if v.option.TLS { - wsOpts.TLS = true - wsOpts.TLSConfig = &tls.Config{ - ServerName: host, - InsecureSkipVerify: v.option.SkipCertVerify, - NextProtos: []string{"http/1.1"}, - } - if v.option.ServerName != "" { - wsOpts.TLSConfig.ServerName = v.option.ServerName - } else if host := wsOpts.Headers.Get("Host"); host != "" { - wsOpts.TLSConfig.ServerName = host - } + wsOpts.TLS = true + wsOpts.TLSConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + ServerName: host, + InsecureSkipVerify: v.option.SkipCertVerify, + NextProtos: []string{"http/1.1"}, } + if v.option.ServerName != "" { + wsOpts.TLSConfig.ServerName = v.option.ServerName + } else if host := wsOpts.Headers.Get("Host"); host != "" { + wsOpts.TLSConfig.ServerName = host + } + c, err = vmess.StreamWebsocketConn(c, wsOpts) case "http": // readability first, so just copy default TLS logic @@ -160,7 +166,7 @@ func (v *Vless) streamTLSOrXTLSConn(conn net.Conn, isH2 bool) (net.Conn, error) return vless.StreamXTLSConn(conn, &xtlsOpts) - } else if v.option.TLS { + } else { tlsOpts := vmess.TLSConfig{ Host: host, SkipCertVerify: v.option.SkipCertVerify, @@ -176,8 +182,6 @@ func (v *Vless) streamTLSOrXTLSConn(conn net.Conn, isH2 bool) (net.Conn, error) return vmess.StreamTLSConn(conn, &tlsOpts) } - - return conn, nil } func (v *Vless) isXTLSEnabled() bool { @@ -282,30 +286,97 @@ func parseVlessAddr(metadata *C.Metadata) *vless.DstAddr { type vlessPacketConn struct { net.Conn - rAddr net.Addr - cache [2]byte + rAddr net.Addr + cache [2]byte + remain int + mux sync.Mutex } -func (uc *vlessPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { - binary.BigEndian.PutUint16(uc.cache[:], uint16(len(b))) - _, _ = uc.Conn.Write(uc.cache[:]) - return uc.Conn.Write(b) -} - -func (uc *vlessPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { - n, err := uc.Conn.Read(uc.cache[:]) - if err != nil { - return n, uc.rAddr, err +func (vc *vlessPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { + total := len(b) + if total == 0 { + return 0, nil } - n, err = uc.Conn.Read(b) - return n, uc.rAddr, err + + if total < maxLength { + return vc.writePacket(b) + } + + offset := 0 + for { + cursor := offset + maxLength + if cursor > total { + cursor = total + } + + n, err := vc.writePacket(b[offset:cursor]) + if err != nil { + return offset + n, err + } + + offset = cursor + if offset == total { + break + } + } + + return total, nil +} + +func (vc *vlessPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { + vc.mux.Lock() + defer vc.mux.Unlock() + + if vc.remain != 0 { + length := len(b) + if length > vc.remain { + length = vc.remain + } + + n, err := vc.Conn.Read(b[:length]) + if err != nil { + return 0, vc.rAddr, err + } + + vc.remain -= n + + return n, vc.rAddr, nil + } + + if _, err := vc.Conn.Read(b[:2]); err != nil { + return 0, vc.rAddr, err + } + + total := int(binary.BigEndian.Uint16(b[:2])) + if total == 0 { + return 0, vc.rAddr, nil + } + + length := len(b) + if length > total { + length = total + } + + if _, err := io.ReadFull(vc.Conn, b[:length]); err != nil { + return 0, vc.rAddr, errors.New("read packet error") + } + + vc.remain = total - length + + return length, vc.rAddr, nil +} + +func (vc *vlessPacketConn) writePacket(payload []byte) (int, error) { + binary.BigEndian.PutUint16(vc.cache[:], uint16(len(payload))) + + if _, err := vc.Conn.Write(vc.cache[:]); err != nil { + return 0, err + } + + return vc.Conn.Write(payload) } func NewVless(option VlessOption) (*Vless, error) { - if !option.TLS { - return nil, fmt.Errorf("TLS must be true with vless") - } - var addons *vless.Addons if option.Network != "ws" && len(option.Flow) >= 16 { option.Flow = option.Flow[:16] @@ -315,7 +386,7 @@ func NewVless(option VlessOption) (*Vless, error) { Flow: option.Flow, } default: - return nil, fmt.Errorf("unsupported vless flow type: %s", option.Flow) + return nil, fmt.Errorf("unsupported xtls flow type: %s", option.Flow) } } diff --git a/listener/tun/device/tun/tun_wireguard.go b/listener/tun/device/tun/tun_wireguard.go index 6e0ad15f..35008425 100644 --- a/listener/tun/device/tun/tun_wireguard.go +++ b/listener/tun/device/tun/tun_wireguard.go @@ -95,7 +95,6 @@ func (t *TUN) Write(packet []byte) (int, error) { return t.nt.Write(packet, t.offset) } - _ = t.cache[:t.offset] packet = append(t.cache[:t.offset], packet...) return t.nt.Write(packet, t.offset) diff --git a/test/config/vless-ws.json b/test/config/vless-ws.json new file mode 100644 index 00000000..9f3a5db8 --- /dev/null +++ b/test/config/vless-ws.json @@ -0,0 +1,35 @@ +{ + "inbounds": [ + { + "port": 10002, + "listen": "0.0.0.0", + "protocol": "vless", + "settings": { + "clients": [ + { + "id": "b831381d-6324-4d53-ad4f-8cda48b30811", + "level": 0, + "email": "ws@example.com" + } + ] + }, + "streamSettings": { + "network": "ws", + "security": "tls", + "tlsSettings": { + "certificates": [ + { + "certificateFile": "/etc/ssl/v2ray/fullchain.pem", + "keyFile": "/etc/ssl/v2ray/privkey.pem" + } + ] + } + } + } + ], + "outbounds": [ + { + "protocol": "freedom" + } + ] +} \ No newline at end of file diff --git a/test/config/vless-xtls.json b/test/config/vless-xtls.json index b4381d61..1d352c3f 100644 --- a/test/config/vless-xtls.json +++ b/test/config/vless-xtls.json @@ -8,9 +8,9 @@ "clients": [ { "id": "b831381d-6324-4d53-ad4f-8cda48b30811", + "email": "xtls@example.com", "flow": "xtls-rprx-direct", - "level": 0, - "email": "love@example.com" + "level": 0 } ], "decryption": "none" diff --git a/test/vless_test.go b/test/vless_test.go index b5fbec86..3f516925 100644 --- a/test/vless_test.go +++ b/test/vless_test.go @@ -37,7 +37,6 @@ func TestClash_VlessTLS(t *testing.T) { Server: localIP.String(), Port: 10002, UUID: "b831381d-6324-4d53-ad4f-8cda48b30811", - TLS: true, SkipCertVerify: true, ServerName: "example.org", UDP: true, @@ -75,7 +74,6 @@ func TestClash_VlessXTLS(t *testing.T) { Server: localIP.String(), Port: 10002, UUID: "b831381d-6324-4d53-ad4f-8cda48b30811", - TLS: true, SkipCertVerify: true, ServerName: "example.org", UDP: true, @@ -89,3 +87,41 @@ func TestClash_VlessXTLS(t *testing.T) { time.Sleep(waitTime) testSuit(t, proxy) } + +func TestClash_VlessWS(t *testing.T) { + cfg := &container.Config{ + Image: ImageVmess, + ExposedPorts: defaultExposedPorts, + } + hostCfg := &container.HostConfig{ + PortBindings: defaultPortBindings, + Binds: []string{ + fmt.Sprintf("%s:/etc/v2ray/config.json", C.Path.Resolve("vless-ws.json")), + fmt.Sprintf("%s:/etc/ssl/v2ray/fullchain.pem", C.Path.Resolve("example.org.pem")), + fmt.Sprintf("%s:/etc/ssl/v2ray/privkey.pem", C.Path.Resolve("example.org-key.pem")), + }, + } + + id, err := startContainer(cfg, hostCfg, "vless-ws") + if err != nil { + assert.FailNow(t, err.Error()) + } + defer cleanContainer(id) + + proxy, err := outbound.NewVless(outbound.VlessOption{ + Name: "vless", + Server: localIP.String(), + Port: 10002, + UUID: "b831381d-6324-4d53-ad4f-8cda48b30811", + SkipCertVerify: true, + ServerName: "example.org", + Network: "ws", + UDP: true, + }) + if err != nil { + assert.FailNow(t, err.Error()) + } + + time.Sleep(waitTime) + testSuit(t, proxy) +} diff --git a/transport/vless/vless.go b/transport/vless/vless.go index 458f54de..bb0ce881 100644 --- a/transport/vless/vless.go +++ b/transport/vless/vless.go @@ -35,15 +35,6 @@ type DstAddr struct { Port uint } -// Config of vless -type Config struct { - UUID string - AlterID uint16 - Security string - Port string - HostName string -} - // Client is vless connection generator type Client struct { uuid *uuid.UUID From 9ff1f5530e965773ae554529fa5558f4f18c1a63 Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Wed, 30 Mar 2022 00:15:39 +0800 Subject: [PATCH 8/8] Feature: Trojan XTLS --- README.md | 22 +++++++-- adapter/outbound/trojan.go | 32 ++++++++++++- test/config/trojan-xtls.json | 39 +++++++++++++++ test/trojan_test.go | 40 ++++++++++++++++ transport/trojan/trojan.go | 93 +++++++++++++++++++++++++++++------- 5 files changed, 205 insertions(+), 21 deletions(-) create mode 100644 test/config/trojan-xtls.json diff --git a/README.md b/README.md index e9821aaa..2eca8bbe 100644 --- a/README.md +++ b/README.md @@ -127,11 +127,14 @@ rules: ``` ### Proxies configuration -Support outbound transport protocol `VLESS`. +Support outbound protocol `VLESS`. -The XTLS only support TCP transport by the XRAY-CORE. +Support `Trojan` with XTLS. + +Currently XTLS only supports TCP transport. ```yaml proxies: + # VLESS - name: "vless-tls" type: vless server: server @@ -149,9 +152,22 @@ proxies: network: tcp servername: example.com flow: xtls-rprx-direct # or xtls-rprx-origin - # flow-show: true # print the XTLS direct log + # flow-show: true # print the XTLS direction log # udp: true # skip-cert-verify: true + + # Trojan + - name: "trojan-xtls" + type: trojan + server: server + port: 443 + password: yourpsk + network: tcp + flow: xtls-rprx-direct # or xtls-rprx-origin + # flow-show: true # print the XTLS direction log + # udp: true + # sni: example.com # aka server name + # skip-cert-verify: true ``` ### IPTABLES configuration diff --git a/adapter/outbound/trojan.go b/adapter/outbound/trojan.go index 064cd3c2..aa389b34 100644 --- a/adapter/outbound/trojan.go +++ b/adapter/outbound/trojan.go @@ -12,6 +12,7 @@ import ( C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/transport/gun" "github.com/Dreamacro/clash/transport/trojan" + "github.com/Dreamacro/clash/transport/vless" "golang.org/x/net/http2" ) @@ -40,6 +41,8 @@ type TrojanOption struct { Network string `proxy:"network,omitempty"` GrpcOpts GrpcOptions `proxy:"grpc-opts,omitempty"` WSOpts WSOptions `proxy:"ws-opts,omitempty"` + Flow string `proxy:"flow,omitempty"` + FlowShow bool `proxy:"flow-show,omitempty"` } func (t *Trojan) plainStream(c net.Conn) (net.Conn, error) { @@ -82,6 +85,11 @@ func (t *Trojan) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) return nil, fmt.Errorf("%s connect error: %w", t.addr, err) } + c, err = t.instance.PresetXTLSConn(c) + if err != nil { + return nil, err + } + err = t.instance.WriteHeader(c, trojan.CommandTCP, serializesSocksAddr(metadata)) return c, err } @@ -95,6 +103,12 @@ func (t *Trojan) DialContext(ctx context.Context, metadata *C.Metadata, opts ... return nil, err } + c, err = t.instance.PresetXTLSConn(c) + if err != nil { + c.Close() + return nil, err + } + if err = t.instance.WriteHeader(c, trojan.CommandTCP, serializesSocksAddr(metadata)); err != nil { c.Close() return nil, err @@ -160,6 +174,17 @@ func NewTrojan(option TrojanOption) (*Trojan, error) { ALPN: option.ALPN, ServerName: option.Server, SkipCertVerify: option.SkipCertVerify, + FlowShow: option.FlowShow, + } + + if option.Network != "ws" && len(option.Flow) >= 16 { + option.Flow = option.Flow[:16] + switch option.Flow { + case vless.XRO, vless.XRD, vless.XRS: + tOption.Flow = option.Flow + default: + return nil, fmt.Errorf("unsupported xtls flow type: %s", option.Flow) + } } if option.SNI != "" { @@ -196,7 +221,12 @@ func NewTrojan(option TrojanOption) (*Trojan, error) { ServerName: tOption.ServerName, } - t.transport = gun.NewHTTP2Client(dialFn, tlsConfig) + if t.option.Flow != "" { + t.transport = gun.NewHTTP2XTLSClient(dialFn, tlsConfig) + } else { + t.transport = gun.NewHTTP2Client(dialFn, tlsConfig) + } + t.gunTLSConfig = tlsConfig t.gunConfig = &gun.Config{ ServiceName: option.GrpcOpts.GrpcServiceName, diff --git a/test/config/trojan-xtls.json b/test/config/trojan-xtls.json new file mode 100644 index 00000000..c3a72eee --- /dev/null +++ b/test/config/trojan-xtls.json @@ -0,0 +1,39 @@ +{ + "inbounds": [ + { + "port": 10002, + "listen": "0.0.0.0", + "protocol": "trojan", + "settings": { + "clients": [ + { + "password": "example", + "email": "xtls@example.com", + "flow": "xtls-rprx-direct", + "level": 0 + } + ] + }, + "streamSettings": { + "network": "tcp", + "security": "xtls", + "xtlsSettings": { + "certificates": [ + { + "certificateFile": "/etc/ssl/v2ray/fullchain.pem", + "keyFile": "/etc/ssl/v2ray/privkey.pem" + } + ] + } + } + } + ], + "outbounds": [ + { + "protocol": "freedom" + } + ], + "log": { + "loglevel": "debug" + } +} \ No newline at end of file diff --git a/test/trojan_test.go b/test/trojan_test.go index d1ab2a00..8b4e1745 100644 --- a/test/trojan_test.go +++ b/test/trojan_test.go @@ -131,6 +131,46 @@ func TestClash_TrojanWebsocket(t *testing.T) { testSuit(t, proxy) } +func TestClash_TrojanXTLS(t *testing.T) { + cfg := &container.Config{ + Image: ImageXray, + ExposedPorts: defaultExposedPorts, + } + hostCfg := &container.HostConfig{ + PortBindings: defaultPortBindings, + Binds: []string{ + fmt.Sprintf("%s:/etc/xray/config.json", C.Path.Resolve("trojan-xtls.json")), + fmt.Sprintf("%s:/etc/ssl/v2ray/fullchain.pem", C.Path.Resolve("example.org.pem")), + fmt.Sprintf("%s:/etc/ssl/v2ray/privkey.pem", C.Path.Resolve("example.org-key.pem")), + }, + } + + id, err := startContainer(cfg, hostCfg, "trojan-xtls") + if err != nil { + assert.FailNow(t, err.Error()) + } + defer cleanContainer(id) + + proxy, err := outbound.NewTrojan(outbound.TrojanOption{ + Name: "trojan", + Server: localIP.String(), + Port: 10002, + Password: "example", + SNI: "example.org", + SkipCertVerify: true, + UDP: true, + Network: "tcp", + Flow: "xtls-rprx-direct", + FlowShow: true, + }) + if err != nil { + assert.FailNow(t, err.Error()) + } + + time.Sleep(waitTime) + testSuit(t, proxy) +} + func Benchmark_Trojan(b *testing.B) { cfg := &container.Config{ Image: ImageTrojan, diff --git a/transport/trojan/trojan.go b/transport/trojan/trojan.go index ac9f17dd..a0e289f1 100644 --- a/transport/trojan/trojan.go +++ b/transport/trojan/trojan.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "encoding/hex" "errors" + "fmt" "io" "net" "net/http" @@ -15,7 +16,10 @@ import ( "github.com/Dreamacro/clash/common/pool" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/transport/socks5" + "github.com/Dreamacro/clash/transport/vless" "github.com/Dreamacro/clash/transport/vmess" + + xtls "github.com/xtls/go" ) const ( @@ -32,9 +36,13 @@ var ( type Command = byte -var ( +const ( CommandTCP byte = 1 CommandUDP byte = 3 + + // for XTLS + commandXRD byte = 0xf0 // XTLS direct mode + commandXRO byte = 0xf1 // XTLS origin mode ) type Option struct { @@ -42,6 +50,8 @@ type Option struct { ALPN []string ServerName string SkipCertVerify bool + Flow string + FlowShow bool } type WebsocketOption struct { @@ -62,23 +72,42 @@ func (t *Trojan) StreamConn(conn net.Conn) (net.Conn, error) { alpn = t.option.ALPN } - tlsConfig := &tls.Config{ - NextProtos: alpn, - MinVersion: tls.VersionTLS12, - InsecureSkipVerify: t.option.SkipCertVerify, - ServerName: t.option.ServerName, + if t.option.Flow != "" { + xtlsConfig := &xtls.Config{ + NextProtos: alpn, + MinVersion: xtls.VersionTLS12, + InsecureSkipVerify: t.option.SkipCertVerify, + ServerName: t.option.ServerName, + } + + xtlsConn := xtls.Client(conn, xtlsConfig) + + ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) + defer cancel() + if err := xtlsConn.HandshakeContext(ctx); err != nil { + return nil, err + } + + return xtlsConn, nil + } else { + tlsConfig := &tls.Config{ + NextProtos: alpn, + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: t.option.SkipCertVerify, + ServerName: t.option.ServerName, + } + + tlsConn := tls.Client(conn, tlsConfig) + + // fix tls handshake not timeout + ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) + defer cancel() + if err := tlsConn.HandshakeContext(ctx); err != nil { + return nil, err + } + + return tlsConn, nil } - - tlsConn := tls.Client(conn, tlsConfig) - - // fix tls handshake not timeout - ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) - defer cancel() - if err := tlsConn.HandshakeContext(ctx); err != nil { - return nil, err - } - - return tlsConn, nil } func (t *Trojan) StreamWebsocketConn(conn net.Conn, wsOptions *WebsocketOption) (net.Conn, error) { @@ -104,7 +133,37 @@ func (t *Trojan) StreamWebsocketConn(conn net.Conn, wsOptions *WebsocketOption) }) } +func (t *Trojan) PresetXTLSConn(conn net.Conn) (net.Conn, error) { + switch t.option.Flow { + case vless.XRO, vless.XRD, vless.XRS: + if xtlsConn, ok := conn.(*xtls.Conn); ok { + xtlsConn.RPRX = true + xtlsConn.SHOW = t.option.FlowShow + xtlsConn.MARK = "XTLS" + if t.option.Flow == vless.XRS { + t.option.Flow = vless.XRD + } + + if t.option.Flow == vless.XRD { + xtlsConn.DirectMode = true + } + } else { + return nil, fmt.Errorf("failed to use %s, maybe \"security\" is not \"xtls\"", t.option.Flow) + } + } + + return conn, nil +} + func (t *Trojan) WriteHeader(w io.Writer, command Command, socks5Addr []byte) error { + if command == CommandTCP { + if t.option.Flow == vless.XRD { + command = commandXRD + } else if t.option.Flow == vless.XRO { + command = commandXRO + } + } + buf := pool.GetBuffer() defer pool.PutBuffer(buf)