From 4994510c87b364f1198772ce41a8d7313d9b80b0 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Wed, 11 Sep 2019 17:00:55 +0800 Subject: [PATCH] Feature: move hosts to the top --- config/config.go | 53 ++++++++++++-------- dns/iputil.go | 66 ++++++++++++++++++++++++ dns/middleware.go | 105 +++++++++++++-------------------------- dns/resolver.go | 48 ++++++++---------- hub/executor/executor.go | 7 ++- tunnel/tunnel.go | 7 +++ 6 files changed, 164 insertions(+), 122 deletions(-) diff --git a/config/config.go b/config/config.go index e1baac8f..6d0942dd 100644 --- a/config/config.go +++ b/config/config.go @@ -44,7 +44,6 @@ type DNS struct { IPv6 bool `yaml:"ipv6"` NameServer []dns.NameServer `yaml:"nameserver"` Fallback []dns.NameServer `yaml:"fallback"` - Hosts *trie.Trie `yaml:"-"` Listen string `yaml:"listen"` EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"` FakeIPRange *fakeip.Pool @@ -60,20 +59,20 @@ type Config struct { General *General DNS *DNS Experimental *Experimental + Hosts *trie.Trie Rules []C.Rule Users []auth.AuthUser Proxies map[string]C.Proxy } type rawDNS struct { - Enable bool `yaml:"enable"` - IPv6 bool `yaml:"ipv6"` - NameServer []string `yaml:"nameserver"` - Hosts map[string]string `yaml:"hosts"` - Fallback []string `yaml:"fallback"` - Listen string `yaml:"listen"` - EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"` - FakeIPRange string `yaml:"fake-ip-range"` + Enable bool `yaml:"enable"` + IPv6 bool `yaml:"ipv6"` + NameServer []string `yaml:"nameserver"` + Fallback []string `yaml:"fallback"` + Listen string `yaml:"listen"` + EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"` + FakeIPRange string `yaml:"fake-ip-range"` } type rawConfig struct { @@ -89,6 +88,7 @@ type rawConfig struct { ExternalUI string `yaml:"external-ui"` Secret string `yaml:"secret"` + Hosts map[string]string `yaml:"hosts"` DNS rawDNS `yaml:"dns"` Experimental Experimental `yaml:"experimental"` Proxy []map[string]interface{} `yaml:"Proxy"` @@ -135,6 +135,7 @@ func readConfig(path string) (*rawConfig, error) { Mode: T.Rule, Authentication: []string{}, LogLevel: log.INFO, + Hosts: map[string]string{}, Rule: []string{}, Proxy: []map[string]interface{}{}, ProxyGroup: []map[string]interface{}{}, @@ -144,7 +145,6 @@ func readConfig(path string) (*rawConfig, error) { DNS: rawDNS{ Enable: false, FakeIPRange: "198.18.0.1/16", - Hosts: map[string]string{}, }, } err = yaml.Unmarshal([]byte(data), &rawConfig) @@ -185,6 +185,12 @@ func Parse(path string) (*Config, error) { } config.DNS = dnsCfg + hosts, err := parseHosts(rawCfg) + if err != nil { + return nil, err + } + config.Hosts = hosts + config.Users = parseAuthentication(rawCfg.Authentication) return config, nil @@ -460,6 +466,21 @@ func parseRules(cfg *rawConfig, proxies map[string]C.Proxy) ([]C.Rule, error) { return rules, nil } +func parseHosts(cfg *rawConfig) (*trie.Trie, error) { + tree := trie.New() + if len(cfg.Hosts) != 0 { + for domain, ipStr := range cfg.Hosts { + ip := net.ParseIP(ipStr) + if ip == nil { + return nil, fmt.Errorf("%s is not a valid IP", ipStr) + } + tree.Insert(domain, ip) + } + } + + return tree, nil +} + func hostWithDefaultPort(host string, defPort string) (string, error) { if !strings.Contains(host, ":") { host += ":" @@ -544,18 +565,6 @@ func parseDNS(cfg rawDNS) (*DNS, error) { return nil, err } - if len(cfg.Hosts) != 0 { - tree := trie.New() - for domain, ipStr := range cfg.Hosts { - ip := net.ParseIP(ipStr) - if ip == nil { - return nil, fmt.Errorf("%s is not a valid IP", ipStr) - } - tree.Insert(domain, ip) - } - dnsCfg.Hosts = tree - } - if cfg.EnhancedMode == dns.FAKEIP { _, ipnet, err := net.ParseCIDR(cfg.FakeIPRange) if err != nil { diff --git a/dns/iputil.go b/dns/iputil.go index 106c25d8..66dd0c0e 100644 --- a/dns/iputil.go +++ b/dns/iputil.go @@ -9,8 +9,74 @@ var ( errIPNotFound = errors.New("cannot found ip") ) +// ResolveIPv4 with a host, return ipv4 +func ResolveIPv4(host string) (net.IP, error) { + if node := DefaultHosts.Search(host); node != nil { + if ip := node.Data.(net.IP).To4(); ip != nil { + return ip, nil + } + } + + ip := net.ParseIP(host) + if ip4 := ip.To4(); ip4 != nil { + return ip4, nil + } + + if DefaultResolver != nil { + return DefaultResolver.ResolveIPv4(host) + } + + ipAddrs, err := net.LookupIP(host) + if err != nil { + return nil, err + } + + for _, ip := range ipAddrs { + if ip4 := ip.To4(); ip4 != nil { + return ip4, nil + } + } + + return nil, errIPNotFound +} + +// ResolveIPv6 with a host, return ipv6 +func ResolveIPv6(host string) (net.IP, error) { + if node := DefaultHosts.Search(host); node != nil { + if ip := node.Data.(net.IP).To16(); ip != nil { + return ip, nil + } + } + + ip := net.ParseIP(host) + if ip6 := ip.To16(); ip6 != nil { + return ip6, nil + } + + if DefaultResolver != nil { + return DefaultResolver.ResolveIPv6(host) + } + + ipAddrs, err := net.LookupIP(host) + if err != nil { + return nil, err + } + + for _, ip := range ipAddrs { + if ip6 := ip.To16(); ip6 != nil { + return ip6, nil + } + } + + return nil, errIPNotFound +} + // ResolveIP with a host, return ip func ResolveIP(host string) (net.IP, error) { + if node := DefaultHosts.Search(host); node != nil { + return node.Data.(net.IP), nil + } + if DefaultResolver != nil { if DefaultResolver.ipv6 { return DefaultResolver.ResolveIP(host) diff --git a/dns/middleware.go b/dns/middleware.go index 4c7e1ec2..6d30f8d6 100644 --- a/dns/middleware.go +++ b/dns/middleware.go @@ -1,8 +1,6 @@ package dns import ( - "fmt" - "net" "strings" "github.com/Dreamacro/clash/component/fakeip" @@ -12,34 +10,40 @@ import ( ) type handler func(w D.ResponseWriter, r *D.Msg) +type middleware func(next handler) handler -func withFakeIP(pool *fakeip.Pool) handler { - return func(w D.ResponseWriter, r *D.Msg) { - q := r.Question[0] - host := strings.TrimRight(q.Name, ".") +func withFakeIP(fakePool *fakeip.Pool) middleware { + return func(next handler) handler { + return func(w D.ResponseWriter, r *D.Msg) { + q := r.Question[0] + if q.Qtype != D.TypeA && q.Qtype != D.TypeAAAA { + next(w, r) + return + } - rr := &D.A{} - rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL} - ip := pool.Lookup(host) - rr.A = ip - msg := r.Copy() - msg.Answer = []D.RR{rr} + host := strings.TrimRight(q.Name, ".") - setMsgTTL(msg, 1) - msg.SetReply(r) - w.WriteMsg(msg) - return + rr := &D.A{} + rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL} + ip := fakePool.Lookup(host) + rr.A = ip + msg := r.Copy() + msg.Answer = []D.RR{rr} + + setMsgTTL(msg, 1) + msg.SetReply(r) + w.WriteMsg(msg) + return + } } } func withResolver(resolver *Resolver) handler { return func(w D.ResponseWriter, r *D.Msg) { msg, err := resolver.Exchange(r) - if err != nil { q := r.Question[0] - qString := fmt.Sprintf("%s %s %s", q.Name, D.Class(q.Qclass).String(), D.Type(q.Qtype).String()) - log.Debugln("[DNS Server] Exchange %s failed: %v", qString, err) + log.Debugln("[DNS Server] Exchange %s failed: %v", q.String(), err) D.HandleFailed(w, r) return } @@ -49,64 +53,23 @@ func withResolver(resolver *Resolver) handler { } } -func withHost(resolver *Resolver, next handler) handler { - hosts := resolver.hosts - if hosts == nil { - panic("dns/withHost: hosts should not be nil") +func compose(middlewares []middleware, endpoint handler) handler { + length := len(middlewares) + h := endpoint + for i := length - 1; i >= 0; i-- { + middleware := middlewares[i] + h = middleware(h) } - return func(w D.ResponseWriter, r *D.Msg) { - q := r.Question[0] - if q.Qtype != D.TypeA && q.Qtype != D.TypeAAAA { - next(w, r) - return - } - - domain := strings.TrimRight(q.Name, ".") - host := hosts.Search(domain) - if host == nil { - next(w, r) - return - } - - ip := host.Data.(net.IP) - if q.Qtype == D.TypeAAAA && ip.To16() == nil { - next(w, r) - return - } else if q.Qtype == D.TypeA && ip.To4() == nil { - next(w, r) - return - } - - var rr D.RR - if q.Qtype == D.TypeAAAA { - record := &D.AAAA{} - record.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: dnsDefaultTTL} - record.AAAA = ip - rr = record - } else { - record := &D.A{} - record.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL} - record.A = ip - rr = record - } - - msg := r.Copy() - msg.Answer = []D.RR{rr} - msg.SetReply(r) - w.WriteMsg(msg) - return - } + return h } func newHandler(resolver *Resolver) handler { + middlewares := []middleware{} + if resolver.IsFakeIP() { - return withFakeIP(resolver.pool) + middlewares = append(middlewares, withFakeIP(resolver.pool)) } - if resolver.hosts != nil { - return withHost(resolver, withResolver(resolver)) - } - - return withResolver(resolver) + return compose(middlewares, withResolver(resolver)) } diff --git a/dns/resolver.go b/dns/resolver.go index 4e66f17c..6ed65764 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -21,8 +21,11 @@ import ( ) var ( - // DefaultResolver aim to resolve ip with host + // DefaultResolver aim to resolve ip DefaultResolver *Resolver + + // DefaultHosts aim to resolve hosts + DefaultHosts = trie.New() ) var ( @@ -46,7 +49,6 @@ type Resolver struct { ipv6 bool mapping bool fakeip bool - hosts *trie.Trie pool *fakeip.Pool fallback []resolver main []resolver @@ -56,11 +58,6 @@ type Resolver struct { // ResolveIP request with TypeA and TypeAAAA, priority return TypeAAAA func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) { - ip = net.ParseIP(host) - if ip != nil { - return ip, nil - } - ch := make(chan net.IP) go func() { defer close(ch) @@ -89,26 +86,12 @@ func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) { // ResolveIPv4 request with TypeA func (r *Resolver) ResolveIPv4(host string) (ip net.IP, err error) { - ip = net.ParseIP(host) - if ip != nil { - return ip, nil - } + return r.resolveIP(host, D.TypeA) +} - query := &D.Msg{} - query.SetQuestion(D.Fqdn(host), D.TypeA) - - msg, err := r.Exchange(query) - if err != nil { - return nil, err - } - - ips := r.msgToIP(msg) - if len(ips) == 0 { - return nil, errIPNotFound - } - - ip = ips[0] - return +// ResolveIPv6 request with TypeAAAA +func (r *Resolver) ResolveIPv6(host string) (ip net.IP, err error) { + return r.resolveIP(host, D.TypeAAAA) } // Exchange a batch of dns request, and it use cache @@ -232,6 +215,17 @@ func (r *Resolver) fallbackExchange(m *D.Msg) (msg *D.Msg, err error) { } func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error) { + ip = net.ParseIP(host) + if dnsType == D.TypeAAAA { + if ip6 := ip.To16(); ip6 != nil { + return ip6, nil + } + } else { + if ip4 := ip.To4(); ip4 != nil { + return ip4, nil + } + } + query := &D.Msg{} query.SetQuestion(D.Fqdn(host), dnsType) @@ -282,7 +276,6 @@ type Config struct { Main, Fallback []NameServer IPv6 bool EnhancedMode EnhancedMode - Hosts *trie.Trie Pool *fakeip.Pool } @@ -297,7 +290,6 @@ func New(config Config) *Resolver { cache: cache.New(time.Second * 60), mapping: config.EnhancedMode == MAPPING, fakeip: config.EnhancedMode == FAKEIP, - hosts: config.Hosts, pool: config.Pool, } if len(config.Fallback) != 0 { diff --git a/hub/executor/executor.go b/hub/executor/executor.go index b00ae2d4..2883d5fc 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -2,6 +2,7 @@ package executor import ( "github.com/Dreamacro/clash/component/auth" + trie "github.com/Dreamacro/clash/component/domain-trie" "github.com/Dreamacro/clash/config" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/dns" @@ -30,6 +31,7 @@ func ApplyConfig(cfg *config.Config, force bool) { updateProxies(cfg.Proxies) updateRules(cfg.Rules) updateDNS(cfg.DNS) + updateHosts(cfg.Hosts) updateExperimental(cfg.Experimental) } @@ -68,7 +70,6 @@ func updateDNS(c *config.DNS) { Main: c.NameServer, Fallback: c.Fallback, IPv6: c.IPv6, - Hosts: c.Hosts, EnhancedMode: c.EnhancedMode, Pool: c.FakeIPRange, }) @@ -83,6 +84,10 @@ func updateDNS(c *config.DNS) { } } +func updateHosts(tree *trie.Trie) { + dns.DefaultHosts = tree +} + func updateProxies(proxies map[string]C.Proxy) { tunnel := T.Instance() oldProxies := tunnel.Proxies() diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 837374ad..87aafec3 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -213,6 +213,13 @@ func (t *Tunnel) match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { defer t.configMux.RUnlock() var resolved bool + + if node := dns.DefaultHosts.Search(metadata.Host); node != nil { + ip := node.Data.(net.IP) + metadata.DstIP = &ip + resolved = true + } + for _, rule := range t.rules { if !resolved && t.shouldResolveIP(rule, metadata) { ip, err := t.resolveIP(metadata.Host)