diff --git a/adapter/outbound/vless.go b/adapter/outbound/vless.go index 2cac7095..940f16fa 100644 --- a/adapter/outbound/vless.go +++ b/adapter/outbound/vless.go @@ -209,8 +209,8 @@ func (v *Vless) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn return NewConn(c, v), err } -// DialUDP implements C.ProxyAdapter -func (v *Vless) DialUDP(metadata *C.Metadata) (_ C.PacketConn, err error) { +// ListenPacketContext implements C.ProxyAdapter +func (v *Vless) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { // vmess use stream-oriented udp with a special address, so we needs a net.UDPAddr if !metadata.Resolved() { ip, err := resolver.ResolveIP(metadata.Host) @@ -231,8 +231,6 @@ func (v *Vless) DialUDP(metadata *C.Metadata) (_ C.PacketConn, err error) { c, err = v.client.StreamConn(c, parseVlessAddr(metadata)) } else { - ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout) - defer cancel() c, err = dialer.DialContext(ctx, "tcp", v.addr) if err != nil { return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index 75bbb868..0951b190 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -9,6 +9,22 @@ import ( ) func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) { + opt := &config{} + + for _, o := range options { + o(opt) + } + + if !opt.skipDefault { + for _, o := range DefaultOptions { + o(opt) + } + } + + for _, o := range options { + o(opt) + } + switch network { case "tcp4", "tcp6", "udp4", "udp6": host, port, err := net.SplitHostPort(address) @@ -19,17 +35,25 @@ func DialContext(ctx context.Context, network, address string, options ...Option var ip net.IP switch network { case "tcp4", "udp4": - ip, err = resolver.ResolveIPv4(host) + if opt.interfaceName != "" { + ip, err = resolver.ResolveIPv4WithMain(host) + } else { + ip, err = resolver.ResolveIPv4(host) + } default: - ip, err = resolver.ResolveIPv6(host) + if opt.interfaceName != "" { + ip, err = resolver.ResolveIPv6WithMain(host) + } else { + ip, err = resolver.ResolveIPv6(host) + } } if err != nil { return nil, err } - return dialContext(ctx, network, ip, port, options) + return dialContext(ctx, network, ip, port, opt) case "tcp", "udp": - return dualStackDialContext(ctx, network, address, options) + return dualStackDialContext(ctx, network, address, opt) default: return nil, errors.New("network invalid") } @@ -38,6 +62,10 @@ func DialContext(ctx context.Context, network, address string, options ...Option func ListenPacket(ctx context.Context, network, address string, options ...Option) (net.PacketConn, error) { cfg := &config{} + for _, o := range options { + o(cfg) + } + if !cfg.skipDefault { for _, o := range DefaultOptions { o(cfg) @@ -63,19 +91,7 @@ func ListenPacket(ctx context.Context, network, address string, options ...Optio return lc.ListenPacket(ctx, network, address) } -func dialContext(ctx context.Context, network string, destination net.IP, port string, options []Option) (net.Conn, error) { - opt := &config{} - - if !opt.skipDefault { - for _, o := range DefaultOptions { - o(opt) - } - } - - for _, o := range options { - o(opt) - } - +func dialContext(ctx context.Context, network string, destination net.IP, port string, opt *config) (net.Conn, error) { dialer := &net.Dialer{} if opt.interfaceName != "" { if err := bindIfaceToDialer(opt.interfaceName, dialer, network, destination); err != nil { @@ -86,7 +102,7 @@ func dialContext(ctx context.Context, network string, destination net.IP, port s return dialer.DialContext(ctx, network, net.JoinHostPort(destination.String(), port)) } -func dualStackDialContext(ctx context.Context, network, address string, options []Option) (net.Conn, error) { +func dualStackDialContext(ctx context.Context, network, address string, opt *config) (net.Conn, error) { host, port, err := net.SplitHostPort(address) if err != nil { return nil, err @@ -119,16 +135,24 @@ func dualStackDialContext(ctx context.Context, network, address string, options var ip net.IP if ipv6 { - ip, result.error = resolver.ResolveIPv6(host) + if opt.interfaceName != "" { + ip, result.error = resolver.ResolveIPv6WithMain(host) + } else { + ip, result.error = resolver.ResolveIPv6(host) + } } else { - ip, result.error = resolver.ResolveIPv4(host) + if opt.interfaceName != "" { + ip, result.error = resolver.ResolveIPv4WithMain(host) + } else { + ip, result.error = resolver.ResolveIPv4(host) + } } if result.error != nil { return } result.resolved = true - result.Conn, result.error = dialContext(ctx, network, ip, port, options) + result.Conn, result.error = dialContext(ctx, network, ip, port, opt) } go startRacer(ctx, network+"4", host, false) diff --git a/component/geodata/utils.go b/component/geodata/utils.go new file mode 100644 index 00000000..3a48dc86 --- /dev/null +++ b/component/geodata/utils.go @@ -0,0 +1,30 @@ +package geodata + +import ( + "github.com/Dreamacro/clash/component/geodata/router" +) + +func LoadGeoSiteMatcher(countryCode string) (*router.DomainMatcher, int, error) { + geoLoaderName := "standard" + geoLoader, err := GetGeoDataLoader(geoLoaderName) + if err != nil { + return nil, 0, err + } + + domains, err := geoLoader.LoadGeoSite(countryCode) + if err != nil { + return nil, 0, err + } + + /** + linear: linear algorithm + matcher, err := router.NewDomainMatcher(domains) + mph:minimal perfect hash algorithm + */ + matcher, err := router.NewMphMatcherGroup(domains) + if err != nil { + return nil, 0, err + } + + return matcher, len(domains), nil +} diff --git a/component/resolver/resolver.go b/component/resolver/resolver.go index d10e39cb..a7300bd7 100644 --- a/component/resolver/resolver.go +++ b/component/resolver/resolver.go @@ -15,6 +15,9 @@ var ( // DefaultResolver aim to resolve ip DefaultResolver Resolver + // MainResolver resolve ip with main domain server + MainResolver Resolver + // DisableIPv6 means don't resolve ipv6 host // default value is true DisableIPv6 = true @@ -40,6 +43,14 @@ type Resolver interface { // ResolveIPv4 with a host, return ipv4 func ResolveIPv4(host string) (net.IP, error) { + return ResolveIPv4WithResolver(host, DefaultResolver) +} + +func ResolveIPv4WithMain(host string) (net.IP, error) { + return ResolveIPv4WithResolver(host, MainResolver) +} + +func ResolveIPv4WithResolver(host string, r Resolver) (net.IP, error) { if node := DefaultHosts.Search(host); node != nil { if ip := node.Data.(net.IP).To4(); ip != nil { return ip, nil @@ -54,8 +65,8 @@ func ResolveIPv4(host string) (net.IP, error) { return nil, ErrIPVersion } - if DefaultResolver != nil { - return DefaultResolver.ResolveIPv4(host) + if r != nil { + return r.ResolveIPv4(host) } ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) @@ -72,6 +83,14 @@ func ResolveIPv4(host string) (net.IP, error) { // ResolveIPv6 with a host, return ipv6 func ResolveIPv6(host string) (net.IP, error) { + return ResolveIPv6WithResolver(host, DefaultResolver) +} + +func ResolveIPv6WithMain(host string) (net.IP, error) { + return ResolveIPv6WithResolver(host, MainResolver) +} + +func ResolveIPv6WithResolver(host string, r Resolver) (net.IP, error) { if DisableIPv6 { return nil, ErrIPv6Disabled } @@ -90,8 +109,8 @@ func ResolveIPv6(host string) (net.IP, error) { return nil, ErrIPVersion } - if DefaultResolver != nil { - return DefaultResolver.ResolveIPv6(host) + if r != nil { + return r.ResolveIPv6(host) } ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) @@ -138,3 +157,8 @@ func ResolveIPWithResolver(host string, r Resolver) (net.IP, error) { func ResolveIP(host string) (net.IP, error) { return ResolveIPWithResolver(host, DefaultResolver) } + +// ResolveIPWithMainResolver with a host, use main resolver, return ip +func ResolveIPWithMainResolver(host string) (net.IP, error) { + return ResolveIPWithResolver(host, MainResolver) +} diff --git a/config/config.go b/config/config.go index eed0b4d0..f44fb568 100644 --- a/config/config.go +++ b/config/config.go @@ -16,6 +16,8 @@ import ( "github.com/Dreamacro/clash/adapter/provider" "github.com/Dreamacro/clash/component/auth" "github.com/Dreamacro/clash/component/fakeip" + "github.com/Dreamacro/clash/component/geodata" + "github.com/Dreamacro/clash/component/geodata/router" S "github.com/Dreamacro/clash/component/script" "github.com/Dreamacro/clash/component/trie" C "github.com/Dreamacro/clash/constant" @@ -75,10 +77,11 @@ type DNS struct { // FallbackFilter config type FallbackFilter struct { - GeoIP bool `yaml:"geoip"` - GeoIPCode string `yaml:"geoip-code"` - IPCIDR []*net.IPNet `yaml:"ipcidr"` - Domain []string `yaml:"domain"` + GeoIP bool `yaml:"geoip"` + GeoIPCode string `yaml:"geoip-code"` + IPCIDR []*net.IPNet `yaml:"ipcidr"` + Domain []string `yaml:"domain"` + GeoSite []*router.DomainMatcher `yaml:"geosite"` } // Profile config @@ -139,6 +142,7 @@ type RawFallbackFilter struct { GeoIPCode string `yaml:"geoip-code"` IPCIDR []string `yaml:"ipcidr"` Domain []string `yaml:"domain"` + GeoSite []string `yaml:"geosite"` } type RawConfig struct { @@ -206,6 +210,7 @@ func UnmarshalRawConfig(buf []byte) (*RawConfig, error) { GeoIP: true, GeoIPCode: "CN", IPCIDR: []string{}, + GeoSite: []string{}, }, DefaultNameserver: []string{ "114.114.114.114", @@ -265,7 +270,7 @@ func ParseRawConfig(rawCfg *RawConfig) (*Config, error) { } config.Hosts = hosts - dnsCfg, err := parseDNS(rawCfg, hosts) + dnsCfg, err := parseDNS(rawCfg, hosts, rules) if err != nil { return nil, err } @@ -648,8 +653,9 @@ func parseNameServer(servers []string) ([]dns.NameServer, error) { nameservers = append( nameservers, dns.NameServer{ - Net: dnsNetType, - Addr: addr, + Net: dnsNetType, + Addr: addr, + ProxyAdapter: u.Fragment, }, ) } @@ -687,7 +693,37 @@ func parseFallbackIPCIDR(ips []string) ([]*net.IPNet, error) { return ipNets, nil } -func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie) (*DNS, error) { +func parseFallbackGeoSite(countries []string, rules []C.Rule) ([]*router.DomainMatcher, error) { + sites := []*router.DomainMatcher{} + + for _, country := range countries { + found := false + for _, rule := range rules { + if rule.RuleType() == C.GEOSITE { + if strings.EqualFold(country, rule.Payload()) { + found = true + sites = append(sites, rule.(C.RuleGeoSite).GetDomainMatcher()) + log.Infoln("Start initial GeoSite dns fallback filter from rule `%s`", country) + } + } + } + + if !found { + matcher, recordsCount, err := geodata.LoadGeoSiteMatcher(country) + if err != nil { + return nil, err + } + + sites = append(sites, matcher) + + log.Infoln("Start initial GeoSite dns fallback filter `%s`, records: %d", country, recordsCount) + } + } + runtime.GC() + return sites, nil +} + +func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie, rules []C.Rule) (*DNS, error) { cfg := rawCfg.DNS if cfg.Enable && len(cfg.NameServer) == 0 { return nil, fmt.Errorf("if DNS configuration is turned on, NameServer cannot be empty") @@ -699,7 +735,8 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie) (*DNS, error) { IPv6: cfg.IPv6, EnhancedMode: cfg.EnhancedMode, FallbackFilter: FallbackFilter{ - IPCIDR: []*net.IPNet{}, + IPCIDR: []*net.IPNet{}, + GeoSite: []*router.DomainMatcher{}, }, } var err error @@ -744,6 +781,18 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie) (*DNS, error) { } } + if len(dnsCfg.Fallback) != 0 { + if host == nil { + host = trie.New() + } + for _, fb := range dnsCfg.Fallback { + if net.ParseIP(fb.Addr) != nil { + continue + } + host.Insert(fb.Addr, true) + } + } + pool, err := fakeip.New(fakeip.Options{ IPNet: ipnet, Size: 1000, @@ -757,12 +806,19 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie) (*DNS, error) { dnsCfg.FakeIPRange = pool } - dnsCfg.FallbackFilter.GeoIP = cfg.FallbackFilter.GeoIP - dnsCfg.FallbackFilter.GeoIPCode = cfg.FallbackFilter.GeoIPCode - if fallbackip, err := parseFallbackIPCIDR(cfg.FallbackFilter.IPCIDR); err == nil { - dnsCfg.FallbackFilter.IPCIDR = fallbackip + if len(cfg.Fallback) != 0 { + dnsCfg.FallbackFilter.GeoIP = cfg.FallbackFilter.GeoIP + dnsCfg.FallbackFilter.GeoIPCode = cfg.FallbackFilter.GeoIPCode + if fallbackip, err := parseFallbackIPCIDR(cfg.FallbackFilter.IPCIDR); err == nil { + dnsCfg.FallbackFilter.IPCIDR = fallbackip + } + dnsCfg.FallbackFilter.Domain = cfg.FallbackFilter.Domain + fallbackGeoSite, err := parseFallbackGeoSite(cfg.FallbackFilter.GeoSite, rules) + if err != nil { + return nil, fmt.Errorf("load GeoSite dns fallback filter error, %w", err) + } + dnsCfg.FallbackFilter.GeoSite = fallbackGeoSite } - dnsCfg.FallbackFilter.Domain = cfg.FallbackFilter.Domain if cfg.UseHosts { dnsCfg.Hosts = hosts diff --git a/constant/rule.go b/constant/rule.go index 87ae2a32..e2087604 100644 --- a/constant/rule.go +++ b/constant/rule.go @@ -1,7 +1,5 @@ package constant -import "net" - // Rule Type const ( Domain RuleType = iota @@ -59,5 +57,3 @@ type Rule interface { ShouldResolveIP() bool RuleExtra() *RuleExtra } - -var TunBroadcastAddr = net.IPv4(198, 18, 255, 255) diff --git a/constant/rule_extra.go b/constant/rule_extra.go index c27f276e..119b42ca 100644 --- a/constant/rule_extra.go +++ b/constant/rule_extra.go @@ -1,6 +1,12 @@ package constant -import "net" +import ( + "net" + + "github.com/Dreamacro/clash/component/geodata/router" +) + +var TunBroadcastAddr = net.IPv4(198, 18, 255, 255) type RuleExtra struct { Network NetWork @@ -23,3 +29,7 @@ func (re *RuleExtra) NotMatchSourceIP(srcIP net.IP) bool { } return true } + +type RuleGeoSite interface { + GetDomainMatcher() *router.DomainMatcher +} diff --git a/dns/client.go b/dns/client.go index 5cb1fe02..04ee7d11 100644 --- a/dns/client.go +++ b/dns/client.go @@ -15,10 +15,11 @@ import ( type client struct { *D.Client - r *Resolver - port string - host string - iface string + r *Resolver + port string + host string + iface string + proxyAdapter string } func (c *client) Exchange(m *D.Msg) (*D.Msg, error) { @@ -30,14 +31,15 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error) ip net.IP err error ) - if c.r == nil { - // a default ip dns - if ip = net.ParseIP(c.host); ip == nil { + + if ip = net.ParseIP(c.host); ip == nil { + if c.r == nil { return nil, fmt.Errorf("dns %s not a valid ip", c.host) - } - } else { - if ip, err = resolver.ResolveIPWithResolver(c.host, c.r); err != nil { - return nil, fmt.Errorf("use default dns resolve failed: %w", err) + } else { + if ip, err = resolver.ResolveIPWithResolver(c.host, c.r); err != nil { + return nil, fmt.Errorf("use default dns resolve failed: %w", err) + } + c.host = ip.String() } } @@ -46,11 +48,17 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error) network = "tcp" } - options := []dialer.Option{} - if c.iface != "" { - options = append(options, dialer.WithInterface(c.iface)) + var conn net.Conn + if c.proxyAdapter != "" && network == "tcp" { + conn, err = dialContextWithProxyAdapter(ctx, c.proxyAdapter, ip, c.port) + } else { + options := []dialer.Option{} + if c.iface != "" { + options = append(options, dialer.WithInterface(c.iface)) + } + conn, err = dialer.DialContext(ctx, network, net.JoinHostPort(c.host, c.port), options...) } - conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), c.port), options...) + if err != nil { return nil, err } diff --git a/dns/doh.go b/dns/doh.go index 94312355..5d36f706 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -19,8 +19,9 @@ const ( ) type dohClient struct { - url string - transport *http.Transport + url string + proxyAdapter string + transport *http.Transport } func (dc *dohClient) Exchange(m *D.Msg) (msg *D.Msg, err error) { @@ -62,7 +63,7 @@ func (dc *dohClient) newRequest(m *D.Msg) (*http.Request, error) { return req, nil } -func (dc *dohClient) doRequest(req *http.Request) (msg *D.Msg, err error) { +func (dc *dohClient) doRequest(req *http.Request) (*D.Msg, error) { client := &http.Client{Transport: dc.transport} resp, err := client.Do(req) if err != nil { @@ -74,14 +75,15 @@ func (dc *dohClient) doRequest(req *http.Request) (msg *D.Msg, err error) { if err != nil { return nil, err } - msg = &D.Msg{} + msg := &D.Msg{} err = msg.Unpack(buf) return msg, err } -func newDoHClient(url string, r *Resolver) *dohClient { +func newDoHClient(url string, r *Resolver, proxyAdapter string) *dohClient { return &dohClient{ - url: url, + url: url, + proxyAdapter: proxyAdapter, transport: &http.Transport{ ForceAttemptHTTP2: true, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { @@ -95,7 +97,11 @@ func newDoHClient(url string, r *Resolver) *dohClient { return nil, err } - return dialer.DialContext(ctx, "tcp", net.JoinHostPort(ip.String(), port)) + if proxyAdapter == "" { + return dialer.DialContext(ctx, "tcp", net.JoinHostPort(ip.String(), port)) + } else { + return dialContextWithProxyAdapter(ctx, proxyAdapter, ip, port) + } }, }, } diff --git a/dns/filters.go b/dns/filters.go index 5939020e..9c1be0d0 100644 --- a/dns/filters.go +++ b/dns/filters.go @@ -4,6 +4,7 @@ import ( "net" "strings" + "github.com/Dreamacro/clash/component/geodata/router" "github.com/Dreamacro/clash/component/mmdb" "github.com/Dreamacro/clash/component/trie" C "github.com/Dreamacro/clash/constant" @@ -49,3 +50,16 @@ func NewDomainFilter(domains []string) *domainFilter { func (df *domainFilter) Match(domain string) bool { return df.tree.Search(domain) != nil } + +type geoSiteFilter struct { + matchers []*router.DomainMatcher +} + +func (gsf *geoSiteFilter) Match(domain string) bool { + for _, matcher := range gsf.matchers { + if matcher.ApplyDomain(domain) { + return true + } + } + return false +} diff --git a/dns/resolver.go b/dns/resolver.go index 91efbb8b..31929a54 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -12,6 +12,7 @@ import ( "github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/common/picker" "github.com/Dreamacro/clash/component/fakeip" + "github.com/Dreamacro/clash/component/geodata/router" "github.com/Dreamacro/clash/component/resolver" "github.com/Dreamacro/clash/component/trie" C "github.com/Dreamacro/clash/constant" @@ -149,7 +150,7 @@ func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg) (msg *D.M return } -func (r *Resolver) batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) { +func (r *Resolver) batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (*D.Msg, error) { fast, ctx := picker.WithTimeout(ctx, resolver.DefaultDNSTimeout) for _, client := range clients { r := client @@ -173,8 +174,8 @@ func (r *Resolver) batchExchange(ctx context.Context, clients []dnsClient, m *D. return nil, err } - msg = elm.(*D.Msg) - return + msg := elm.(*D.Msg) + return msg, nil } func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient { @@ -215,7 +216,7 @@ func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool { return false } -func (r *Resolver) ipExchange(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { +func (r *Resolver) ipExchange(ctx context.Context, m *D.Msg) (*D.Msg, error) { if matched := r.matchPolicy(m); len(matched) != 0 { res := <-r.asyncExchange(ctx, matched, m) return res.Msg, res.Error @@ -230,27 +231,22 @@ func (r *Resolver) ipExchange(ctx context.Context, m *D.Msg) (msg *D.Msg, err er msgCh := r.asyncExchange(ctx, r.main, m) - if r.fallback == nil { // directly return if no fallback servers are available + if r.fallback == nil || len(r.fallback) == 0 { // directly return if no fallback servers are available res := <-msgCh - msg, err = res.Msg, res.Error - return + return res.Msg, res.Error } - fallbackMsg := r.asyncExchange(ctx, r.fallback, m) res := <-msgCh if res.Error == nil { if ips := msgToIP(res.Msg); len(ips) != 0 { if !r.shouldIPFallback(ips[0]) { - msg = res.Msg // no need to wait for fallback result - err = res.Error - return msg, err + return res.Msg, res.Error // no need to wait for fallback result } } } - res = <-fallbackMsg - msg, err = res.Msg, res.Error - return + res = <-r.asyncExchange(ctx, r.fallback, m) + return res.Msg, res.Error } func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error) { @@ -302,9 +298,10 @@ func (r *Resolver) asyncExchange(ctx context.Context, client []dnsClient, msg *D } type NameServer struct { - Net string - Addr string - Interface string + Net string + Addr string + Interface string + ProxyAdapter string } type FallbackFilter struct { @@ -312,6 +309,7 @@ type FallbackFilter struct { GeoIPCode string IPCIDR []*net.IPNet Domain []string + GeoSite []*router.DomainMatcher } type Config struct { @@ -360,10 +358,27 @@ func NewResolver(config Config) *Resolver { } r.fallbackIPFilters = fallbackIPFilters + fallbackDomainFilters := []fallbackDomainFilter{} if len(config.FallbackFilter.Domain) != 0 { - fallbackDomainFilters := []fallbackDomainFilter{NewDomainFilter(config.FallbackFilter.Domain)} - r.fallbackDomainFilters = fallbackDomainFilters + fallbackDomainFilters = append(fallbackDomainFilters, NewDomainFilter(config.FallbackFilter.Domain)) } + if len(config.FallbackFilter.GeoSite) != 0 { + fallbackDomainFilters = append(fallbackDomainFilters, &geoSiteFilter{ + matchers: config.FallbackFilter.GeoSite, + }) + } + r.fallbackDomainFilters = fallbackDomainFilters + + return r +} + +func NewMainResolver(old *Resolver) *Resolver { + r := &Resolver{ + ipv6: old.ipv6, + main: old.main, + lruCache: old.lruCache, + hosts: old.hosts, + } return r } diff --git a/dns/util.go b/dns/util.go index d11870f8..a0e38933 100644 --- a/dns/util.go +++ b/dns/util.go @@ -1,12 +1,16 @@ package dns import ( + "context" "crypto/tls" + "fmt" "net" "time" "github.com/Dreamacro/clash/common/cache" + C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/log" + "github.com/Dreamacro/clash/tunnel" D "github.com/miekg/dns" ) @@ -51,7 +55,7 @@ func transform(servers []NameServer, resolver *Resolver) []dnsClient { for _, s := range servers { switch s.Net { case "https": - ret = append(ret, newDoHClient(s.Addr, resolver)) + ret = append(ret, newDoHClient(s.Addr, resolver, s.ProxyAdapter)) continue case "dhcp": ret = append(ret, newDHCPClient(s.Addr)) @@ -70,10 +74,11 @@ func transform(servers []NameServer, resolver *Resolver) []dnsClient { UDPSize: 4096, Timeout: 5 * time.Second, }, - port: port, - host: host, - iface: s.Interface, - r: resolver, + port: port, + host: host, + iface: s.Interface, + r: resolver, + proxyAdapter: s.ProxyAdapter, }) } return ret @@ -104,3 +109,26 @@ func msgToIP(msg *D.Msg) []net.IP { return ips } + +func dialContextWithProxyAdapter(ctx context.Context, adapterName string, dstIP net.IP, port string) (net.Conn, error) { + adapter, ok := tunnel.Proxies()[adapterName] + if !ok { + return nil, fmt.Errorf("proxy dapter [%s] not found", adapterName) + } + + addrType := C.AtypIPv4 + + if dstIP.To4() == nil { + addrType = C.AtypIPv6 + } + + metadata := &C.Metadata{ + NetWork: C.TCP, + AddrType: addrType, + Host: "", + DstIP: dstIP, + DstPort: port, + } + + return adapter.DialContext(ctx, metadata) +} diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 76669ca7..9600c7a2 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -118,6 +118,7 @@ func updateExperimental(c *config.Config) {} func updateDNS(c *config.DNS, general *config.General) { if !c.Enable { resolver.DefaultResolver = nil + resolver.MainResolver = nil resolver.DefaultHostMapper = nil dns.ReCreateServer("", nil, nil) return @@ -135,12 +136,14 @@ func updateDNS(c *config.DNS, general *config.General) { GeoIPCode: c.FallbackFilter.GeoIPCode, IPCIDR: c.FallbackFilter.IPCIDR, Domain: c.FallbackFilter.Domain, + GeoSite: c.FallbackFilter.GeoSite, }, Default: c.DefaultNameserver, Policy: c.NameServerPolicy, } r := dns.NewResolver(cfg) + mr := dns.NewMainResolver(r) m := dns.NewEnhancer(cfg) // reuse cache of old host mapper @@ -149,6 +152,7 @@ func updateDNS(c *config.DNS, general *config.General) { } resolver.DefaultResolver = r + resolver.MainResolver = mr resolver.DefaultHostMapper = m if general.Tun.Enable && !strings.EqualFold(general.Tun.Stack, "gvisor") { resolver.DefaultLocalServer = dns.NewLocalServer(r, m) @@ -201,11 +205,12 @@ func updateGeneral(general *config.General, force bool) { if general.Interface != "" { dialer.DefaultOptions = []dialer.Option{dialer.WithInterface(general.Interface)} - log.Infoln("Use interface name: %s", general.Interface) } else { dialer.DefaultOptions = nil } + log.Infoln("Use interface name: %s", general.Interface) + iface.FlushCache() if !force { diff --git a/listener/tun/ipstack/lwip/dns.go b/listener/tun/ipstack/lwip/dns.go index b922d584..6a314c08 100644 --- a/listener/tun/ipstack/lwip/dns.go +++ b/listener/tun/ipstack/lwip/dns.go @@ -28,10 +28,6 @@ func hijackUDPDns(conn golwip.UDPConn, pkt []byte, addr *net.UDPAddr) { _ = conn.Close() }(conn) - if err := conn.SetDeadline(time.Now().Add(defaultDnsReadTimeout)); err != nil { - return - } - answer, err := D.RelayDnsPacket(pkt) if err != nil { return @@ -81,7 +77,7 @@ func hijackTCPDns(conn net.Conn) { type dnsHandler struct{} -func NewDnsHandler() golwip.DnsHandler { +func newDnsHandler() golwip.DnsHandler { return &dnsHandler{} } diff --git a/listener/tun/ipstack/lwip/tcp.go b/listener/tun/ipstack/lwip/tcp.go index 1c7cba9c..c62a6beb 100644 --- a/listener/tun/ipstack/lwip/tcp.go +++ b/listener/tun/ipstack/lwip/tcp.go @@ -15,7 +15,7 @@ type tcpHandler struct { tcpIn chan<- C.ConnContext } -func NewTCPHandler(dnsIP net.IP, tcpIn chan<- C.ConnContext) golwip.TCPConnHandler { +func newTCPHandler(dnsIP net.IP, tcpIn chan<- C.ConnContext) golwip.TCPConnHandler { return &tcpHandler{dnsIP, tcpIn} } diff --git a/listener/tun/ipstack/lwip/tun.go b/listener/tun/ipstack/lwip/tun.go index 1493e80f..9037be6e 100644 --- a/listener/tun/ipstack/lwip/tun.go +++ b/listener/tun/ipstack/lwip/tun.go @@ -50,7 +50,7 @@ func NewAdapter(device dev.TunDevice, conf config.Tun, mtu int, tcpIn chan<- C.C }) // Set custom buffer pool - golwip.SetPoolAllocator(&lwipPool{}) + golwip.SetPoolAllocator(newLWIPPool()) // Setup TCP/IP stack. lwipStack, err := golwip.NewLWIPStack(mtu) @@ -59,9 +59,9 @@ func NewAdapter(device dev.TunDevice, conf config.Tun, mtu int, tcpIn chan<- C.C } adapter.lwipStack = lwipStack - golwip.RegisterDnsHandler(NewDnsHandler()) - golwip.RegisterTCPConnHandler(NewTCPHandler(dnsIP, tcpIn)) - golwip.RegisterUDPConnHandler(NewUDPHandler(dnsIP, udpIn)) + golwip.RegisterDnsHandler(newDnsHandler()) + golwip.RegisterTCPConnHandler(newTCPHandler(dnsIP, tcpIn)) + golwip.RegisterUDPConnHandler(newUDPHandler(dnsIP, udpIn)) // Copy packets from tun device to lwip stack, it's the loop. go func(lwipStack golwip.LWIPStack, device dev.TunDevice, mtu int) { @@ -95,7 +95,7 @@ func (l *lwipAdapter) Close() { func (l *lwipAdapter) stopLocked() { if l.lwipStack != nil { - l.lwipStack.Close() + _ = l.lwipStack.Close() } if l.device != nil { @@ -115,3 +115,7 @@ func (p lwipPool) Get(size int) []byte { func (p lwipPool) Put(buf []byte) error { return pool.Put(buf) } + +func newLWIPPool() golwip.LWIPPool { + return &lwipPool{} +} diff --git a/listener/tun/ipstack/lwip/udp.go b/listener/tun/ipstack/lwip/udp.go index 5a2b9c58..747796bf 100644 --- a/listener/tun/ipstack/lwip/udp.go +++ b/listener/tun/ipstack/lwip/udp.go @@ -42,7 +42,7 @@ type udpHandler struct { udpIn chan<- *inbound.PacketAdapter } -func NewUDPHandler(dnsIP net.IP, udpIn chan<- *inbound.PacketAdapter) golwip.UDPConnHandler { +func newUDPHandler(dnsIP net.IP, udpIn chan<- *inbound.PacketAdapter) golwip.UDPConnHandler { return &udpHandler{dnsIP, udpIn} } diff --git a/listener/tun/ipstack/system/dns.go b/listener/tun/ipstack/system/dns.go index e79a57d7..77159206 100644 --- a/listener/tun/ipstack/system/dns.go +++ b/listener/tun/ipstack/system/dns.go @@ -11,7 +11,7 @@ import ( "github.com/kr328/tun2socket/redirect" ) -const defaultDnsReadTimeout = time.Second * 10 +const defaultDnsReadTimeout = time.Second * 30 func shouldHijackDns(dnsAddr binding.Address, targetAddr binding.Address) bool { if targetAddr.Port != 53 { diff --git a/rule/geosite.go b/rule/geosite.go index 9849549d..875320bd 100644 --- a/rule/geosite.go +++ b/rule/geosite.go @@ -47,29 +47,17 @@ func (gs *GEOSITE) RuleExtra() *C.RuleExtra { return gs.ruleExtra } +func (gs *GEOSITE) GetDomainMatcher() *router.DomainMatcher { + return gs.matcher +} + func NewGEOSITE(country string, adapter string, ruleExtra *C.RuleExtra) (*GEOSITE, error) { - geoLoaderName := "standard" - geoLoader, err := geodata.GetGeoDataLoader(geoLoaderName) + matcher, recordsCount, err := geodata.LoadGeoSiteMatcher(country) if err != nil { return nil, fmt.Errorf("load GeoSite data error, %s", err.Error()) } - domains, err := geoLoader.LoadGeoSite(country) - if err != nil { - return nil, fmt.Errorf("load GeoSite data error, %s", err.Error()) - } - - /** - linear: linear algorithm - matcher, err := router.NewDomainMatcher(domains) - mph:minimal perfect hash algorithm - */ - matcher, err := router.NewMphMatcherGroup(domains) - if err != nil { - return nil, fmt.Errorf("load GeoSite data error, %s", err.Error()) - } - - log.Infoln("Start initial GeoSite rule %s => %s, records: %d", country, adapter, len(domains)) + log.Infoln("Start initial GeoSite rule %s => %s, records: %d", country, adapter, recordsCount) geoSite := &GEOSITE{ country: country,