diff --git a/adapter/outbound/direct.go b/adapter/outbound/direct.go index 4cf16e38..f7c88124 100644 --- a/adapter/outbound/direct.go +++ b/adapter/outbound/direct.go @@ -14,6 +14,7 @@ type Direct struct { // DialContext implements C.ProxyAdapter func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { + opts = append(opts, dialer.WithDirect()) c, err := dialer.DialContext(ctx, "tcp", metadata.RemoteAddress(), d.Base.DialOptions(opts...)...) if err != nil { return nil, err @@ -24,6 +25,7 @@ func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ... // ListenPacketContext implements C.ProxyAdapter func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) { + opts = append(opts, dialer.WithDirect()) pc, err := dialer.ListenPacket(ctx, "udp", "", d.Base.DialOptions(opts...)...) if err != nil { return nil, err diff --git a/adapter/outbound/util.go b/adapter/outbound/util.go index ff6a8e65..7ca795dd 100644 --- a/adapter/outbound/util.go +++ b/adapter/outbound/util.go @@ -67,7 +67,7 @@ func resolveUDPAddr(network, address string) (*net.UDPAddr, error) { return nil, err } - ip, err := resolver.ResolveIP(host) + ip, err := resolver.ResolveProxyServerHost(host) if err != nil { return nil, err } diff --git a/adapter/provider/healthcheck.go b/adapter/provider/healthcheck.go index 7a89ffa3..bfbaf6b0 100644 --- a/adapter/provider/healthcheck.go +++ b/adapter/provider/healthcheck.go @@ -31,7 +31,13 @@ type HealthCheck struct { func (hc *HealthCheck) process() { ticker := time.NewTicker(time.Duration(hc.interval) * time.Second) - go hc.check() + go func() { + t := time.NewTicker(30 * time.Second) + <-t.C + t.Stop() + hc.check() + }() + for { select { case <-ticker.C: diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index 09254c58..19a5fdd2 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -32,14 +32,14 @@ func DialContext(ctx context.Context, network, address string, options ...Option var ip net.IP switch network { case "tcp4", "udp4": - if opt.interfaceName != "" { - ip, err = resolver.ResolveIPv4WithMain(host) + if !opt.direct { + ip, err = resolver.ResolveIPv4ProxyServerHost(host) } else { ip, err = resolver.ResolveIPv4(host) } default: - if opt.interfaceName != "" { - ip, err = resolver.ResolveIPv6WithMain(host) + if !opt.direct { + ip, err = resolver.ResolveIPv6ProxyServerHost(host) } else { ip, err = resolver.ResolveIPv6(host) } @@ -121,7 +121,7 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt results := make(chan dialResult) var primary, fallback dialResult - startRacer := func(ctx context.Context, network, host string, ipv6 bool) { + startRacer := func(ctx context.Context, network, host string, direct bool, ipv6 bool) { result := dialResult{ipv6: ipv6, done: true} defer func() { select { @@ -135,14 +135,14 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt var ip net.IP if ipv6 { - if opt.interfaceName != "" { - ip, result.error = resolver.ResolveIPv6WithMain(host) + if !direct { + ip, result.error = resolver.ResolveIPv6ProxyServerHost(host) } else { ip, result.error = resolver.ResolveIPv6(host) } } else { - if opt.interfaceName != "" { - ip, result.error = resolver.ResolveIPv4WithMain(host) + if !direct { + ip, result.error = resolver.ResolveIPv4ProxyServerHost(host) } else { ip, result.error = resolver.ResolveIPv4(host) } @@ -155,8 +155,8 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt result.Conn, result.error = dialContext(ctx, network, ip, port, opt) } - go startRacer(ctx, network+"4", host, false) - go startRacer(ctx, network+"6", host, true) + go startRacer(ctx, network+"4", host, opt.direct, false) + go startRacer(ctx, network+"6", host, opt.direct, true) for res := range results { if res.error == nil { diff --git a/component/dialer/options.go b/component/dialer/options.go index 2d884094..2985dc7b 100644 --- a/component/dialer/options.go +++ b/component/dialer/options.go @@ -12,6 +12,7 @@ type option struct { interfaceName string addrReuse bool routingMark int + direct bool } type Option func(opt *option) @@ -33,3 +34,9 @@ func WithRoutingMark(mark int) Option { opt.routingMark = mark } } + +func WithDirect() Option { + return func(opt *option) { + opt.direct = true + } +} diff --git a/component/resolver/resolver.go b/component/resolver/resolver.go index a7300bd7..e1100a31 100644 --- a/component/resolver/resolver.go +++ b/component/resolver/resolver.go @@ -15,8 +15,8 @@ var ( // DefaultResolver aim to resolve ip DefaultResolver Resolver - // MainResolver resolve ip with main domain server - MainResolver Resolver + // ProxyServerHostResolver resolve ip to proxies server host + ProxyServerHostResolver Resolver // DisableIPv6 means don't resolve ipv6 host // default value is true @@ -46,10 +46,6 @@ func ResolveIPv4(host string) (net.IP, error) { return ResolveIPv4WithResolver(host, DefaultResolver) } -func ResolveIPv4WithMain(host string) (net.IP, error) { - return ResolveIPv4WithResolver(host, MainResolver) -} - func ResolveIPv4WithResolver(host string, r Resolver) (net.IP, error) { if node := DefaultHosts.Search(host); node != nil { if ip := node.Data.(net.IP).To4(); ip != nil { @@ -69,16 +65,20 @@ func ResolveIPv4WithResolver(host string, r Resolver) (net.IP, error) { return r.ResolveIPv4(host) } - ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) - defer cancel() - ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip4", host) - if err != nil { - return nil, err - } else if len(ipAddrs) == 0 { - return nil, ErrIPNotFound + if DefaultResolver == nil { + ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) + defer cancel() + ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip4", host) + if err != nil { + return nil, err + } else if len(ipAddrs) == 0 { + return nil, ErrIPNotFound + } + + return ipAddrs[rand.Intn(len(ipAddrs))], nil } - return ipAddrs[rand.Intn(len(ipAddrs))], nil + return nil, ErrIPNotFound } // ResolveIPv6 with a host, return ipv6 @@ -86,10 +86,6 @@ func ResolveIPv6(host string) (net.IP, error) { return ResolveIPv6WithResolver(host, DefaultResolver) } -func ResolveIPv6WithMain(host string) (net.IP, error) { - return ResolveIPv6WithResolver(host, MainResolver) -} - func ResolveIPv6WithResolver(host string, r Resolver) (net.IP, error) { if DisableIPv6 { return nil, ErrIPv6Disabled @@ -113,16 +109,20 @@ func ResolveIPv6WithResolver(host string, r Resolver) (net.IP, error) { return r.ResolveIPv6(host) } - ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) - defer cancel() - ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip6", host) - if err != nil { - return nil, err - } else if len(ipAddrs) == 0 { - return nil, ErrIPNotFound + if DefaultResolver == nil { + ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) + defer cancel() + ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip6", host) + if err != nil { + return nil, err + } else if len(ipAddrs) == 0 { + return nil, ErrIPNotFound + } + + return ipAddrs[rand.Intn(len(ipAddrs))], nil } - return ipAddrs[rand.Intn(len(ipAddrs))], nil + return nil, ErrIPNotFound } // ResolveIPWithResolver same as ResolveIP, but with a resolver @@ -145,12 +145,16 @@ func ResolveIPWithResolver(host string, r Resolver) (net.IP, error) { return ip, nil } - ipAddr, err := net.ResolveIPAddr("ip", host) - if err != nil { - return nil, err + if DefaultResolver == nil { + ipAddr, err := net.ResolveIPAddr("ip", host) + if err != nil { + return nil, err + } + + return ipAddr.IP, nil } - return ipAddr.IP, nil + return nil, ErrIPNotFound } // ResolveIP with a host, return ip @@ -158,7 +162,26 @@ func ResolveIP(host string) (net.IP, error) { return ResolveIPWithResolver(host, DefaultResolver) } -// ResolveIPWithMainResolver with a host, use main resolver, return ip -func ResolveIPWithMainResolver(host string) (net.IP, error) { - return ResolveIPWithResolver(host, MainResolver) +// ResolveIPv4ProxyServerHost proxies server host only +func ResolveIPv4ProxyServerHost(host string) (net.IP, error) { + if ProxyServerHostResolver != nil { + return ResolveIPv4WithResolver(host, ProxyServerHostResolver) + } + return ResolveIPv4(host) +} + +// ResolveIPv6ProxyServerHost proxies server host only +func ResolveIPv6ProxyServerHost(host string) (net.IP, error) { + if ProxyServerHostResolver != nil { + return ResolveIPv6WithResolver(host, ProxyServerHostResolver) + } + return ResolveIPv6(host) +} + +// ResolveProxyServerHost proxies server host only +func ResolveProxyServerHost(host string) (net.IP, error) { + if ProxyServerHostResolver != nil { + return ResolveIPWithResolver(host, ProxyServerHostResolver) + } + return ResolveIP(host) } diff --git a/config/config.go b/config/config.go index 9e7447f2..1b46c0f0 100644 --- a/config/config.go +++ b/config/config.go @@ -70,17 +70,18 @@ type Controller struct { // DNS config type DNS struct { - Enable bool `yaml:"enable"` - IPv6 bool `yaml:"ipv6"` - NameServer []dns.NameServer `yaml:"nameserver"` - Fallback []dns.NameServer `yaml:"fallback"` - FallbackFilter FallbackFilter `yaml:"fallback-filter"` - Listen string `yaml:"listen"` - EnhancedMode C.DNSMode `yaml:"enhanced-mode"` - DefaultNameserver []dns.NameServer `yaml:"default-nameserver"` - FakeIPRange *fakeip.Pool - Hosts *trie.DomainTrie - NameServerPolicy map[string]dns.NameServer + Enable bool `yaml:"enable"` + IPv6 bool `yaml:"ipv6"` + NameServer []dns.NameServer `yaml:"nameserver"` + Fallback []dns.NameServer `yaml:"fallback"` + FallbackFilter FallbackFilter `yaml:"fallback-filter"` + Listen string `yaml:"listen"` + EnhancedMode C.DNSMode `yaml:"enhanced-mode"` + DefaultNameserver []dns.NameServer `yaml:"default-nameserver"` + FakeIPRange *fakeip.Pool + Hosts *trie.DomainTrie + NameServerPolicy map[string]dns.NameServer + ProxyServerNameserver []dns.NameServer } // FallbackFilter config @@ -146,18 +147,19 @@ type Config struct { } type RawDNS struct { - Enable bool `yaml:"enable"` - IPv6 bool `yaml:"ipv6"` - UseHosts bool `yaml:"use-hosts"` - NameServer []string `yaml:"nameserver"` - Fallback []string `yaml:"fallback"` - FallbackFilter RawFallbackFilter `yaml:"fallback-filter"` - Listen string `yaml:"listen"` - EnhancedMode C.DNSMode `yaml:"enhanced-mode"` - FakeIPRange string `yaml:"fake-ip-range"` - FakeIPFilter []string `yaml:"fake-ip-filter"` - DefaultNameserver []string `yaml:"default-nameserver"` - NameServerPolicy map[string]string `yaml:"nameserver-policy"` + Enable bool `yaml:"enable"` + IPv6 bool `yaml:"ipv6"` + UseHosts bool `yaml:"use-hosts"` + NameServer []string `yaml:"nameserver"` + Fallback []string `yaml:"fallback"` + FallbackFilter RawFallbackFilter `yaml:"fallback-filter"` + Listen string `yaml:"listen"` + EnhancedMode C.DNSMode `yaml:"enhanced-mode"` + FakeIPRange string `yaml:"fake-ip-range"` + FakeIPFilter []string `yaml:"fake-ip-filter"` + DefaultNameserver []string `yaml:"default-nameserver"` + NameServerPolicy map[string]string `yaml:"nameserver-policy"` + ProxyServerNameserver []string `yaml:"proxy-server-nameserver"` } type RawFallbackFilter struct { @@ -805,6 +807,10 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie, rules []C.Rule) (*DNS, return nil, err } + if dnsCfg.ProxyServerNameserver, err = parseNameServer(cfg.ProxyServerNameserver); err != nil { + return nil, err + } + if len(cfg.DefaultNameserver) == 0 { return nil, errors.New("default nameserver should have at least one nameserver") } diff --git a/dns/resolver.go b/dns/resolver.go index aa854153..c5e32867 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -41,6 +41,7 @@ type Resolver struct { group singleflight.Group lruCache *cache.LruCache policy *trie.DomainTrie + proxyServer []dnsClient } // ResolveIP request with TypeA and TypeAAAA, priority return TypeA @@ -300,6 +301,11 @@ func (r *Resolver) asyncExchange(ctx context.Context, client []dnsClient, msg *D return ch } +// HasProxyServer has proxy server dns client +func (r *Resolver) HasProxyServer() bool { + return len(r.main) > 0 +} + type NameServer struct { Net string Addr string @@ -318,6 +324,7 @@ type FallbackFilter struct { type Config struct { Main, Fallback []NameServer Default []NameServer + ProxyServer []NameServer IPv6 bool EnhancedMode C.DNSMode FallbackFilter FallbackFilter @@ -343,6 +350,10 @@ func NewResolver(config Config) *Resolver { r.fallback = transform(config.Fallback, defaultResolver) } + if len(config.ProxyServer) != 0 { + r.proxyServer = transform(config.ProxyServer, defaultResolver) + } + if len(config.Policy) != 0 { r.policy = trie.New() for domain, nameserver := range config.Policy { @@ -376,10 +387,10 @@ func NewResolver(config Config) *Resolver { return r } -func NewMainResolver(old *Resolver) *Resolver { +func NewProxyServerHostResolver(old *Resolver) *Resolver { r := &Resolver{ ipv6: old.ipv6, - main: old.main, + main: old.proxyServer, lruCache: old.lruCache, hosts: old.hosts, policy: old.policy, diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 9a5c7c2b..bcdafb89 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -131,12 +131,13 @@ func updateDNS(c *config.DNS, t *config.Tun) { Domain: c.FallbackFilter.Domain, GeoSite: c.FallbackFilter.GeoSite, }, - Default: c.DefaultNameserver, - Policy: c.NameServerPolicy, + Default: c.DefaultNameserver, + Policy: c.NameServerPolicy, + ProxyServer: c.ProxyServerNameserver, } r := dns.NewResolver(cfg) - mr := dns.NewMainResolver(r) + pr := dns.NewProxyServerHostResolver(r) m := dns.NewEnhancer(cfg) // reuse cache of old host mapper @@ -145,9 +146,12 @@ func updateDNS(c *config.DNS, t *config.Tun) { } resolver.DefaultResolver = r - resolver.MainResolver = mr resolver.DefaultHostMapper = m + if pr.HasProxyServer() { + resolver.ProxyServerHostResolver = pr + } + if t.Enable { resolver.DefaultLocalServer = dns.NewLocalServer(r, m) } @@ -157,9 +161,9 @@ func updateDNS(c *config.DNS, t *config.Tun) { } else { if !t.Enable { resolver.DefaultResolver = nil - resolver.MainResolver = nil resolver.DefaultHostMapper = nil resolver.DefaultLocalServer = nil + resolver.ProxyServerHostResolver = nil } dns.ReCreateServer("", nil, nil) } @@ -365,7 +369,9 @@ func updateIPTables(cfg *config.Config) { log.Infoln("[IPTABLES] Setting iptables completed") } -func Cleanup() { +func Shutdown() { P.Cleanup() tproxy.CleanupTProxyIPTables() + + log.Warnln("Clash shutting down") } diff --git a/hub/hub.go b/hub/hub.go index cde0bb57..471fdb5e 100644 --- a/hub/hub.go +++ b/hub/hub.go @@ -48,7 +48,3 @@ func Parse(options ...Option) error { executor.ApplyConfig(cfg, true) return nil } - -func Cleanup() { - executor.Cleanup() -} diff --git a/listener/tun/device/device.go b/listener/tun/device/device.go index 70115cbd..73b03dee 100644 --- a/listener/tun/device/device.go +++ b/listener/tun/device/device.go @@ -29,4 +29,7 @@ type Device interface { // UseIOBased work for other ip stack UseIOBased() error + + // Wait waits for the device to close. + Wait() } diff --git a/listener/tun/device/iobased/endpoint.go b/listener/tun/device/iobased/endpoint.go index a187491e..c0942d10 100644 --- a/listener/tun/device/iobased/endpoint.go +++ b/listener/tun/device/iobased/endpoint.go @@ -103,7 +103,7 @@ func (e *Endpoint) dispatchLoop(cancel context.CancelFunc) { case header.IPv6Version: e.InjectInbound(header.IPv6ProtocolNumber, pkt) } - pkt.DecRef() /* release */ + pkt.DecRef() } } diff --git a/listener/tun/device/tun/tun_wireguard.go b/listener/tun/device/tun/tun_wireguard.go index 50db6511..30398b55 100644 --- a/listener/tun/device/tun/tun_wireguard.go +++ b/listener/tun/device/tun/tun_wireguard.go @@ -106,6 +106,9 @@ func (t *TUN) Write(packet []byte) (int, error) { } func (t *TUN) Close() error { + if t.Endpoint != nil { + t.Endpoint.Close() + } return t.nt.Close() } diff --git a/listener/tun/ipstack/gvisor/adapter/handler.go b/listener/tun/ipstack/gvisor/adapter/handler.go index 2878b713..715f6636 100644 --- a/listener/tun/ipstack/gvisor/adapter/handler.go +++ b/listener/tun/ipstack/gvisor/adapter/handler.go @@ -3,6 +3,12 @@ package adapter // Handler is a TCP/UDP connection handler that implements // HandleTCPConn and HandleUDPConn methods. type Handler interface { - HandleTCPConn(TCPConn) - HandleUDPConn(UDPConn) + HandleTCP(TCPConn) + HandleUDP(UDPConn) } + +// TCPHandleFunc handles incoming TCP connection. +type TCPHandleFunc func(TCPConn) + +// UDPHandleFunc handles incoming UDP connection. +type UDPHandleFunc func(UDPConn) diff --git a/listener/tun/ipstack/gvisor/handler.go b/listener/tun/ipstack/gvisor/handler.go index 76aab2f5..6365234e 100644 --- a/listener/tun/ipstack/gvisor/handler.go +++ b/listener/tun/ipstack/gvisor/handler.go @@ -24,7 +24,7 @@ type GVHandler struct { UDPIn chan<- *inbound.PacketAdapter } -func (gh *GVHandler) HandleTCPConn(tunConn adapter.TCPConn) { +func (gh *GVHandler) HandleTCP(tunConn adapter.TCPConn) { id := tunConn.ID() rAddr := &net.UDPAddr{ @@ -77,7 +77,7 @@ func (gh *GVHandler) HandleTCPConn(tunConn adapter.TCPConn) { gh.TCPIn <- inbound.NewSocket(socks5.ParseAddrToSocksAddr(rAddr), tunConn, C.TUN) } -func (gh *GVHandler) HandleUDPConn(tunConn adapter.UDPConn) { +func (gh *GVHandler) HandleUDP(tunConn adapter.UDPConn) { id := tunConn.ID() rAddr := &net.UDPAddr{ diff --git a/listener/tun/ipstack/gvisor/nic.go b/listener/tun/ipstack/gvisor/nic.go index fb8ac1a2..0ca96778 100644 --- a/listener/tun/ipstack/gvisor/nic.go +++ b/listener/tun/ipstack/gvisor/nic.go @@ -3,14 +3,13 @@ package gvisor import ( "fmt" + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( - // defaultNICID is the ID of default NIC used by DefaultStack. - defaultNICID tcpip.NICID = 0x01 - // nicPromiscuousModeEnabled is the value used by stack to enable // or disable NIC's promiscuous mode. nicPromiscuousModeEnabled = true @@ -21,9 +20,9 @@ const ( ) // withCreatingNIC creates NIC for stack. -func withCreatingNIC(ep stack.LinkEndpoint) Option { - return func(s *gvStack) error { - if err := s.CreateNICWithOptions(s.nicID, ep, +func withCreatingNIC(nicID tcpip.NICID, ep stack.LinkEndpoint) option.Option { + return func(s *stack.Stack) error { + if err := s.CreateNICWithOptions(nicID, ep, stack.NICOptions{ Disabled: false, // If no queueing discipline was specified @@ -37,21 +36,21 @@ func withCreatingNIC(ep stack.LinkEndpoint) Option { } } -// withPromiscuousMode sets promiscuous mode in the given NIC. -func withPromiscuousMode(v bool) Option { - return func(s *gvStack) error { - if err := s.SetPromiscuousMode(s.nicID, v); err != nil { +// withPromiscuousMode sets promiscuous mode in the given NICs. +func withPromiscuousMode(nicID tcpip.NICID, v bool) option.Option { + return func(s *stack.Stack) error { + if err := s.SetPromiscuousMode(nicID, v); err != nil { return fmt.Errorf("set promiscuous mode: %s", err) } return nil } } -// withSpoofing sets address spoofing in the given NIC, allowing +// withSpoofing sets address spoofing in the given NICs, allowing // endpoints to bind to any address in the NIC. -func withSpoofing(v bool) Option { - return func(s *gvStack) error { - if err := s.SetSpoofing(s.nicID, v); err != nil { +func withSpoofing(nicID tcpip.NICID, v bool) option.Option { + return func(s *stack.Stack) error { + if err := s.SetSpoofing(nicID, v); err != nil { return fmt.Errorf("set spoofing: %s", err) } return nil diff --git a/listener/tun/ipstack/gvisor/opts.go b/listener/tun/ipstack/gvisor/option/option.go similarity index 93% rename from listener/tun/ipstack/gvisor/opts.go rename to listener/tun/ipstack/gvisor/option/option.go index 7fd5a65b..34508e0d 100644 --- a/listener/tun/ipstack/gvisor/opts.go +++ b/listener/tun/ipstack/gvisor/option/option.go @@ -1,4 +1,4 @@ -package gvisor +package option import ( "fmt" @@ -7,6 +7,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" ) @@ -56,11 +57,11 @@ const ( tcpDefaultBufferSize = 212 << 10 // 212 KiB ) -type Option func(*gvStack) error +type Option func(*stack.Stack) error // WithDefault sets all default values for stack. func WithDefault() Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { opts := []Option{ WithDefaultTTL(defaultTimeToLive), WithForwarding(ipForwardingEnabled), @@ -110,7 +111,7 @@ func WithDefault() Option { // WithDefaultTTL sets the default TTL used by stack. func WithDefaultTTL(ttl uint8) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { opt := tcpip.DefaultTTLOption(ttl) if err := s.SetNetworkProtocolOption(ipv4.ProtocolNumber, &opt); err != nil { return fmt.Errorf("set ipv4 default TTL: %s", err) @@ -124,7 +125,7 @@ func WithDefaultTTL(ttl uint8) Option { // WithForwarding sets packet forwarding between NICs for IPv4 & IPv6. func WithForwarding(v bool) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, v); err != nil { return fmt.Errorf("set ipv4 forwarding: %s", err) } @@ -138,7 +139,7 @@ func WithForwarding(v bool) Option { // WithICMPBurst sets the number of ICMP messages that can be sent // in a single burst. func WithICMPBurst(burst int) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { s.SetICMPBurst(burst) return nil } @@ -147,7 +148,7 @@ func WithICMPBurst(burst int) Option { // WithICMPLimit sets the maximum number of ICMP messages permitted // by rate limiter. func WithICMPLimit(limit rate.Limit) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { s.SetICMPLimit(limit) return nil } @@ -155,7 +156,7 @@ func WithICMPLimit(limit rate.Limit) Option { // WithTCPBufferSizeRange sets the receive and send buffer size range for TCP. func WithTCPBufferSizeRange(a, b, c int) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { rcvOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: a, Default: b, Max: c} if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvOpt); err != nil { return fmt.Errorf("set TCP receive buffer size range: %s", err) @@ -170,7 +171,7 @@ func WithTCPBufferSizeRange(a, b, c int) Option { // WithTCPCongestionControl sets the current congestion control algorithm. func WithTCPCongestionControl(cc string) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { opt := tcpip.CongestionControlOption(cc) if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { return fmt.Errorf("set TCP congestion control algorithm: %s", err) @@ -181,7 +182,7 @@ func WithTCPCongestionControl(cc string) Option { // WithTCPDelay enables or disables Nagle's algorithm in TCP. func WithTCPDelay(v bool) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { opt := tcpip.TCPDelayEnabled(v) if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { return fmt.Errorf("set TCP delay: %s", err) @@ -192,7 +193,7 @@ func WithTCPDelay(v bool) Option { // WithTCPModerateReceiveBuffer sets receive buffer moderation for TCP. func WithTCPModerateReceiveBuffer(v bool) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { opt := tcpip.TCPModerateReceiveBufferOption(v) if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { return fmt.Errorf("set TCP moderate receive buffer: %s", err) @@ -203,7 +204,7 @@ func WithTCPModerateReceiveBuffer(v bool) Option { // WithTCPSACKEnabled sets the SACK option for TCP. func WithTCPSACKEnabled(v bool) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { opt := tcpip.TCPSACKEnabled(v) if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { return fmt.Errorf("set TCP SACK: %s", err) @@ -214,7 +215,7 @@ func WithTCPSACKEnabled(v bool) Option { // WithTCPRecovery sets the recovery option for TCP. func WithTCPRecovery(v tcpip.TCPRecovery) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { return fmt.Errorf("set TCP Recovery: %s", err) } diff --git a/listener/tun/ipstack/gvisor/icmp.go b/listener/tun/ipstack/gvisor/route.go similarity index 51% rename from listener/tun/ipstack/gvisor/icmp.go rename to listener/tun/ipstack/gvisor/route.go index 8b56d397..5a3d3bf4 100644 --- a/listener/tun/ipstack/gvisor/icmp.go +++ b/listener/tun/ipstack/gvisor/route.go @@ -1,22 +1,23 @@ package gvisor import ( + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) -func withICMPHandler() Option { - return func(s *gvStack) error { - // Add default route table for IPv4 and IPv6. - // This will handle all incoming ICMP packets. +func withRouteTable(nicID tcpip.NICID) option.Option { + return func(s *stack.Stack) error { s.SetRouteTable([]tcpip.Route{ { Destination: header.IPv4EmptySubnet, - NIC: s.nicID, + NIC: nicID, }, { Destination: header.IPv6EmptySubnet, - NIC: s.nicID, + NIC: nicID, }, }) return nil diff --git a/listener/tun/ipstack/gvisor/stack.go b/listener/tun/ipstack/gvisor/stack.go index 104c0300..9061995d 100644 --- a/listener/tun/ipstack/gvisor/stack.go +++ b/listener/tun/ipstack/gvisor/stack.go @@ -5,6 +5,7 @@ import ( "github.com/Dreamacro/clash/listener/tun/device" "github.com/Dreamacro/clash/listener/tun/ipstack" "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter" + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -18,23 +19,23 @@ import ( type gvStack struct { *stack.Stack device device.Device - - handler adapter.Handler - nicID tcpip.NICID } func (s *gvStack) Close() error { + var err error + if s.device != nil { + err = s.device.Close() + s.device.Wait() + } if s.Stack != nil { s.Stack.Close() + s.Stack.Wait() } - if s.device != nil { - _ = s.device.Close() - } - return nil + return err } // New allocates a new *gvStack with given options. -func New(device device.Device, handler adapter.Handler, opts ...Option) (ipstack.Stack, error) { +func New(device device.Device, handler adapter.Handler, opts ...option.Option) (ipstack.Stack, error) { s := &gvStack{ Stack: stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ @@ -49,19 +50,15 @@ func New(device device.Device, handler adapter.Handler, opts ...Option) (ipstack }, }), - device: device, - handler: handler, - nicID: defaultNICID, + device: device, } - opts = append(opts, - // Important: We must initiate transport protocol handlers - // before creating NIC, otherwise NIC would dispatch packets - // to stack and cause race condition. - withICMPHandler(), withTCPHandler(), withUDPHandler(), + // Generate unique NIC id. + nicID := tcpip.NICID(s.Stack.UniqueID()) - // Create stack NIC and then bind link endpoint. - withCreatingNIC(device.(stack.LinkEndpoint)), + opts = append(opts, + // Create stack NIC and then bind link endpoint to it. + withCreatingNIC(nicID, device), // In the past we did s.AddAddressRange to assign 0.0.0.0/0 // onto the interface. We need that to be able to terminate @@ -70,27 +67,34 @@ func New(device device.Device, handler adapter.Handler, opts ...Option) (ipstack // Promiscuous mode. https://github.com/google/gvisor/issues/3876 // // Ref: https://github.com/cloudflare/slirpnetstack/blob/master/stack.go - withPromiscuousMode(nicPromiscuousModeEnabled), + withPromiscuousMode(nicID, nicPromiscuousModeEnabled), - // Enable spoofing if a stack may send packets from unowned addresses. - // This change required changes to some netgophers since previously, - // promiscuous mode was enough to let the netstack respond to all - // incoming packets regardless of the packet's destination address. Now - // that a stack.Route is not held for each incoming packet, finding a route - // may fail with local addresses we don't own but accepted packets for - // while in promiscuous mode. Since we also want to be able to send from - // any address (in response the received promiscuous mode packets), we need - // to enable spoofing. + // Enable spoofing if a stack may send packets from unowned + // addresses. This change required changes to some netgophers + // since previously, promiscuous mode was enough to let the + // netstack respond to all incoming packets regardless of the + // packet's destination address. Now that a stack.Route is not + // held for each incoming packet, finding a route may fail with + // local addresses we don't own but accepted packets for while + // in promiscuous mode. Since we also want to be able to send + // from any address (in response the received promiscuous mode + // packets), we need to enable spoofing. // // Ref: https://github.com/google/gvisor/commit/8c0701462a84ff77e602f1626aec49479c308127 - withSpoofing(nicSpoofingEnabled), + withSpoofing(nicID, nicSpoofingEnabled), + + // Add default route table for IPv4 and IPv6. This will handle + // all incoming ICMP packets. + withRouteTable(nicID), + + // Initiate transport protocol (TCP/UDP) with given handler. + withTCPHandler(handler.HandleTCP), withUDPHandler(handler.HandleUDP), ) for _, opt := range opts { - if err := opt(s); err != nil { + if err := opt(s.Stack); err != nil { return nil, err } } - return s, nil } diff --git a/listener/tun/ipstack/gvisor/tcp.go b/listener/tun/ipstack/gvisor/tcp.go index 8b8277e0..8bffb932 100644 --- a/listener/tun/ipstack/gvisor/tcp.go +++ b/listener/tun/ipstack/gvisor/tcp.go @@ -4,6 +4,9 @@ import ( "fmt" "time" + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter" + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -37,9 +40,9 @@ const ( tcpKeepaliveInterval = 30 * time.Second ) -func withTCPHandler() Option { - return func(s *gvStack) error { - tcpForwarder := tcp.NewForwarder(s.Stack, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) { +func withTCPHandler(handle adapter.TCPHandleFunc) option.Option { + return func(s *stack.Stack) error { + tcpForwarder := tcp.NewForwarder(s, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) { var wq waiter.Queue ep, err := r.CreateEndpoint(&wq) if err != nil { @@ -55,7 +58,7 @@ func withTCPHandler() Option { TCPConn: gonet.NewTCPConn(&wq, ep), id: r.ID(), } - s.handler.HandleTCPConn(conn) + handle(conn) }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) return nil diff --git a/listener/tun/ipstack/gvisor/udp.go b/listener/tun/ipstack/gvisor/udp.go index 6efbd204..688583a0 100644 --- a/listener/tun/ipstack/gvisor/udp.go +++ b/listener/tun/ipstack/gvisor/udp.go @@ -5,6 +5,7 @@ import ( "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter" + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -12,9 +13,9 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -func withUDPHandler() Option { - return func(s *gvStack) error { - udpForwarder := udp.NewForwarder(s.Stack, func(r *udp.ForwarderRequest) { +func withUDPHandler(handle adapter.UDPHandleFunc) option.Option { + return func(s *stack.Stack) error { + udpForwarder := udp.NewForwarder(s, func(r *udp.ForwarderRequest) { var wq waiter.Queue ep, err := r.CreateEndpoint(&wq) if err != nil { @@ -23,10 +24,10 @@ func withUDPHandler() Option { } conn := &udpConn{ - UDPConn: gonet.NewUDPConn(s.Stack, &wq, ep), + UDPConn: gonet.NewUDPConn(s, &wq, ep), id: r.ID(), } - s.handler.HandleUDPConn(conn) + handle(conn) }) s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) return nil diff --git a/listener/tun/ipstack/system/stack.go b/listener/tun/ipstack/system/stack.go index ad3f465d..d8b250ba 100644 --- a/listener/tun/ipstack/system/stack.go +++ b/listener/tun/ipstack/system/stack.go @@ -36,8 +36,6 @@ func (s sysStack) Close() error { return nil } -var ipv4LoopBack = netip.MustParsePrefix("127.0.0.0/8") - func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Prefix, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.Stack, error) { var ( gateway = tunAddress.Masked().Addr().Next() @@ -71,12 +69,6 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref rAddrIp, _ := netip.AddrFromSlice(rAddr.IP) rAddrPort := netip.AddrPortFrom(rAddrIp, uint16(rAddr.Port)) - if ipv4LoopBack.Contains(rAddrIp) { - conn.Close() - - continue - } - if D.ShouldHijackDns(dnsAddr, rAddrPort) { go func() { log.Debugln("[TUN] hijack dns tcp: %s", rAddrPort.String()) @@ -149,12 +141,6 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref rAddrIp, _ := netip.AddrFromSlice(rAddr.IP) rAddrPort := netip.AddrPortFrom(rAddrIp, uint16(rAddr.Port)) - if ipv4LoopBack.Contains(rAddrIp) { - pool.Put(buf) - - continue - } - if D.ShouldHijackDns(dnsAddr, rAddrPort) { go func() { defer pool.Put(buf) diff --git a/listener/tun/tun_adapter.go b/listener/tun/tun_adapter.go index 1e1f6ad4..17058be7 100644 --- a/listener/tun/tun_adapter.go +++ b/listener/tun/tun_adapter.go @@ -13,6 +13,7 @@ import ( "github.com/Dreamacro/clash/listener/tun/ipstack" "github.com/Dreamacro/clash/listener/tun/ipstack/commons" "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor" + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" "github.com/Dreamacro/clash/listener/tun/ipstack/system" "github.com/Dreamacro/clash/log" "net/netip" @@ -72,7 +73,7 @@ func New(tunConf *config.Tun, dnsConf *config.DNS, tcpIn chan<- C.ConnContext, u DNSAdds: tunConf.DNSHijack, TCPIn: tcpIn, UDPIn: udpIn, }, - gvisor.WithDefault(), + option.WithDefault(), ) if err != nil { diff --git a/main.go b/main.go index e4b85b9b..10f4307b 100644 --- a/main.go +++ b/main.go @@ -106,13 +106,9 @@ func main() { log.Fatalln("Parse config error: %s", err.Error()) } + defer executor.Shutdown() + sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) <-sigCh - - // cleanup - log.Warnln("Clash cleanup") - hub.Cleanup() - - log.Warnln("Clash shutting down") }