From 0778591524e180cc935176e772ede8d8a86c0f85 Mon Sep 17 00:00:00 2001 From: Rusty Pen Date: Wed, 19 May 2021 11:17:35 +0800 Subject: [PATCH] Feature: dns resolve domain through nameserver-policy (#1406) --- component/trie/domain.go | 6 +++--- config/config.go | 23 +++++++++++++++++++++++ dns/resolver.go | 35 +++++++++++++++++++++++++++++++++++ hub/executor/executor.go | 1 + 4 files changed, 62 insertions(+), 3 deletions(-) diff --git a/component/trie/domain.go b/component/trie/domain.go index c063686e..b4de4a70 100644 --- a/component/trie/domain.go +++ b/component/trie/domain.go @@ -23,7 +23,7 @@ type DomainTrie struct { root *Node } -func validAndSplitDomain(domain string) ([]string, bool) { +func ValidAndSplitDomain(domain string) ([]string, bool) { if domain != "" && domain[len(domain)-1] == '.' { return nil, false } @@ -54,7 +54,7 @@ func validAndSplitDomain(domain string) ([]string, bool) { // 4. .example.com // 5. +.example.com func (t *DomainTrie) Insert(domain string, data interface{}) error { - parts, valid := validAndSplitDomain(domain) + parts, valid := ValidAndSplitDomain(domain) if !valid { return ErrInvalidDomain } @@ -91,7 +91,7 @@ func (t *DomainTrie) insert(parts []string, data interface{}) { // 2. wildcard domain // 2. dot wildcard domain func (t *DomainTrie) Search(domain string) *Node { - parts, valid := validAndSplitDomain(domain) + parts, valid := ValidAndSplitDomain(domain) if !valid || parts[0] == "" { return nil } diff --git a/config/config.go b/config/config.go index 51faac91..90331649 100644 --- a/config/config.go +++ b/config/config.go @@ -64,6 +64,7 @@ type DNS struct { DefaultNameserver []dns.NameServer `yaml:"default-nameserver"` FakeIPRange *fakeip.Pool Hosts *trie.DomainTrie + NameServerPolicy map[string]dns.NameServer } // FallbackFilter config @@ -106,6 +107,7 @@ type RawDNS struct { FakeIPRange string `yaml:"fake-ip-range"` FakeIPFilter []string `yaml:"fake-ip-filter"` DefaultNameserver []string `yaml:"default-nameserver"` + NameServerPolicy map[string]string `yaml:"nameserver-policy"` } type RawFallbackFilter struct { @@ -500,6 +502,23 @@ func parseNameServer(servers []string) ([]dns.NameServer, error) { return nameservers, nil } +func parseNameServerPolicy(nsPolicy map[string]string) (map[string]dns.NameServer, error) { + policy := map[string]dns.NameServer{} + + for domain, server := range nsPolicy { + nameservers, err := parseNameServer([]string{server}) + if err != nil { + return nil, err + } + if _, valid := trie.ValidAndSplitDomain(domain); !valid { + return nil, fmt.Errorf("DNS ResoverRule invalid domain: %s", domain) + } + policy[domain] = nameservers[0] + } + + return policy, nil +} + func parseFallbackIPCIDR(ips []string) ([]*net.IPNet, error) { ipNets := []*net.IPNet{} @@ -537,6 +556,10 @@ func parseDNS(cfg RawDNS, hosts *trie.DomainTrie) (*DNS, error) { return nil, err } + if dnsCfg.NameServerPolicy, err = parseNameServerPolicy(cfg.NameServerPolicy); err != nil { + return nil, err + } + if len(cfg.DefaultNameserver) == 0 { return nil, errors.New("default nameserver should have at least one nameserver") } diff --git a/dns/resolver.go b/dns/resolver.go index 0db6855d..f57fec52 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -43,6 +43,7 @@ type Resolver struct { fallbackIPFilters []fallbackIPFilter group singleflight.Group lruCache *cache.LruCache + policy *trie.DomainTrie } // ResolveIP request with TypeA and TypeAAAA, priority return TypeA @@ -131,6 +132,9 @@ func (r *Resolver) exchangeWithoutCache(m *D.Msg) (msg *D.Msg, err error) { return r.ipExchange(m) } + if matched := r.matchPolicy(m); len(matched) != 0 { + return r.batchExchange(matched, m) + } return r.batchExchange(r.main, m) }) @@ -172,6 +176,24 @@ func (r *Resolver) batchExchange(clients []dnsClient, m *D.Msg) (msg *D.Msg, err return } +func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient { + if r.policy == nil { + return nil + } + + domain := r.msgToDomain(m) + if domain == "" { + return nil + } + + record := r.policy.Search(domain) + if record == nil { + return nil + } + + return record.Data.([]dnsClient) +} + func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool { if r.fallback == nil || len(r.fallbackDomainFilters) == 0 { return false @@ -194,6 +216,11 @@ func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool { func (r *Resolver) ipExchange(m *D.Msg) (msg *D.Msg, err error) { + if matched := r.matchPolicy(m); len(matched) != 0 { + res := <-r.asyncExchange(matched, m) + return res.Msg, res.Error + } + onlyFallback := r.shouldOnlyQueryFallback(m) if onlyFallback { @@ -293,6 +320,7 @@ type Config struct { FallbackFilter FallbackFilter Pool *fakeip.Pool Hosts *trie.DomainTrie + Policy map[string]NameServer } func NewResolver(config Config) *Resolver { @@ -312,6 +340,13 @@ func NewResolver(config Config) *Resolver { r.fallback = transform(config.Fallback, defaultResolver) } + if len(config.Policy) != 0 { + r.policy = trie.New() + for domain, nameserver := range config.Policy { + r.policy.Insert(domain, transform([]NameServer{nameserver}, defaultResolver)) + } + } + fallbackIPFilters := []fallbackIPFilter{} if config.FallbackFilter.GeoIP { fallbackIPFilters = append(fallbackIPFilters, &geoipFilter{}) diff --git a/hub/executor/executor.go b/hub/executor/executor.go index bea2a1ab..caf75aff 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -128,6 +128,7 @@ func updateDNS(c *config.DNS) { Domain: c.FallbackFilter.Domain, }, Default: c.DefaultNameserver, + Policy: c.NameServerPolicy, } r := dns.NewResolver(cfg)