From 1a21c8ebfdf6e35f47158eed90993b9e46ed42aa Mon Sep 17 00:00:00 2001 From: Dreamacro <305009791@qq.com> Date: Sun, 14 Jul 2019 19:29:58 +0800 Subject: [PATCH] Feature(dns): support custom hosts --- README.md | 6 ++ component/domain-trie/node.go | 26 +++++++ component/domain-trie/tire.go | 84 ++++++++++++++++++++ component/domain-trie/trie_test.go | 69 ++++++++++++++++ config/config.go | 30 +++++-- dns/middleware.go | 121 +++++++++++++++++++++++++++++ dns/resolver.go | 23 ++++-- dns/server.go | 76 +++--------------- go.mod | 2 +- hub/executor/executor.go | 1 + 10 files changed, 359 insertions(+), 79 deletions(-) create mode 100644 component/domain-trie/node.go create mode 100644 component/domain-trie/tire.go create mode 100644 component/domain-trie/trie_test.go create mode 100644 dns/middleware.go diff --git a/README.md b/README.md index a031ecea..f0d4f2d7 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,12 @@ experimental: # listen: 0.0.0.0:53 # enhanced-mode: redir-host # or fake-ip # # fake-ip-range: 198.18.0.1/16 # if you don't know what it is, don't change it + # # experimental hosts, support wildcard (e.g. *.clash.dev Even *.foo.*.example.com) + # # static domain has a higher priority than wildcard domain (foo.example.com > *.example.com) + # # NOTE: hosts don't work with `fake-ip` + # hosts: + # '*.clash.dev': 127.0.0.1 + # 'alpha.clash.dev': '::1' # nameserver: # - 114.114.114.114 # - tls://dns.rubyfish.cn:853 # dns over tls diff --git a/component/domain-trie/node.go b/component/domain-trie/node.go new file mode 100644 index 00000000..be8ba91f --- /dev/null +++ b/component/domain-trie/node.go @@ -0,0 +1,26 @@ +package trie + +// Node is the trie's node +type Node struct { + Data interface{} + children map[string]*Node +} + +func (n *Node) getChild(s string) *Node { + return n.children[s] +} + +func (n *Node) hasChild(s string) bool { + return n.getChild(s) != nil +} + +func (n *Node) addChild(s string, child *Node) { + n.children[s] = child +} + +func newNode(data interface{}) *Node { + return &Node{ + Data: data, + children: map[string]*Node{}, + } +} diff --git a/component/domain-trie/tire.go b/component/domain-trie/tire.go new file mode 100644 index 00000000..8d570214 --- /dev/null +++ b/component/domain-trie/tire.go @@ -0,0 +1,84 @@ +package trie + +import ( + "errors" + "strings" +) + +const ( + wildcard = "*" + domainStep = "." +) + +var ( + // ErrInvalidDomain means insert domain is invalid + ErrInvalidDomain = errors.New("invalid domain") +) + +// Trie contains the main logic for adding and searching nodes for domain segments. +// support wildcard domain (e.g *.google.com) +type Trie struct { + root *Node +} + +// Insert adds a node to the trie. +// Support +// 1. www.example.com +// 2. *.example.com +// 3. subdomain.*.example.com +func (t *Trie) Insert(domain string, data interface{}) error { + parts := strings.Split(domain, domainStep) + if len(parts) < 2 { + return ErrInvalidDomain + } + + node := t.root + // reverse storage domain part to save space + for i := len(parts) - 1; i >= 0; i-- { + part := parts[i] + if !node.hasChild(part) { + node.addChild(part, newNode(nil)) + } + + node = node.getChild(part) + } + + node.Data = data + return nil +} + +// Search is the most important part of the Trie. +// Priority as: +// 1. static part +// 2. wildcard domain +func (t *Trie) Search(domain string) *Node { + parts := strings.Split(domain, domainStep) + if len(parts) < 2 { + return nil + } + + n := t.root + for i := len(parts) - 1; i >= 0; i-- { + part := parts[i] + + var child *Node + if !n.hasChild(part) { + if !n.hasChild(wildcard) { + return nil + } + + child = n.getChild(wildcard) + } else { + child = n.getChild(part) + } + + n = child + } + + return n +} + +// New returns a new, empty Trie. +func New() *Trie { + return &Trie{root: newNode(nil)} +} diff --git a/component/domain-trie/trie_test.go b/component/domain-trie/trie_test.go new file mode 100644 index 00000000..cd80ce3d --- /dev/null +++ b/component/domain-trie/trie_test.go @@ -0,0 +1,69 @@ +package trie + +import ( + "net" + "testing" +) + +func TestTrie_Basic(t *testing.T) { + tree := New() + domains := []string{ + "example.com", + "google.com", + } + + for _, domain := range domains { + tree.Insert(domain, net.ParseIP("127.0.0.1")) + } + + node := tree.Search("example.com") + if node == nil { + t.Error("should not recv nil") + } + + if !node.Data.(net.IP).Equal(net.IP{127, 0, 0, 1}) { + t.Error("should equal 127.0.0.1") + } +} + +func TestTrie_Wildcard(t *testing.T) { + tree := New() + domains := []string{ + "*.example.com", + "sub.*.example.com", + "*.dev", + } + + for _, domain := range domains { + tree.Insert(domain, nil) + } + + if tree.Search("sub.example.com") == nil { + t.Error("should not recv nil") + } + + if tree.Search("sub.foo.example.com") == nil { + t.Error("should not recv nil") + } + + if tree.Search("foo.sub.example.com") != nil { + t.Error("should recv nil") + } + + if tree.Search("foo.example.dev") != nil { + t.Error("should recv nil") + } +} + +func TestTrie_Boundary(t *testing.T) { + tree := New() + tree.Insert("*.dev", nil) + + if err := tree.Insert("com", nil); err == nil { + t.Error("should recv err") + } + + if tree.Search("dev") != nil { + t.Error("should recv nil") + } +} diff --git a/config/config.go b/config/config.go index 2bbaf9bd..d11e243c 100644 --- a/config/config.go +++ b/config/config.go @@ -12,6 +12,7 @@ import ( adapters "github.com/Dreamacro/clash/adapters/outbound" "github.com/Dreamacro/clash/common/structure" "github.com/Dreamacro/clash/component/auth" + trie "github.com/Dreamacro/clash/component/domain-trie" "github.com/Dreamacro/clash/component/fakeip" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/dns" @@ -42,6 +43,7 @@ type DNS struct { IPv6 bool `yaml:"ipv6"` NameServer []dns.NameServer `yaml:"nameserver"` Fallback []dns.NameServer `yaml:"fallback"` + Hosts *trie.Trie `yaml:"-"` Listen string `yaml:"listen"` EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"` FakeIPRange *fakeip.Pool @@ -63,13 +65,14 @@ type Config struct { } type rawDNS struct { - Enable bool `yaml:"enable"` - IPv6 bool `yaml:"ipv6"` - NameServer []string `yaml:"nameserver"` - Fallback []string `yaml:"fallback"` - Listen string `yaml:"listen"` - EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"` - FakeIPRange string `yaml:"fake-ip-range"` + Enable bool `yaml:"enable"` + IPv6 bool `yaml:"ipv6"` + NameServer []string `yaml:"nameserver"` + Hosts map[string]string `yaml:"hosts"` + Fallback []string `yaml:"fallback"` + Listen string `yaml:"listen"` + EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"` + FakeIPRange string `yaml:"fake-ip-range"` } type rawConfig struct { @@ -134,6 +137,7 @@ func readConfig(path string) (*rawConfig, error) { DNS: rawDNS{ Enable: false, FakeIPRange: "198.18.0.1/16", + Hosts: map[string]string{}, }, } err = yaml.Unmarshal([]byte(data), &rawConfig) @@ -518,6 +522,18 @@ func parseDNS(cfg rawDNS) (*DNS, error) { return nil, err } + if len(cfg.Hosts) != 0 { + tree := trie.New() + for domain, ipStr := range cfg.Hosts { + ip := net.ParseIP(ipStr) + if ip == nil { + return nil, fmt.Errorf("%s is not a valid IP", ipStr) + } + tree.Insert(domain, ip) + } + dnsCfg.Hosts = tree + } + if cfg.EnhancedMode == dns.FAKEIP { _, ipnet, err := net.ParseCIDR(cfg.FakeIPRange) if err != nil { diff --git a/dns/middleware.go b/dns/middleware.go new file mode 100644 index 00000000..1b9198bd --- /dev/null +++ b/dns/middleware.go @@ -0,0 +1,121 @@ +package dns + +import ( + "fmt" + "net" + "strings" + + "github.com/Dreamacro/clash/common/cache" + "github.com/Dreamacro/clash/component/fakeip" + "github.com/Dreamacro/clash/log" + + D "github.com/miekg/dns" +) + +type handler func(w D.ResponseWriter, r *D.Msg) + +func withFakeIP(cache *cache.Cache, pool *fakeip.Pool) handler { + return func(w D.ResponseWriter, r *D.Msg) { + q := r.Question[0] + + cacheItem := cache.Get("fakeip:" + q.String()) + if cache != nil { + msg := cacheItem.(*D.Msg).Copy() + setMsgTTL(msg, 1) + msg.SetReply(r) + w.WriteMsg(msg) + return + } + + rr := &D.A{} + rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL} + ip := pool.Get() + rr.A = ip + msg := r.Copy() + msg.Answer = []D.RR{rr} + putMsgToCache(cache, "fakeip:"+q.String(), msg) + putMsgToCache(cache, ip.String(), msg) + + setMsgTTL(msg, 1) + return + } +} + +func withResolver(resolver *Resolver) handler { + return func(w D.ResponseWriter, r *D.Msg) { + msg, err := resolver.Exchange(r) + + if err != nil { + q := r.Question[0] + qString := fmt.Sprintf("%s %s %s", q.Name, D.Class(q.Qclass).String(), D.Type(q.Qtype).String()) + log.Debugln("[DNS Server] Exchange %s failed: %v", qString, err) + D.HandleFailed(w, r) + return + } + msg.SetReply(r) + w.WriteMsg(msg) + return + } +} + +func withHost(resolver *Resolver, next handler) handler { + hosts := resolver.hosts + if hosts == nil { + panic("dns/withHost: hosts should not be nil") + } + + return func(w D.ResponseWriter, r *D.Msg) { + q := r.Question[0] + if q.Qtype != D.TypeA && q.Qtype != D.TypeAAAA { + next(w, r) + return + } + + domain := strings.TrimRight(q.Name, ".") + host := hosts.Search(domain) + if host == nil { + next(w, r) + return + } + + ip := host.Data.(net.IP) + if q.Qtype == D.TypeAAAA && ip.To16() == nil { + next(w, r) + return + } else if q.Qtype == D.TypeA && ip.To4() == nil { + next(w, r) + return + } + + var rr D.RR + if q.Qtype == D.TypeAAAA { + record := &D.AAAA{} + record.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: dnsDefaultTTL} + record.AAAA = ip + rr = record + } else { + record := &D.A{} + record.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL} + record.A = ip + rr = record + } + + msg := r.Copy() + msg.Answer = []D.RR{rr} + msg.SetReply(r) + w.WriteMsg(msg) + return + } +} + +func newHandler(resolver *Resolver) handler { + if resolver.IsFakeIP() { + return withFakeIP(resolver.cache, resolver.pool) + } + + if resolver.hosts != nil { + return withHost(resolver, withResolver(resolver)) + } + + return withResolver(resolver) +} diff --git a/dns/resolver.go b/dns/resolver.go index 044976f8..a41e68fd 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -11,11 +11,13 @@ import ( "github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/common/picker" + trie "github.com/Dreamacro/clash/component/domain-trie" "github.com/Dreamacro/clash/component/fakeip" C "github.com/Dreamacro/clash/constant" D "github.com/miekg/dns" geoip2 "github.com/oschwald/geoip2-golang" + "golang.org/x/sync/singleflight" ) var ( @@ -44,9 +46,11 @@ type Resolver struct { ipv6 bool mapping bool fakeip bool + hosts *trie.Trie pool *fakeip.Pool fallback []resolver main []resolver + group singleflight.Group cache *cache.Cache } @@ -134,13 +138,20 @@ func (r *Resolver) Exchange(m *D.Msg) (msg *D.Msg, err error) { } }() - isIPReq := isIPRequest(q) - if isIPReq { - msg, err = r.fallbackExchange(m) - return + ret, err, _ := r.group.Do(q.String(), func() (interface{}, error) { + isIPReq := isIPRequest(q) + if isIPReq { + msg, err := r.fallbackExchange(m) + return msg, err + } + + return r.batchExchange(r.main, m) + }) + + if err == nil { + msg = ret.(*D.Msg) } - msg, err = r.batchExchange(r.main, m) return } @@ -266,6 +277,7 @@ type Config struct { Main, Fallback []NameServer IPv6 bool EnhancedMode EnhancedMode + Hosts *trie.Trie Pool *fakeip.Pool } @@ -280,6 +292,7 @@ func New(config Config) *Resolver { cache: cache.New(time.Second * 60), mapping: config.EnhancedMode == MAPPING, fakeip: config.EnhancedMode == FAKEIP, + hosts: config.Hosts, pool: config.Pool, } if len(config.Fallback) != 0 { diff --git a/dns/server.go b/dns/server.go index 21dcac09..9649587d 100644 --- a/dns/server.go +++ b/dns/server.go @@ -1,12 +1,8 @@ package dns import ( - "errors" - "fmt" "net" - "github.com/Dreamacro/clash/log" - "github.com/miekg/dns" D "github.com/miekg/dns" ) @@ -19,79 +15,26 @@ var ( type Server struct { *D.Server - r *Resolver + handler handler } func (s *Server) ServeDNS(w D.ResponseWriter, r *D.Msg) { - if s.r.IsFakeIP() { - msg, err := s.handleFakeIP(r) - if err != nil { - D.HandleFailed(w, r) - return - } - msg.SetReply(r) - w.WriteMsg(msg) - return - } - - msg, err := s.r.Exchange(r) - - if err != nil { - if len(r.Question) > 0 { - q := r.Question[0] - qString := fmt.Sprintf("%s %s %s", q.Name, D.Class(q.Qclass).String(), D.Type(q.Qtype).String()) - log.Debugln("[DNS Server] Exchange %s failed: %v", qString, err) - } + if len(r.Question) == 0 { D.HandleFailed(w, r) return } - msg.SetReply(r) - w.WriteMsg(msg) + + s.handler(w, r) } -func (s *Server) handleFakeIP(r *D.Msg) (msg *D.Msg, err error) { - if len(r.Question) == 0 { - err = errors.New("should have one question at least") - return - } - - q := r.Question[0] - - cache := s.r.cache.Get("fakeip:" + q.String()) - if cache != nil { - msg = cache.(*D.Msg).Copy() - setMsgTTL(msg, 1) - return - } - - var ip net.IP - defer func() { - if msg == nil { - return - } - - putMsgToCache(s.r.cache, "fakeip:"+q.String(), msg) - putMsgToCache(s.r.cache, ip.String(), msg) - - setMsgTTL(msg, 1) - }() - - rr := &D.A{} - rr.Hdr = dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: dnsDefaultTTL} - ip = s.r.pool.Get() - rr.A = ip - msg = r.Copy() - msg.Answer = []D.RR{rr} - return -} - -func (s *Server) setReslover(r *Resolver) { - s.r = r +func (s *Server) setHandler(handler handler) { + s.handler = handler } func ReCreateServer(addr string, resolver *Resolver) error { if addr == address { - server.setReslover(resolver) + handler := newHandler(resolver) + server.setHandler(handler) return nil } @@ -116,7 +59,8 @@ func ReCreateServer(addr string, resolver *Resolver) error { } address = addr - server = &Server{r: resolver} + handler := newHandler(resolver) + server = &Server{handler: handler} server.Server = &D.Server{Addr: addr, PacketConn: p, Handler: server} go func() { diff --git a/go.mod b/go.mod index c64d6c20..23fe1c71 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/sirupsen/logrus v1.4.2 golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734 golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 - golang.org/x/sync v0.0.0-20190423024810-112230192c58 // indirect + golang.org/x/sync v0.0.0-20190423024810-112230192c58 gopkg.in/eapache/channels.v1 v1.1.0 gopkg.in/yaml.v2 v2.2.2 ) diff --git a/hub/executor/executor.go b/hub/executor/executor.go index f295108f..82d5f6e0 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -67,6 +67,7 @@ func updateDNS(c *config.DNS) { Main: c.NameServer, Fallback: c.Fallback, IPv6: c.IPv6, + Hosts: c.Hosts, EnhancedMode: c.EnhancedMode, Pool: c.FakeIPRange, })