chore: share dnsClient in NewResolver

This commit is contained in:
wwqgtxx 2023-11-08 20:19:48 +08:00
parent 575c1d4129
commit f260d8cf01

View file

@ -19,6 +19,7 @@ import (
D "github.com/miekg/dns" D "github.com/miekg/dns"
"github.com/samber/lo" "github.com/samber/lo"
"golang.org/x/exp/maps"
"golang.org/x/sync/singleflight" "golang.org/x/sync/singleflight"
) )
@ -370,6 +371,23 @@ type NameServer struct {
PreferH3 bool PreferH3 bool
} }
func (ns NameServer) Equal(ns2 NameServer) bool {
defer func() {
// C.ProxyAdapter compare maybe panic, just ignore
recover()
}()
if ns.Net == ns2.Net &&
ns.Addr == ns2.Addr &&
ns.Interface == ns2.Interface &&
ns.ProxyAdapter == ns2.ProxyAdapter &&
ns.ProxyName == ns2.ProxyName &&
maps.Equal(ns.Params, ns2.Params) &&
ns.PreferH3 == ns2.PreferH3 {
return true
}
return false
}
type FallbackFilter struct { type FallbackFilter struct {
GeoIP bool GeoIP bool
GeoIPCode string GeoIPCode string
@ -399,20 +417,47 @@ func NewResolver(config Config) *Resolver {
ipv6Timeout: time.Duration(config.IPv6Timeout) * time.Millisecond, ipv6Timeout: time.Duration(config.IPv6Timeout) * time.Millisecond,
} }
var nameServerCache []struct {
NameServer
dnsClient
}
cacheTransform := func(nameserver []NameServer) (result []dnsClient) {
LOOP:
for _, ns := range nameserver {
for _, nsc := range nameServerCache {
if nsc.NameServer.Equal(ns) {
result = append(result, nsc.dnsClient)
continue LOOP
}
}
// not in cache
dc := transform([]NameServer{ns}, defaultResolver)
if len(dc) > 0 {
dc := dc[0]
nameServerCache = append(nameServerCache, struct {
NameServer
dnsClient
}{NameServer: ns, dnsClient: dc})
result = append(result, dc)
}
}
return
}
r := &Resolver{ r := &Resolver{
ipv6: config.IPv6, ipv6: config.IPv6,
main: transform(config.Main, defaultResolver), main: cacheTransform(config.Main),
lruCache: cache.New(cache.WithSize[string, *D.Msg](4096), cache.WithStale[string, *D.Msg](true)), lruCache: cache.New(cache.WithSize[string, *D.Msg](4096), cache.WithStale[string, *D.Msg](true)),
hosts: config.Hosts, hosts: config.Hosts,
ipv6Timeout: time.Duration(config.IPv6Timeout) * time.Millisecond, ipv6Timeout: time.Duration(config.IPv6Timeout) * time.Millisecond,
} }
if len(config.Fallback) != 0 { if len(config.Fallback) != 0 {
r.fallback = transform(config.Fallback, defaultResolver) r.fallback = cacheTransform(config.Fallback)
} }
if len(config.ProxyServer) != 0 { if len(config.ProxyServer) != 0 {
r.proxyServer = transform(config.ProxyServer, defaultResolver) r.proxyServer = cacheTransform(config.ProxyServer)
} }
if len(config.Policy) != 0 { if len(config.Policy) != 0 {
@ -426,6 +471,7 @@ func NewResolver(config Config) *Resolver {
triePolicy = nil triePolicy = nil
} }
} }
for _, p := range config.Policy { for _, p := range config.Policy {
domain, nameserver := p.Extract() domain, nameserver := p.Extract()
domain = strings.ToLower(domain) domain = strings.ToLower(domain)
@ -439,7 +485,7 @@ func NewResolver(config Config) *Resolver {
insertTriePolicy() insertTriePolicy()
r.policy = append(r.policy, domainSetPolicy{ r.policy = append(r.policy, domainSetPolicy{
domainSetProvider: p, domainSetProvider: p,
dnsClients: transform(nameserver, defaultResolver), dnsClients: cacheTransform(nameserver),
}) })
continue continue
} }
@ -458,7 +504,7 @@ func NewResolver(config Config) *Resolver {
r.policy = append(r.policy, geositePolicy{ r.policy = append(r.policy, geositePolicy{
matcher: matcher, matcher: matcher,
inverse: inverse, inverse: inverse,
dnsClients: transform(nameserver, defaultResolver), dnsClients: cacheTransform(nameserver),
}) })
continue continue
} }
@ -466,7 +512,7 @@ func NewResolver(config Config) *Resolver {
if triePolicy == nil { if triePolicy == nil {
triePolicy = trie.New[[]dnsClient]() triePolicy = trie.New[[]dnsClient]()
} }
_ = triePolicy.Insert(domain, transform(nameserver, defaultResolver)) _ = triePolicy.Insert(domain, cacheTransform(nameserver))
} }
insertTriePolicy() insertTriePolicy()
} }