diff --git a/README.md b/README.md index b5328ce7..3be09c9d 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,7 @@ external-controller: 127.0.0.1:9090 # experimental feature experimental: ignore-resolve-fail: true # ignore dns resolve fail, default value is true + # interface-name: en0 # outbound interface name # authentication of local SOCKS5/HTTP(S) server # authentication: @@ -130,6 +131,9 @@ experimental: # enable: true # set true to enable dns (default is false) # ipv6: false # default is false # listen: 0.0.0.0:53 + # # default-nameserver: # resolve dns nameserver host, should fill pure IP + # # - 114.114.114.114 + # # - 8.8.8.8 # enhanced-mode: redir-host # or fake-ip # # fake-ip-range: 198.18.0.1/16 # if you don't know what it is, don't change it # fake-ip-filter: # fake ip white domain list diff --git a/adapters/outbound/direct.go b/adapters/outbound/direct.go index 1da750bc..c118425d 100644 --- a/adapters/outbound/direct.go +++ b/adapters/outbound/direct.go @@ -18,7 +18,7 @@ func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, address = net.JoinHostPort(metadata.DstIP.String(), metadata.DstPort) } - c, err := dialContext(ctx, "tcp", address) + c, err := dialer.DialContext(ctx, "tcp", address) if err != nil { return nil, err } diff --git a/adapters/outbound/http.go b/adapters/outbound/http.go index 5223f27d..e0ae0c41 100644 --- a/adapters/outbound/http.go +++ b/adapters/outbound/http.go @@ -13,6 +13,7 @@ import ( "net/url" "strconv" + "github.com/Dreamacro/clash/component/dialer" C "github.com/Dreamacro/clash/constant" ) @@ -35,7 +36,7 @@ type HttpOption struct { } func (h *Http) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { - c, err := dialContext(ctx, "tcp", h.addr) + c, err := dialer.DialContext(ctx, "tcp", h.addr) if err == nil && h.tlsConfig != nil { cc := tls.Client(c, h.tlsConfig) err = cc.Handshake() diff --git a/adapters/outbound/shadowsocks.go b/adapters/outbound/shadowsocks.go index f0313e99..245d00ce 100644 --- a/adapters/outbound/shadowsocks.go +++ b/adapters/outbound/shadowsocks.go @@ -60,7 +60,7 @@ type v2rayObfsOption struct { } func (ss *ShadowSocks) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { - c, err := dialContext(ctx, "tcp", ss.server) + c, err := dialer.DialContext(ctx, "tcp", ss.server) if err != nil { return nil, fmt.Errorf("%s connect error: %w", ss.server, err) } diff --git a/adapters/outbound/snell.go b/adapters/outbound/snell.go index 4626bbef..f96d8fff 100644 --- a/adapters/outbound/snell.go +++ b/adapters/outbound/snell.go @@ -7,6 +7,7 @@ import ( "strconv" "github.com/Dreamacro/clash/common/structure" + "github.com/Dreamacro/clash/component/dialer" obfs "github.com/Dreamacro/clash/component/simple-obfs" "github.com/Dreamacro/clash/component/snell" C "github.com/Dreamacro/clash/constant" @@ -28,7 +29,7 @@ type SnellOption struct { } func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { - c, err := dialContext(ctx, "tcp", s.server) + c, err := dialer.DialContext(ctx, "tcp", s.server) if err != nil { return nil, fmt.Errorf("%s connect error: %w", s.server, err) } diff --git a/adapters/outbound/socks5.go b/adapters/outbound/socks5.go index 8c5b61f6..2c47bb44 100644 --- a/adapters/outbound/socks5.go +++ b/adapters/outbound/socks5.go @@ -36,7 +36,7 @@ type Socks5Option struct { } func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { - c, err := dialContext(ctx, "tcp", ss.addr) + c, err := dialer.DialContext(ctx, "tcp", ss.addr) if err == nil && ss.tls { cc := tls.Client(c, ss.tlsConfig) @@ -64,7 +64,7 @@ func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, err error) { ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout) defer cancel() - c, err := dialContext(ctx, "tcp", ss.addr) + c, err := dialer.DialContext(ctx, "tcp", ss.addr) if err != nil { err = fmt.Errorf("%s connect error: %w", ss.addr, err) return diff --git a/adapters/outbound/util.go b/adapters/outbound/util.go index e082954b..4a92c24c 100644 --- a/adapters/outbound/util.go +++ b/adapters/outbound/util.go @@ -2,7 +2,6 @@ package outbound import ( "bytes" - "context" "crypto/tls" "fmt" "net" @@ -11,10 +10,9 @@ import ( "sync" "time" - "github.com/Dreamacro/clash/component/dialer" + "github.com/Dreamacro/clash/component/resolver" "github.com/Dreamacro/clash/component/socks5" C "github.com/Dreamacro/clash/constant" - "github.com/Dreamacro/clash/dns" ) const ( @@ -88,92 +86,13 @@ func serializesSocksAddr(metadata *C.Metadata) []byte { return bytes.Join(buf, nil) } -func dialContext(ctx context.Context, network, address string) (net.Conn, error) { - host, port, err := net.SplitHostPort(address) - if err != nil { - return nil, err - } - - returned := make(chan struct{}) - defer close(returned) - - type dialResult struct { - net.Conn - error - resolved bool - ipv6 bool - done bool - } - results := make(chan dialResult) - var primary, fallback dialResult - - startRacer := func(ctx context.Context, host string, ipv6 bool) { - dialer := dialer.Dialer() - result := dialResult{ipv6: ipv6, done: true} - defer func() { - select { - case results <- result: - case <-returned: - if result.Conn != nil { - result.Conn.Close() - } - } - }() - - var ip net.IP - if ipv6 { - ip, result.error = dns.ResolveIPv6(host) - } else { - ip, result.error = dns.ResolveIPv4(host) - } - if result.error != nil { - return - } - result.resolved = true - - if ipv6 { - result.Conn, result.error = dialer.DialContext(ctx, "tcp6", net.JoinHostPort(ip.String(), port)) - } else { - result.Conn, result.error = dialer.DialContext(ctx, "tcp4", net.JoinHostPort(ip.String(), port)) - } - } - - go startRacer(ctx, host, false) - go startRacer(ctx, host, true) - - for { - select { - case res := <-results: - if res.error == nil { - return res.Conn, nil - } - - if !res.ipv6 { - primary = res - } else { - fallback = res - } - - if primary.done && fallback.done { - if primary.resolved { - return nil, primary.error - } else if fallback.resolved { - return nil, fallback.error - } else { - return nil, primary.error - } - } - } - } -} - func resolveUDPAddr(network, address string) (*net.UDPAddr, error) { host, port, err := net.SplitHostPort(address) if err != nil { return nil, err } - ip, err := dns.ResolveIP(host) + ip, err := resolver.ResolveIP(host) if err != nil { return nil, err } diff --git a/adapters/outbound/vmess.go b/adapters/outbound/vmess.go index 401f3b87..197ba9a5 100644 --- a/adapters/outbound/vmess.go +++ b/adapters/outbound/vmess.go @@ -7,6 +7,7 @@ import ( "strconv" "strings" + "github.com/Dreamacro/clash/component/dialer" "github.com/Dreamacro/clash/component/vmess" C "github.com/Dreamacro/clash/constant" ) @@ -33,7 +34,7 @@ type VmessOption struct { } func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { - c, err := dialContext(ctx, "tcp", v.server) + c, err := dialer.DialContext(ctx, "tcp", v.server) if err != nil { return nil, fmt.Errorf("%s connect error", v.server) } @@ -45,7 +46,7 @@ func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, func (v *Vmess) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout) defer cancel() - c, err := dialContext(ctx, "tcp", v.server) + c, err := dialer.DialContext(ctx, "tcp", v.server) if err != nil { return nil, fmt.Errorf("%s connect error", v.server) } diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index bb8f122e..6dcecdc1 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -2,7 +2,10 @@ package dialer import ( "context" + "errors" "net" + + "github.com/Dreamacro/clash/component/resolver" ) func Dialer() *net.Dialer { @@ -28,11 +31,124 @@ func Dial(network, address string) (net.Conn, error) { } func DialContext(ctx context.Context, network, address string) (net.Conn, error) { - dailer := Dialer() - return dailer.DialContext(ctx, network, address) + switch network { + case "tcp4", "tcp6", "udp4", "udp6": + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + + dialer := Dialer() + + var ip net.IP + switch network { + case "tcp4", "udp4": + ip, err = resolver.ResolveIPv4(host) + default: + ip, err = resolver.ResolveIPv6(host) + } + + if err != nil { + return nil, err + } + + if DialHook != nil { + DialHook(dialer, network, ip) + } + return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port)) + case "tcp", "udp": + return dualStackDailContext(ctx, network, address) + default: + return nil, errors.New("network invalid") + } } func ListenPacket(network, address string) (net.PacketConn, error) { lc := ListenConfig() + + if ListenPacketHook != nil && address == "" { + ip := ListenPacketHook() + if ip != nil { + address = net.JoinHostPort(ip.String(), "0") + } + } return lc.ListenPacket(context.Background(), network, address) } + +func dualStackDailContext(ctx context.Context, network, address string) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + + returned := make(chan struct{}) + defer close(returned) + + type dialResult struct { + net.Conn + error + resolved bool + ipv6 bool + done bool + } + results := make(chan dialResult) + var primary, fallback dialResult + + startRacer := func(ctx context.Context, network, host string, ipv6 bool) { + dialer := Dialer() + result := dialResult{ipv6: ipv6, done: true} + defer func() { + select { + case results <- result: + case <-returned: + if result.Conn != nil { + result.Conn.Close() + } + } + }() + + var ip net.IP + if ipv6 { + ip, result.error = resolver.ResolveIPv6(host) + } else { + ip, result.error = resolver.ResolveIPv4(host) + } + if result.error != nil { + return + } + result.resolved = true + + if DialHook != nil { + DialHook(dialer, network, ip) + } + result.Conn, result.error = dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port)) + } + + go startRacer(ctx, network+"4", host, false) + go startRacer(ctx, network+"6", host, true) + + for { + select { + case res := <-results: + if res.error == nil { + return res.Conn, nil + } + + if !res.ipv6 { + primary = res + } else { + fallback = res + } + + if primary.done && fallback.done { + if primary.resolved { + return nil, primary.error + } else if fallback.resolved { + return nil, fallback.error + } else { + return nil, primary.error + } + } + } + } +} diff --git a/component/dialer/hook.go b/component/dialer/hook.go index 4ef71143..3546a393 100644 --- a/component/dialer/hook.go +++ b/component/dialer/hook.go @@ -1,11 +1,142 @@ package dialer -import "net" +import ( + "errors" + "net" + "time" -type DialerHookFunc = func(*net.Dialer) + "github.com/Dreamacro/clash/common/singledo" +) + +type DialerHookFunc = func(dialer *net.Dialer) +type DialHookFunc = func(dialer *net.Dialer, network string, ip net.IP) type ListenConfigHookFunc = func(*net.ListenConfig) +type ListenPacketHookFunc = func() net.IP var ( - DialerHook DialerHookFunc = nil - ListenConfigHook ListenConfigHookFunc = nil + DialerHook DialerHookFunc + DialHook DialHookFunc + ListenConfigHook ListenConfigHookFunc + ListenPacketHook ListenPacketHookFunc ) + +var ( + ErrAddrNotFound = errors.New("addr not found") + ErrNetworkNotSupport = errors.New("network not support") +) + +func lookupTCPAddr(ip net.IP, addrs []net.Addr) (*net.TCPAddr, error) { + ipv4 := ip.To4() != nil + + for _, elm := range addrs { + addr, ok := elm.(*net.IPNet) + if !ok { + continue + } + + addrV4 := addr.IP.To4() != nil + + if addrV4 && ipv4 { + return &net.TCPAddr{IP: addr.IP, Port: 0}, nil + } else if !addrV4 && !ipv4 { + return &net.TCPAddr{IP: addr.IP, Port: 0}, nil + } + } + + return nil, ErrAddrNotFound +} + +func lookupUDPAddr(ip net.IP, addrs []net.Addr) (*net.UDPAddr, error) { + ipv4 := ip.To4() != nil + + for _, elm := range addrs { + addr, ok := elm.(*net.IPNet) + if !ok { + continue + } + + addrV4 := addr.IP.To4() != nil + + if addrV4 && ipv4 { + return &net.UDPAddr{IP: addr.IP, Port: 0}, nil + } else if !addrV4 && !ipv4 { + return &net.UDPAddr{IP: addr.IP, Port: 0}, nil + } + } + + return nil, ErrAddrNotFound +} + +func ListenPacketWithInterface(name string) ListenPacketHookFunc { + single := singledo.NewSingle(5 * time.Second) + + return func() net.IP { + elm, err, _ := single.Do(func() (interface{}, error) { + iface, err := net.InterfaceByName(name) + if err != nil { + return nil, err + } + + addrs, err := iface.Addrs() + if err != nil { + return nil, err + } + + return addrs, nil + }) + + if err != nil { + return nil + } + + addrs := elm.([]net.Addr) + + for _, elm := range addrs { + addr, ok := elm.(*net.IPNet) + if !ok || addr.IP.To4() == nil { + continue + } + + return addr.IP + } + + return nil + } +} + +func DialerWithInterface(name string) DialHookFunc { + single := singledo.NewSingle(5 * time.Second) + + return func(dialer *net.Dialer, network string, ip net.IP) { + elm, err, _ := single.Do(func() (interface{}, error) { + iface, err := net.InterfaceByName(name) + if err != nil { + return nil, err + } + + addrs, err := iface.Addrs() + if err != nil { + return nil, err + } + + return addrs, nil + }) + + if err != nil { + return + } + + addrs := elm.([]net.Addr) + + switch network { + case "tcp", "tcp4", "tcp6": + if addr, err := lookupTCPAddr(ip, addrs); err == nil { + dialer.LocalAddr = addr + } + case "udp", "udp4", "udp6": + if addr, err := lookupUDPAddr(ip, addrs); err == nil { + dialer.LocalAddr = addr + } + } + } +} diff --git a/dns/iputil.go b/component/resolver/resolver.go similarity index 70% rename from dns/iputil.go rename to component/resolver/resolver.go index f186e5a0..f7fc1e89 100644 --- a/dns/iputil.go +++ b/component/resolver/resolver.go @@ -1,16 +1,32 @@ -package dns +package resolver import ( "errors" "net" "strings" + + trie "github.com/Dreamacro/clash/component/domain-trie" ) var ( - errIPNotFound = errors.New("couldn't find ip") - errIPVersion = errors.New("ip version error") + // DefaultResolver aim to resolve ip + DefaultResolver Resolver + + // DefaultHosts aim to resolve hosts + DefaultHosts = trie.New() ) +var ( + ErrIPNotFound = errors.New("couldn't find ip") + ErrIPVersion = errors.New("ip version error") +) + +type Resolver interface { + ResolveIP(host string) (ip net.IP, err error) + ResolveIPv4(host string) (ip net.IP, err error) + ResolveIPv6(host string) (ip net.IP, err error) +} + // ResolveIPv4 with a host, return ipv4 func ResolveIPv4(host string) (net.IP, error) { if node := DefaultHosts.Search(host); node != nil { @@ -24,7 +40,7 @@ func ResolveIPv4(host string) (net.IP, error) { if !strings.Contains(host, ":") { return ip, nil } - return nil, errIPVersion + return nil, ErrIPVersion } if DefaultResolver != nil { @@ -42,7 +58,7 @@ func ResolveIPv4(host string) (net.IP, error) { } } - return nil, errIPNotFound + return nil, ErrIPNotFound } // ResolveIPv6 with a host, return ipv6 @@ -58,7 +74,7 @@ func ResolveIPv6(host string) (net.IP, error) { if strings.Contains(host, ":") { return ip, nil } - return nil, errIPVersion + return nil, ErrIPVersion } if DefaultResolver != nil { @@ -76,7 +92,7 @@ func ResolveIPv6(host string) (net.IP, error) { } } - return nil, errIPNotFound + return nil, ErrIPNotFound } // ResolveIP with a host, return ip @@ -86,10 +102,7 @@ func ResolveIP(host string) (net.IP, error) { } if DefaultResolver != nil { - if DefaultResolver.ipv6 { - return DefaultResolver.ResolveIP(host) - } - return DefaultResolver.ResolveIPv4(host) + return DefaultResolver.ResolveIP(host) } ip := net.ParseIP(host) diff --git a/config/config.go b/config/config.go index be6ae542..e43006a7 100644 --- a/config/config.go +++ b/config/config.go @@ -1,6 +1,7 @@ package config import ( + "errors" "fmt" "net" "net/url" @@ -30,7 +31,7 @@ type General struct { Authentication []string `json:"authentication"` AllowLan bool `json:"allow-lan"` BindAddress string `json:"bind-address"` - Mode T.Mode `json:"mode"` + Mode T.TunnelMode `json:"mode"` LogLevel log.LogLevel `json:"log-level"` ExternalController string `json:"-"` ExternalUI string `json:"-"` @@ -39,14 +40,15 @@ type General struct { // DNS config type DNS struct { - Enable bool `yaml:"enable"` - IPv6 bool `yaml:"ipv6"` - NameServer []dns.NameServer `yaml:"nameserver"` - Fallback []dns.NameServer `yaml:"fallback"` - FallbackFilter FallbackFilter `yaml:"fallback-filter"` - Listen string `yaml:"listen"` - EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"` - FakeIPRange *fakeip.Pool + Enable bool `yaml:"enable"` + IPv6 bool `yaml:"ipv6"` + NameServer []dns.NameServer `yaml:"nameserver"` + Fallback []dns.NameServer `yaml:"fallback"` + FallbackFilter FallbackFilter `yaml:"fallback-filter"` + Listen string `yaml:"listen"` + EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"` + DefaultNameserver []dns.NameServer `yaml:"default-nameserver"` + FakeIPRange *fakeip.Pool } // FallbackFilter config @@ -57,7 +59,8 @@ type FallbackFilter struct { // Experimental config type Experimental struct { - IgnoreResolveFail bool `yaml:"ignore-resolve-fail"` + IgnoreResolveFail bool `yaml:"ignore-resolve-fail"` + Interface string `yaml:"interface-name"` } // Config is clash config manager @@ -73,15 +76,16 @@ type Config struct { } type RawDNS struct { - Enable bool `yaml:"enable"` - IPv6 bool `yaml:"ipv6"` - NameServer []string `yaml:"nameserver"` - Fallback []string `yaml:"fallback"` - FallbackFilter RawFallbackFilter `yaml:"fallback-filter"` - Listen string `yaml:"listen"` - EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"` - FakeIPRange string `yaml:"fake-ip-range"` - FakeIPFilter []string `yaml:"fake-ip-filter"` + Enable bool `yaml:"enable"` + IPv6 bool `yaml:"ipv6"` + NameServer []string `yaml:"nameserver"` + Fallback []string `yaml:"fallback"` + FallbackFilter RawFallbackFilter `yaml:"fallback-filter"` + Listen string `yaml:"listen"` + EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"` + FakeIPRange string `yaml:"fake-ip-range"` + FakeIPFilter []string `yaml:"fake-ip-filter"` + DefaultNameserver []string `yaml:"default-nameserver"` } type RawFallbackFilter struct { @@ -96,7 +100,7 @@ type RawConfig struct { Authentication []string `yaml:"authentication"` AllowLan bool `yaml:"allow-lan"` BindAddress string `yaml:"bind-address"` - Mode T.Mode `yaml:"mode"` + Mode T.TunnelMode `yaml:"mode"` LogLevel log.LogLevel `yaml:"log-level"` ExternalController string `yaml:"external-controller"` ExternalUI string `yaml:"external-ui"` @@ -143,6 +147,10 @@ func UnmarshalRawConfig(buf []byte) (*RawConfig, error) { GeoIP: true, IPCIDR: []string{}, }, + DefaultNameserver: []string{ + "114.114.114.114", + "8.8.8.8", + }, }, } @@ -433,21 +441,21 @@ func parseHosts(cfg *RawConfig) (*trie.Trie, error) { return tree, nil } -func hostWithDefaultPort(host string, defPort string) (string, error) { +func hostWithDefaultPort(host string, defPort string) (string, string, error) { if !strings.Contains(host, ":") { host += ":" } hostname, port, err := net.SplitHostPort(host) if err != nil { - return "", err + return "", "", err } if port == "" { port = defPort } - return net.JoinHostPort(hostname, port), nil + return net.JoinHostPort(hostname, port), hostname, nil } func parseNameServer(servers []string) ([]dns.NameServer, error) { @@ -463,20 +471,21 @@ func parseNameServer(servers []string) ([]dns.NameServer, error) { return nil, fmt.Errorf("DNS NameServer[%d] format error: %s", idx, err.Error()) } - var host, dnsNetType string + var addr, dnsNetType, host string switch u.Scheme { case "udp": - host, err = hostWithDefaultPort(u.Host, "53") + addr, host, err = hostWithDefaultPort(u.Host, "53") dnsNetType = "" // UDP case "tcp": - host, err = hostWithDefaultPort(u.Host, "53") + addr, host, err = hostWithDefaultPort(u.Host, "53") dnsNetType = "tcp" // TCP case "tls": - host, err = hostWithDefaultPort(u.Host, "853") + addr, host, err = hostWithDefaultPort(u.Host, "853") dnsNetType = "tcp-tls" // DNS over TLS case "https": clearURL := url.URL{Scheme: "https", Host: u.Host, Path: u.Path} - host = clearURL.String() + addr = clearURL.String() + _, host, err = hostWithDefaultPort(u.Host, "853") dnsNetType = "https" // DNS over HTTPS default: return nil, fmt.Errorf("DNS NameServer[%d] unsupport scheme: %s", idx, u.Scheme) @@ -490,7 +499,8 @@ func parseNameServer(servers []string) ([]dns.NameServer, error) { nameservers, dns.NameServer{ Net: dnsNetType, - Addr: host, + Addr: addr, + Host: host, }, ) } @@ -534,6 +544,19 @@ func parseDNS(cfg RawDNS) (*DNS, error) { return nil, err } + if len(cfg.DefaultNameserver) == 0 { + return nil, errors.New("default nameserver should have at least one nameserver") + } + if dnsCfg.DefaultNameserver, err = parseNameServer(cfg.DefaultNameserver); err != nil { + return nil, err + } + // check default nameserver is pure ip addr + for _, ns := range dnsCfg.DefaultNameserver { + if net.ParseIP(ns.Host) == nil { + return nil, errors.New("default nameserver should be pure IP") + } + } + if cfg.EnhancedMode == dns.FAKEIP { _, ipnet, err := net.ParseCIDR(cfg.FakeIPRange) if err != nil { diff --git a/dns/client.go b/dns/client.go index a3d56669..1f1e6c6f 100644 --- a/dns/client.go +++ b/dns/client.go @@ -2,6 +2,7 @@ package dns import ( "context" + "strings" "github.com/Dreamacro/clash/component/dialer" @@ -10,7 +11,9 @@ import ( type client struct { *D.Client - Address string + r *Resolver + addr string + host string } func (c *client) Exchange(m *D.Msg) (msg *D.Msg, err error) { @@ -18,7 +21,22 @@ func (c *client) Exchange(m *D.Msg) (msg *D.Msg, err error) { } func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { - c.Client.Dialer = dialer.Dialer() + network := "udp" + if strings.HasPrefix(c.Client.Net, "tcp") { + network = "tcp" + } + + ip, err := c.r.ResolveIPv4(c.host) + if err != nil { + return nil, err + } + + d := dialer.Dialer() + if dialer.DialHook != nil { + dialer.DialHook(d, network, ip) + } + + c.Client.Dialer = d // miekg/dns ExchangeContext doesn't respond to context cancel. // this is a workaround @@ -28,7 +46,7 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err } ch := make(chan result, 1) go func() { - msg, _, err := c.Client.ExchangeContext(ctx, m, c.Address) + msg, _, err := c.Client.Exchange(m, c.addr) ch <- result{msg, err} }() diff --git a/dns/doh.go b/dns/doh.go index 4aecd355..8715393e 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "io/ioutil" + "net" "net/http" "github.com/Dreamacro/clash/component/dialer" @@ -17,13 +18,9 @@ const ( dotMimeType = "application/dns-message" ) -var dohTransport = &http.Transport{ - TLSClientConfig: &tls.Config{ClientSessionCache: globalSessionCache}, - DialContext: dialer.DialContext, -} - type dohClient struct { - url string + url string + transport *http.Transport } func (dc *dohClient) Exchange(m *D.Msg) (msg *D.Msg, err error) { @@ -58,7 +55,7 @@ func (dc *dohClient) newRequest(m *D.Msg) (*http.Request, error) { } func (dc *dohClient) doRequest(req *http.Request) (msg *D.Msg, err error) { - client := &http.Client{Transport: dohTransport} + client := &http.Client{Transport: dc.transport} resp, err := client.Do(req) if err != nil { return nil, err @@ -73,3 +70,25 @@ func (dc *dohClient) doRequest(req *http.Request) (msg *D.Msg, err error) { err = msg.Unpack(buf) return msg, err } + +func newDoHClient(url string, r *Resolver) *dohClient { + return &dohClient{ + url: url, + transport: &http.Transport{ + TLSClientConfig: &tls.Config{ClientSessionCache: globalSessionCache}, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + ip, err := r.ResolveIPv4(host) + if err != nil { + return nil, err + } + + return dialer.DialContext(ctx, "tcp4", net.JoinHostPort(ip.String(), port)) + }, + }, + } +} diff --git a/dns/resolver.go b/dns/resolver.go index 58c2a3e9..480818c9 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -11,26 +11,18 @@ import ( "github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/common/picker" - trie "github.com/Dreamacro/clash/component/domain-trie" "github.com/Dreamacro/clash/component/fakeip" + "github.com/Dreamacro/clash/component/resolver" D "github.com/miekg/dns" "golang.org/x/sync/singleflight" ) -var ( - // DefaultResolver aim to resolve ip - DefaultResolver *Resolver - - // DefaultHosts aim to resolve hosts - DefaultHosts = trie.New() -) - var ( globalSessionCache = tls.NewLRUClientSessionCache(64) ) -type resolver interface { +type dnsClient interface { Exchange(m *D.Msg) (msg *D.Msg, err error) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) } @@ -45,8 +37,8 @@ type Resolver struct { mapping bool fakeip bool pool *fakeip.Pool - main []resolver - fallback []resolver + main []dnsClient + fallback []dnsClient fallbackFilters []fallbackFilter group singleflight.Group cache *cache.Cache @@ -74,7 +66,7 @@ func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) { ip, open := <-ch if !open { - return nil, errIPNotFound + return nil, resolver.ErrIPNotFound } return ip, nil @@ -174,7 +166,7 @@ func (r *Resolver) IsFakeIP(ip net.IP) bool { return false } -func (r *Resolver) batchExchange(clients []resolver, m *D.Msg) (msg *D.Msg, err error) { +func (r *Resolver) batchExchange(clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) { fast, ctx := picker.WithTimeout(context.Background(), time.Second*5) for _, client := range clients { r := client @@ -238,7 +230,7 @@ func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error) ips := r.msgToIP(msg) ipLength := len(ips) if ipLength == 0 { - return nil, errIPNotFound + return nil, resolver.ErrIPNotFound } ip = ips[rand.Intn(ipLength)] @@ -260,7 +252,7 @@ func (r *Resolver) msgToIP(msg *D.Msg) []net.IP { return ips } -func (r *Resolver) asyncExchange(client []resolver, msg *D.Msg) <-chan *result { +func (r *Resolver) asyncExchange(client []dnsClient, msg *D.Msg) <-chan *result { ch := make(chan *result) go func() { res, err := r.batchExchange(client, msg) @@ -272,6 +264,7 @@ func (r *Resolver) asyncExchange(client []resolver, msg *D.Msg) <-chan *result { type NameServer struct { Net string Addr string + Host string } type FallbackFilter struct { @@ -281,6 +274,7 @@ type FallbackFilter struct { type Config struct { Main, Fallback []NameServer + Default []NameServer IPv6 bool EnhancedMode EnhancedMode FallbackFilter FallbackFilter @@ -288,9 +282,14 @@ type Config struct { } func New(config Config) *Resolver { + defaultResolver := &Resolver{ + main: transform(config.Default, nil), + cache: cache.New(time.Second * 60), + } + r := &Resolver{ ipv6: config.IPv6, - main: transform(config.Main), + main: transform(config.Main, defaultResolver), cache: cache.New(time.Second * 60), mapping: config.EnhancedMode == MAPPING, fakeip: config.EnhancedMode == FAKEIP, @@ -298,7 +297,7 @@ func New(config Config) *Resolver { } if len(config.Fallback) != 0 { - r.fallback = transform(config.Fallback) + r.fallback = transform(config.Fallback, defaultResolver) } fallbackFilters := []fallbackFilter{} diff --git a/dns/util.go b/dns/util.go index d5cb216d..2cf0cd9b 100644 --- a/dns/util.go +++ b/dns/util.go @@ -8,9 +8,9 @@ import ( "github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/log" - yaml "gopkg.in/yaml.v2" D "github.com/miekg/dns" + yaml "gopkg.in/yaml.v2" ) var ( @@ -117,11 +117,11 @@ func isIPRequest(q D.Question) bool { return false } -func transform(servers []NameServer) []resolver { - ret := []resolver{} +func transform(servers []NameServer, resolver *Resolver) []dnsClient { + ret := []dnsClient{} for _, s := range servers { if s.Net == "https" { - ret = append(ret, &dohClient{url: s.Addr}) + ret = append(ret, newDoHClient(s.Addr, resolver)) continue } @@ -136,7 +136,8 @@ func transform(servers []NameServer) []resolver { UDPSize: 4096, Timeout: 5 * time.Second, }, - Address: s.Addr, + addr: s.Addr, + host: s.Host, }) } return ret diff --git a/hub/executor/executor.go b/hub/executor/executor.go index ac45c172..5916bfbb 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -8,14 +8,16 @@ import ( "github.com/Dreamacro/clash/adapters/provider" "github.com/Dreamacro/clash/component/auth" + "github.com/Dreamacro/clash/component/dialer" trie "github.com/Dreamacro/clash/component/domain-trie" + "github.com/Dreamacro/clash/component/resolver" "github.com/Dreamacro/clash/config" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/dns" "github.com/Dreamacro/clash/log" P "github.com/Dreamacro/clash/proxy" authStore "github.com/Dreamacro/clash/proxy/auth" - T "github.com/Dreamacro/clash/tunnel" + "github.com/Dreamacro/clash/tunnel" ) // forward compatibility before 1.0 @@ -83,7 +85,7 @@ func ApplyConfig(cfg *config.Config, force bool) { updateRules(cfg.Rules) updateDNS(cfg.DNS) updateHosts(cfg.Hosts) - updateExperimental(cfg.Experimental) + updateExperimental(cfg) } func GetGeneral() *config.General { @@ -100,20 +102,30 @@ func GetGeneral() *config.General { Authentication: authenticator, AllowLan: P.AllowLan(), BindAddress: P.BindAddress(), - Mode: T.Instance().Mode(), + Mode: tunnel.Mode(), LogLevel: log.Level(), } return general } -func updateExperimental(c *config.Experimental) { - T.Instance().UpdateExperimental(c.IgnoreResolveFail) +func updateExperimental(c *config.Config) { + cfg := c.Experimental + + tunnel.UpdateExperimental(cfg.IgnoreResolveFail) + if cfg.Interface != "" && c.DNS.Enable { + dialer.DialHook = dialer.DialerWithInterface(cfg.Interface) + dialer.ListenPacketHook = dialer.ListenPacketWithInterface(cfg.Interface) + } else { + dialer.DialHook = nil + dialer.ListenPacketHook = nil + } } func updateDNS(c *config.DNS) { if c.Enable == false { - dns.DefaultResolver = nil + resolver.DefaultResolver = nil + tunnel.SetResolver(nil) dns.ReCreateServer("", nil) return } @@ -127,8 +139,10 @@ func updateDNS(c *config.DNS) { GeoIP: c.FallbackFilter.GeoIP, IPCIDR: c.FallbackFilter.IPCIDR, }, + Default: c.DefaultNameserver, }) - dns.DefaultResolver = r + resolver.DefaultResolver = r + tunnel.SetResolver(r) if err := dns.ReCreateServer(c.Listen, r); err != nil { log.Errorln("Start DNS server error: %s", err.Error()) return @@ -140,11 +154,10 @@ func updateDNS(c *config.DNS) { } func updateHosts(tree *trie.Trie) { - dns.DefaultHosts = tree + resolver.DefaultHosts = tree } func updateProxies(proxies map[string]C.Proxy, providers map[string]provider.ProxyProvider) { - tunnel := T.Instance() oldProviders := tunnel.Providers() // close providers goroutine @@ -156,12 +169,12 @@ func updateProxies(proxies map[string]C.Proxy, providers map[string]provider.Pro } func updateRules(rules []C.Rule) { - T.Instance().UpdateRules(rules) + tunnel.UpdateRules(rules) } func updateGeneral(general *config.General) { log.SetLevel(general.LogLevel) - T.Instance().SetMode(general.Mode) + tunnel.SetMode(general.Mode) allowLan := general.AllowLan P.SetAllowLan(allowLan) diff --git a/hub/route/configs.go b/hub/route/configs.go index ea720a7f..dd97a840 100644 --- a/hub/route/configs.go +++ b/hub/route/configs.go @@ -8,7 +8,7 @@ import ( "github.com/Dreamacro/clash/hub/executor" "github.com/Dreamacro/clash/log" P "github.com/Dreamacro/clash/proxy" - T "github.com/Dreamacro/clash/tunnel" + "github.com/Dreamacro/clash/tunnel" "github.com/go-chi/chi" "github.com/go-chi/render" @@ -23,13 +23,13 @@ func configRouter() http.Handler { } type configSchema struct { - Port *int `json:"port"` - SocksPort *int `json:"socks-port"` - RedirPort *int `json:"redir-port"` - AllowLan *bool `json:"allow-lan"` - BindAddress *string `json:"bind-address"` - Mode *T.Mode `json:"mode"` - LogLevel *log.LogLevel `json:"log-level"` + Port *int `json:"port"` + SocksPort *int `json:"socks-port"` + RedirPort *int `json:"redir-port"` + AllowLan *bool `json:"allow-lan"` + BindAddress *string `json:"bind-address"` + Mode *tunnel.TunnelMode `json:"mode"` + LogLevel *log.LogLevel `json:"log-level"` } func getConfigs(w http.ResponseWriter, r *http.Request) { @@ -67,7 +67,7 @@ func patchConfigs(w http.ResponseWriter, r *http.Request) { P.ReCreateRedir(pointerOrDefault(general.RedirPort, ports.RedirPort)) if general.Mode != nil { - T.Instance().SetMode(*general.Mode) + tunnel.SetMode(*general.Mode) } if general.LogLevel != nil { diff --git a/hub/route/provider.go b/hub/route/provider.go index f69bec46..eb7f8320 100644 --- a/hub/route/provider.go +++ b/hub/route/provider.go @@ -5,7 +5,7 @@ import ( "net/http" "github.com/Dreamacro/clash/adapters/provider" - T "github.com/Dreamacro/clash/tunnel" + "github.com/Dreamacro/clash/tunnel" "github.com/go-chi/chi" "github.com/go-chi/render" @@ -25,7 +25,7 @@ func proxyProviderRouter() http.Handler { } func getProviders(w http.ResponseWriter, r *http.Request) { - providers := T.Instance().Providers() + providers := tunnel.Providers() render.JSON(w, r, render.M{ "providers": providers, }) @@ -63,7 +63,7 @@ func parseProviderName(next http.Handler) http.Handler { func findProviderByName(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { name := r.Context().Value(CtxKeyProviderName).(string) - providers := T.Instance().Providers() + providers := tunnel.Providers() provider, exist := providers[name] if !exist { render.Status(r, http.StatusNotFound) diff --git a/hub/route/proxies.go b/hub/route/proxies.go index ab769f9c..e3cb066c 100644 --- a/hub/route/proxies.go +++ b/hub/route/proxies.go @@ -10,7 +10,7 @@ import ( "github.com/Dreamacro/clash/adapters/outbound" "github.com/Dreamacro/clash/adapters/outboundgroup" C "github.com/Dreamacro/clash/constant" - T "github.com/Dreamacro/clash/tunnel" + "github.com/Dreamacro/clash/tunnel" "github.com/go-chi/chi" "github.com/go-chi/render" @@ -40,7 +40,7 @@ func parseProxyName(next http.Handler) http.Handler { func findProxyByName(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { name := r.Context().Value(CtxKeyProxyName).(string) - proxies := T.Instance().Proxies() + proxies := tunnel.Proxies() proxy, exist := proxies[name] if !exist { render.Status(r, http.StatusNotFound) @@ -54,7 +54,7 @@ func findProxyByName(next http.Handler) http.Handler { } func getProxies(w http.ResponseWriter, r *http.Request) { - proxies := T.Instance().Proxies() + proxies := tunnel.Proxies() render.JSON(w, r, render.M{ "proxies": proxies, }) diff --git a/hub/route/rules.go b/hub/route/rules.go index 6a223ab5..d03ca63f 100644 --- a/hub/route/rules.go +++ b/hub/route/rules.go @@ -3,7 +3,7 @@ package route import ( "net/http" - T "github.com/Dreamacro/clash/tunnel" + "github.com/Dreamacro/clash/tunnel" "github.com/go-chi/chi" "github.com/go-chi/render" @@ -22,7 +22,7 @@ type Rule struct { } func getRules(w http.ResponseWriter, r *http.Request) { - rawRules := T.Instance().Rules() + rawRules := tunnel.Rules() rules := []Rule{} for _, rule := range rawRules { diff --git a/proxy/http/server.go b/proxy/http/server.go index 51941ef6..de7e0fca 100644 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -16,10 +16,6 @@ import ( "github.com/Dreamacro/clash/tunnel" ) -var ( - tun = tunnel.Instance() -) - type HttpListener struct { net.Listener address string @@ -100,9 +96,9 @@ func handleConn(conn net.Conn, cache *cache.Cache) { if err != nil { return } - tun.Add(adapters.NewHTTPS(request, conn)) + tunnel.Add(adapters.NewHTTPS(request, conn)) return } - tun.Add(adapters.NewHTTP(request, conn)) + tunnel.Add(adapters.NewHTTP(request, conn)) } diff --git a/proxy/redir/tcp.go b/proxy/redir/tcp.go index fd455603..58352198 100644 --- a/proxy/redir/tcp.go +++ b/proxy/redir/tcp.go @@ -9,10 +9,6 @@ import ( "github.com/Dreamacro/clash/tunnel" ) -var ( - tun = tunnel.Instance() -) - type RedirListener struct { net.Listener address string @@ -59,5 +55,5 @@ func handleRedir(conn net.Conn) { return } conn.(*net.TCPConn).SetKeepAlive(true) - tun.Add(inbound.NewSocket(target, conn, C.REDIR, C.TCP)) + tunnel.Add(inbound.NewSocket(target, conn, C.REDIR, C.TCP)) } diff --git a/proxy/socks/tcp.go b/proxy/socks/tcp.go index 1d080dd6..eeff4bc4 100644 --- a/proxy/socks/tcp.go +++ b/proxy/socks/tcp.go @@ -13,10 +13,6 @@ import ( "github.com/Dreamacro/clash/tunnel" ) -var ( - tun = tunnel.Instance() -) - type SockListener struct { net.Listener address string @@ -68,5 +64,5 @@ func handleSocks(conn net.Conn) { io.Copy(ioutil.Discard, conn) return } - tun.Add(adapters.NewSocket(target, conn, C.SOCKS, C.TCP)) + tunnel.Add(adapters.NewSocket(target, conn, C.SOCKS, C.TCP)) } diff --git a/proxy/socks/udp.go b/proxy/socks/udp.go index 2239af3c..46e681ff 100644 --- a/proxy/socks/udp.go +++ b/proxy/socks/udp.go @@ -7,6 +7,7 @@ import ( "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/component/socks5" C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/tunnel" ) type SockUDPListener struct { @@ -62,5 +63,5 @@ func handleSocksUDP(pc net.PacketConn, buf []byte, addr net.Addr) { payload: payload, bufRef: buf, } - tun.AddPacket(adapters.NewPacket(target, packet, C.SOCKS)) + tunnel.AddPacket(adapters.NewPacket(target, packet, C.SOCKS)) } diff --git a/tunnel/connection.go b/tunnel/connection.go index 6771aa40..98f3bbce 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -14,7 +14,7 @@ import ( "github.com/Dreamacro/clash/common/pool" ) -func (t *Tunnel) handleHTTP(request *adapters.HTTPAdapter, outbound net.Conn) { +func handleHTTP(request *adapters.HTTPAdapter, outbound net.Conn) { req := request.R host := req.Host @@ -81,17 +81,17 @@ func (t *Tunnel) handleHTTP(request *adapters.HTTPAdapter, outbound net.Conn) { } } -func (t *Tunnel) handleUDPToRemote(packet C.UDPPacket, pc net.PacketConn, addr net.Addr) { +func handleUDPToRemote(packet C.UDPPacket, pc net.PacketConn, addr net.Addr) { if _, err := pc.WriteTo(packet.Data(), addr); err != nil { return } DefaultManager.Upload() <- int64(len(packet.Data())) } -func (t *Tunnel) handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string) { +func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string) { buf := pool.BufPool.Get().([]byte) defer pool.BufPool.Put(buf[:cap(buf)]) - defer t.natTable.Delete(key) + defer natTable.Delete(key) defer pc.Close() for { @@ -109,7 +109,7 @@ func (t *Tunnel) handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key str } } -func (t *Tunnel) handleSocket(request *adapters.SocketAdapter, outbound net.Conn) { +func handleSocket(request *adapters.SocketAdapter, outbound net.Conn) { relay(request, outbound) } diff --git a/tunnel/mode.go b/tunnel/mode.go index 5f7d9639..b00b73ab 100644 --- a/tunnel/mode.go +++ b/tunnel/mode.go @@ -5,11 +5,11 @@ import ( "errors" ) -type Mode int +type TunnelMode int var ( // ModeMapping is a mapping for Mode enum - ModeMapping = map[string]Mode{ + ModeMapping = map[string]TunnelMode{ Global.String(): Global, Rule.String(): Rule, Direct.String(): Direct, @@ -17,13 +17,13 @@ var ( ) const ( - Global Mode = iota + Global TunnelMode = iota Rule Direct ) // UnmarshalJSON unserialize Mode -func (m *Mode) UnmarshalJSON(data []byte) error { +func (m *TunnelMode) UnmarshalJSON(data []byte) error { var tp string json.Unmarshal(data, &tp) mode, exist := ModeMapping[tp] @@ -35,7 +35,7 @@ func (m *Mode) UnmarshalJSON(data []byte) error { } // UnmarshalYAML unserialize Mode with yaml -func (m *Mode) UnmarshalYAML(unmarshal func(interface{}) error) error { +func (m *TunnelMode) UnmarshalYAML(unmarshal func(interface{}) error) error { var tp string unmarshal(&tp) mode, exist := ModeMapping[tp] @@ -47,11 +47,11 @@ func (m *Mode) UnmarshalYAML(unmarshal func(interface{}) error) error { } // MarshalJSON serialize Mode -func (m Mode) MarshalJSON() ([]byte, error) { +func (m TunnelMode) MarshalJSON() ([]byte, error) { return json.Marshal(m.String()) } -func (m Mode) String() string { +func (m TunnelMode) String() string { switch m { case Global: return "Global" diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index f572399a..0a37bb1f 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -10,6 +10,7 @@ import ( "github.com/Dreamacro/clash/adapters/inbound" "github.com/Dreamacro/clash/adapters/provider" "github.com/Dreamacro/clash/component/nat" + "github.com/Dreamacro/clash/component/resolver" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/dns" "github.com/Dreamacro/clash/log" @@ -18,136 +19,136 @@ import ( ) var ( - tunnel *Tunnel - once sync.Once - - // default timeout for UDP session - udpTimeout = 60 * time.Second -) - -// Tunnel handle relay inbound proxy and outbound proxy -type Tunnel struct { - tcpQueue *channels.InfiniteChannel - udpQueue *channels.InfiniteChannel - natTable *nat.Table - rules []C.Rule - proxies map[string]C.Proxy - providers map[string]provider.ProxyProvider - configMux sync.RWMutex + tcpQueue = channels.NewInfiniteChannel() + udpQueue = channels.NewInfiniteChannel() + natTable = nat.New() + rules []C.Rule + proxies = make(map[string]C.Proxy) + providers map[string]provider.ProxyProvider + configMux sync.RWMutex + enhancedMode *dns.Resolver // experimental features ignoreResolveFail bool // Outbound Rule - mode Mode + mode = Rule + + // default timeout for UDP session + udpTimeout = 60 * time.Second +) + +func init() { + go process() } // Add request to queue -func (t *Tunnel) Add(req C.ServerAdapter) { - t.tcpQueue.In() <- req +func Add(req C.ServerAdapter) { + tcpQueue.In() <- req } // AddPacket add udp Packet to queue -func (t *Tunnel) AddPacket(packet *inbound.PacketAdapter) { - t.udpQueue.In() <- packet +func AddPacket(packet *inbound.PacketAdapter) { + udpQueue.In() <- packet } // Rules return all rules -func (t *Tunnel) Rules() []C.Rule { - return t.rules +func Rules() []C.Rule { + return rules } // UpdateRules handle update rules -func (t *Tunnel) UpdateRules(rules []C.Rule) { - t.configMux.Lock() - t.rules = rules - t.configMux.Unlock() +func UpdateRules(newRules []C.Rule) { + configMux.Lock() + rules = newRules + configMux.Unlock() } // Proxies return all proxies -func (t *Tunnel) Proxies() map[string]C.Proxy { - return t.proxies +func Proxies() map[string]C.Proxy { + return proxies } // Providers return all compatible providers -func (t *Tunnel) Providers() map[string]provider.ProxyProvider { - return t.providers +func Providers() map[string]provider.ProxyProvider { + return providers } // UpdateProxies handle update proxies -func (t *Tunnel) UpdateProxies(proxies map[string]C.Proxy, providers map[string]provider.ProxyProvider) { - t.configMux.Lock() - t.proxies = proxies - t.providers = providers - t.configMux.Unlock() +func UpdateProxies(newProxies map[string]C.Proxy, newProviders map[string]provider.ProxyProvider) { + configMux.Lock() + proxies = newProxies + providers = newProviders + configMux.Unlock() } // UpdateExperimental handle update experimental config -func (t *Tunnel) UpdateExperimental(ignoreResolveFail bool) { - t.configMux.Lock() - t.ignoreResolveFail = ignoreResolveFail - t.configMux.Unlock() +func UpdateExperimental(value bool) { + configMux.Lock() + ignoreResolveFail = value + configMux.Unlock() } // Mode return current mode -func (t *Tunnel) Mode() Mode { - return t.mode +func Mode() TunnelMode { + return mode } // SetMode change the mode of tunnel -func (t *Tunnel) SetMode(mode Mode) { - t.mode = mode +func SetMode(m TunnelMode) { + mode = m +} + +// SetResolver set custom dns resolver for enhanced mode +func SetResolver(r *dns.Resolver) { + enhancedMode = r } // processUDP starts a loop to handle udp packet -func (t *Tunnel) processUDP() { - queue := t.udpQueue.Out() +func processUDP() { + queue := udpQueue.Out() for elm := range queue { conn := elm.(*inbound.PacketAdapter) - t.handleUDPConn(conn) + handleUDPConn(conn) } } -func (t *Tunnel) process() { +func process() { numUDPWorkers := 4 if runtime.NumCPU() > numUDPWorkers { numUDPWorkers = runtime.NumCPU() } for i := 0; i < numUDPWorkers; i++ { - go t.processUDP() + go processUDP() } - queue := t.tcpQueue.Out() + queue := tcpQueue.Out() for elm := range queue { conn := elm.(C.ServerAdapter) - go t.handleTCPConn(conn) + go handleTCPConn(conn) } } -func (t *Tunnel) resolveIP(host string) (net.IP, error) { - return dns.ResolveIP(host) +func needLookupIP(metadata *C.Metadata) bool { + return enhancedMode != nil && (enhancedMode.IsMapping() || enhancedMode.FakeIPEnabled()) && metadata.Host == "" && metadata.DstIP != nil } -func (t *Tunnel) needLookupIP(metadata *C.Metadata) bool { - return dns.DefaultResolver != nil && (dns.DefaultResolver.IsMapping() || dns.DefaultResolver.FakeIPEnabled()) && metadata.Host == "" && metadata.DstIP != nil -} - -func (t *Tunnel) preHandleMetadata(metadata *C.Metadata) error { +func preHandleMetadata(metadata *C.Metadata) error { // handle IP string on host if ip := net.ParseIP(metadata.Host); ip != nil { metadata.DstIP = ip } // preprocess enhanced-mode metadata - if t.needLookupIP(metadata) { - host, exist := dns.DefaultResolver.IPToHost(metadata.DstIP) + if needLookupIP(metadata) { + host, exist := enhancedMode.IPToHost(metadata.DstIP) if exist { metadata.Host = host metadata.AddrType = C.AtypDomainName - if dns.DefaultResolver.FakeIPEnabled() { + if enhancedMode.FakeIPEnabled() { metadata.DstIP = nil } - } else if dns.DefaultResolver.IsFakeIP(metadata.DstIP) { + } else if enhancedMode.IsFakeIP(metadata.DstIP) { return fmt.Errorf("fake DNS record %s missing", metadata.DstIP) } } @@ -155,18 +156,18 @@ func (t *Tunnel) preHandleMetadata(metadata *C.Metadata) error { return nil } -func (t *Tunnel) resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) { +func resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) { var proxy C.Proxy var rule C.Rule - switch t.mode { + switch mode { case Direct: - proxy = t.proxies["DIRECT"] + proxy = proxies["DIRECT"] case Global: - proxy = t.proxies["GLOBAL"] + proxy = proxies["GLOBAL"] // Rule default: var err error - proxy, rule, err = t.match(metadata) + proxy, rule, err = match(metadata) if err != nil { return nil, nil, err } @@ -174,23 +175,23 @@ func (t *Tunnel) resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) return proxy, rule, nil } -func (t *Tunnel) handleUDPConn(packet *inbound.PacketAdapter) { +func handleUDPConn(packet *inbound.PacketAdapter) { metadata := packet.Metadata() if !metadata.Valid() { log.Warnln("[Metadata] not valid: %#v", metadata) return } - if err := t.preHandleMetadata(metadata); err != nil { + if err := preHandleMetadata(metadata); err != nil { log.Debugln("[Metadata PreHandle] error: %s", err) return } key := packet.LocalAddr().String() - pc := t.natTable.Get(key) + pc := natTable.Get(key) if pc != nil { if !metadata.Resolved() { - ip, err := t.resolveIP(metadata.Host) + ip, err := resolver.ResolveIP(metadata.Host) if err != nil { log.Warnln("[UDP] Resolve %s failed: %s, %#v", metadata.Host, err.Error(), metadata) return @@ -198,20 +199,20 @@ func (t *Tunnel) handleUDPConn(packet *inbound.PacketAdapter) { metadata.DstIP = ip } - t.handleUDPToRemote(packet, pc, metadata.UDPAddr()) + handleUDPToRemote(packet, pc, metadata.UDPAddr()) return } lockKey := key + "-lock" - wg, loaded := t.natTable.GetOrCreateLock(lockKey) + wg, loaded := natTable.GetOrCreateLock(lockKey) go func() { if !loaded { wg.Add(1) - proxy, rule, err := t.resolveMetadata(metadata) + proxy, rule, err := resolveMetadata(metadata) if err != nil { log.Warnln("[UDP] Parse metadata failed: %s", err.Error()) - t.natTable.Delete(lockKey) + natTable.Delete(lockKey) wg.Done() return } @@ -219,7 +220,7 @@ func (t *Tunnel) handleUDPConn(packet *inbound.PacketAdapter) { rawPc, err := proxy.DialUDP(metadata) if err != nil { log.Warnln("[UDP] dial %s error: %s", proxy.Name(), err.Error()) - t.natTable.Delete(lockKey) + natTable.Delete(lockKey) wg.Done() return } @@ -228,36 +229,36 @@ func (t *Tunnel) handleUDPConn(packet *inbound.PacketAdapter) { switch true { case rule != nil: log.Infoln("[UDP] %s --> %v match %s using %s", metadata.SourceAddress(), metadata.String(), rule.RuleType().String(), rawPc.Chains().String()) - case t.mode == Global: + case mode == Global: log.Infoln("[UDP] %s --> %v using GLOBAL", metadata.SourceAddress(), metadata.String()) - case t.mode == Direct: + case mode == Direct: log.Infoln("[UDP] %s --> %v using DIRECT", metadata.SourceAddress(), metadata.String()) default: log.Infoln("[UDP] %s --> %v doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.String()) } - t.natTable.Set(key, pc) - t.natTable.Delete(lockKey) + natTable.Set(key, pc) + natTable.Delete(lockKey) wg.Done() - go t.handleUDPToLocal(packet.UDPPacket, pc, key) + go handleUDPToLocal(packet.UDPPacket, pc, key) } wg.Wait() - pc := t.natTable.Get(key) + pc := natTable.Get(key) if pc != nil { if !metadata.Resolved() { - ip, err := dns.ResolveIP(metadata.Host) + ip, err := resolver.ResolveIP(metadata.Host) if err != nil { return } metadata.DstIP = ip } - t.handleUDPToRemote(packet, pc, metadata.UDPAddr()) + handleUDPToRemote(packet, pc, metadata.UDPAddr()) } }() } -func (t *Tunnel) handleTCPConn(localConn C.ServerAdapter) { +func handleTCPConn(localConn C.ServerAdapter) { defer localConn.Close() metadata := localConn.Metadata() @@ -266,12 +267,12 @@ func (t *Tunnel) handleTCPConn(localConn C.ServerAdapter) { return } - if err := t.preHandleMetadata(metadata); err != nil { + if err := preHandleMetadata(metadata); err != nil { log.Debugln("[Metadata PreHandle] error: %s", err) return } - proxy, rule, err := t.resolveMetadata(metadata) + proxy, rule, err := resolveMetadata(metadata) if err != nil { log.Warnln("Parse metadata failed: %v", err) return @@ -288,9 +289,9 @@ func (t *Tunnel) handleTCPConn(localConn C.ServerAdapter) { switch true { case rule != nil: log.Infoln("[TCP] %s --> %v match %s using %s", metadata.SourceAddress(), metadata.String(), rule.RuleType().String(), remoteConn.Chains().String()) - case t.mode == Global: + case mode == Global: log.Infoln("[TCP] %s --> %v using GLOBAL", metadata.SourceAddress(), metadata.String()) - case t.mode == Direct: + case mode == Direct: log.Infoln("[TCP] %s --> %v using DIRECT", metadata.SourceAddress(), metadata.String()) default: log.Infoln("[TCP] %s --> %v doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.String()) @@ -298,33 +299,33 @@ func (t *Tunnel) handleTCPConn(localConn C.ServerAdapter) { switch adapter := localConn.(type) { case *inbound.HTTPAdapter: - t.handleHTTP(adapter, remoteConn) + handleHTTP(adapter, remoteConn) case *inbound.SocketAdapter: - t.handleSocket(adapter, remoteConn) + handleSocket(adapter, remoteConn) } } -func (t *Tunnel) shouldResolveIP(rule C.Rule, metadata *C.Metadata) bool { +func shouldResolveIP(rule C.Rule, metadata *C.Metadata) bool { return !rule.NoResolveIP() && metadata.Host != "" && metadata.DstIP == nil } -func (t *Tunnel) match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { - t.configMux.RLock() - defer t.configMux.RUnlock() +func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { + configMux.RLock() + defer configMux.RUnlock() var resolved bool - if node := dns.DefaultHosts.Search(metadata.Host); node != nil { + if node := resolver.DefaultHosts.Search(metadata.Host); node != nil { ip := node.Data.(net.IP) metadata.DstIP = ip resolved = true } - for _, rule := range t.rules { - if !resolved && t.shouldResolveIP(rule, metadata) { - ip, err := t.resolveIP(metadata.Host) + for _, rule := range rules { + if !resolved && shouldResolveIP(rule, metadata) { + ip, err := resolver.ResolveIP(metadata.Host) if err != nil { - if !t.ignoreResolveFail { + if !ignoreResolveFail { return nil, nil, fmt.Errorf("[DNS] resolve %s error: %s", metadata.Host, err.Error()) } log.Debugln("[DNS] resolve %s error: %s", metadata.Host, err.Error()) @@ -336,7 +337,7 @@ func (t *Tunnel) match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { } if rule.Match(metadata) { - adapter, ok := t.proxies[rule.Adapter()] + adapter, ok := proxies[rule.Adapter()] if !ok { continue } @@ -348,24 +349,6 @@ func (t *Tunnel) match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { return adapter, rule, nil } } - return t.proxies["DIRECT"], nil, nil -} -func newTunnel() *Tunnel { - return &Tunnel{ - tcpQueue: channels.NewInfiniteChannel(), - udpQueue: channels.NewInfiniteChannel(), - natTable: nat.New(), - proxies: make(map[string]C.Proxy), - mode: Rule, - } -} - -// Instance return singleton instance of Tunnel -func Instance() *Tunnel { - once.Do(func() { - tunnel = newTunnel() - go tunnel.process() - }) - return tunnel + return proxies["DIRECT"], nil, nil }