From 57fdd223f1079af84582c6152ba108d366a6a161 Mon Sep 17 00:00:00 2001 From: Dreamacro <305009791@qq.com> Date: Sat, 29 Jun 2019 00:58:59 +0800 Subject: [PATCH] Feature: custom dns ipv4/ipv6 dual stack --- adapters/outbound/direct.go | 4 ++-- adapters/outbound/http.go | 2 +- adapters/outbound/shadowsocks.go | 4 ++-- adapters/outbound/socks5.go | 4 ++-- adapters/outbound/util.go | 28 ++++++++++++++++++++++++++++ adapters/outbound/vmess.go | 4 ++-- config/config.go | 1 + dns/iputil.go | 32 ++++++++++++++++++++++++++++++++ dns/resolver.go | 18 ++++++++++++++---- hub/executor/executor.go | 9 ++++++--- tunnel/tunnel.go | 27 ++++----------------------- 11 files changed, 94 insertions(+), 39 deletions(-) create mode 100644 dns/iputil.go diff --git a/adapters/outbound/direct.go b/adapters/outbound/direct.go index 563f7074..491c170e 100644 --- a/adapters/outbound/direct.go +++ b/adapters/outbound/direct.go @@ -16,7 +16,7 @@ func (d *Direct) Dial(metadata *C.Metadata) (net.Conn, error) { address = net.JoinHostPort(metadata.DstIP.String(), metadata.DstPort) } - c, err := net.DialTimeout("tcp", address, tcpTimeout) + c, err := dialTimeout("tcp", address, tcpTimeout) if err != nil { return nil, err } @@ -30,7 +30,7 @@ func (d *Direct) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr, error) return nil, nil, err } - addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(metadata.String(), metadata.DstPort)) + addr, err := resolveUDPAddr("udp", net.JoinHostPort(metadata.String(), metadata.DstPort)) if err != nil { return nil, nil, err } diff --git a/adapters/outbound/http.go b/adapters/outbound/http.go index 58f335dc..a617de5b 100644 --- a/adapters/outbound/http.go +++ b/adapters/outbound/http.go @@ -36,7 +36,7 @@ type HttpOption struct { } func (h *Http) Dial(metadata *C.Metadata) (net.Conn, error) { - c, err := net.DialTimeout("tcp", h.addr, tcpTimeout) + c, err := dialTimeout("tcp", h.addr, tcpTimeout) if err == nil && h.tls { cc := tls.Client(c, h.tlsConfig) err = cc.Handshake() diff --git a/adapters/outbound/shadowsocks.go b/adapters/outbound/shadowsocks.go index 7d422723..b05f5830 100644 --- a/adapters/outbound/shadowsocks.go +++ b/adapters/outbound/shadowsocks.go @@ -58,7 +58,7 @@ type v2rayObfsOption struct { } func (ss *ShadowSocks) Dial(metadata *C.Metadata) (net.Conn, error) { - c, err := net.DialTimeout("tcp", ss.server, tcpTimeout) + c, err := dialTimeout("tcp", ss.server, tcpTimeout) if err != nil { return nil, fmt.Errorf("%s connect error: %s", ss.server, err.Error()) } @@ -87,7 +87,7 @@ func (ss *ShadowSocks) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr, return nil, nil, err } - addr, err := net.ResolveUDPAddr("udp", ss.server) + addr, err := resolveUDPAddr("udp", ss.server) if err != nil { return nil, nil, err } diff --git a/adapters/outbound/socks5.go b/adapters/outbound/socks5.go index 7f52b58e..6ec611c6 100644 --- a/adapters/outbound/socks5.go +++ b/adapters/outbound/socks5.go @@ -32,7 +32,7 @@ type Socks5Option struct { } func (ss *Socks5) Dial(metadata *C.Metadata) (net.Conn, error) { - c, err := net.DialTimeout("tcp", ss.addr, tcpTimeout) + c, err := dialTimeout("tcp", ss.addr, tcpTimeout) if err == nil && ss.tls { cc := tls.Client(c, ss.tlsConfig) @@ -58,7 +58,7 @@ func (ss *Socks5) Dial(metadata *C.Metadata) (net.Conn, error) { } func (ss *Socks5) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr, error) { - c, err := net.DialTimeout("tcp", ss.addr, tcpTimeout) + c, err := dialTimeout("tcp", ss.addr, tcpTimeout) if err == nil && ss.tls { cc := tls.Client(c, ss.tlsConfig) diff --git a/adapters/outbound/util.go b/adapters/outbound/util.go index d7c134b3..0cccf5d9 100644 --- a/adapters/outbound/util.go +++ b/adapters/outbound/util.go @@ -12,6 +12,7 @@ import ( "github.com/Dreamacro/clash/component/socks5" C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/dns" ) const ( @@ -96,3 +97,30 @@ func (fuc *fakeUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { n, err := fuc.Conn.Read(b) return n, fuc.RemoteAddr(), err } + +func dialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + + ip, err := dns.ResolveIP(host) + if err != nil { + return nil, err + } + + return net.DialTimeout(network, net.JoinHostPort(ip.String(), port), timeout) +} + +func resolveUDPAddr(network, address string) (*net.UDPAddr, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + + ip, err := dns.ResolveIP(host) + if err != nil { + return nil, err + } + return net.ResolveUDPAddr(network, net.JoinHostPort(ip.String(), port)) +} diff --git a/adapters/outbound/vmess.go b/adapters/outbound/vmess.go index e1113bd1..7d74936d 100644 --- a/adapters/outbound/vmess.go +++ b/adapters/outbound/vmess.go @@ -32,7 +32,7 @@ type VmessOption struct { } func (v *Vmess) Dial(metadata *C.Metadata) (net.Conn, error) { - c, err := net.DialTimeout("tcp", v.server, tcpTimeout) + c, err := dialTimeout("tcp", v.server, tcpTimeout) if err != nil { return nil, fmt.Errorf("%s connect error", v.server) } @@ -42,7 +42,7 @@ func (v *Vmess) Dial(metadata *C.Metadata) (net.Conn, error) { } func (v *Vmess) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr, error) { - c, err := net.DialTimeout("tcp", v.server, tcpTimeout) + c, err := dialTimeout("tcp", v.server, tcpTimeout) if err != nil { return nil, nil, fmt.Errorf("%s connect error", v.server) } diff --git a/config/config.go b/config/config.go index 598ac834..fdf55db0 100644 --- a/config/config.go +++ b/config/config.go @@ -506,6 +506,7 @@ func parseDNS(cfg rawDNS) (*DNS, error) { dnsCfg := &DNS{ Enable: cfg.Enable, Listen: cfg.Listen, + IPv6: cfg.IPv6, EnhancedMode: cfg.EnhancedMode, } var err error diff --git a/dns/iputil.go b/dns/iputil.go new file mode 100644 index 00000000..9d8c7048 --- /dev/null +++ b/dns/iputil.go @@ -0,0 +1,32 @@ +package dns + +import ( + "errors" + "net" +) + +var ( + errIPNotFound = errors.New("ip not found") +) + +// ResolveIP with a host, return ip +func ResolveIP(host string) (net.IP, error) { + if DefaultResolver != nil { + if DefaultResolver.ipv6 { + return DefaultResolver.ResolveIP(host) + } + return DefaultResolver.ResolveIPv4(host) + } + + ip := net.ParseIP(host) + if ip != nil { + return ip, nil + } + + ipAddr, err := net.ResolveIPAddr("ip", host) + if err != nil { + return nil, err + } + + return ipAddr.IP, nil +} diff --git a/dns/resolver.go b/dns/resolver.go index bb0710d2..7821e570 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -18,6 +18,11 @@ import ( geoip2 "github.com/oschwald/geoip2-golang" ) +var ( + // DefaultResolver aim to resolve ip with host + DefaultResolver *Resolver +) + var ( globalSessionCache = tls.NewLRUClientSessionCache(64) @@ -47,11 +52,16 @@ type Resolver struct { // ResolveIP request with TypeA and TypeAAAA, priority return TypeAAAA func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) { + ip = net.ParseIP(host) + if ip != nil { + return ip, nil + } + ch := make(chan net.IP) go func() { + defer close(ch) ip, err := r.resolveIP(host, D.TypeA) if err != nil { - close(ch) return } ch <- ip @@ -65,8 +75,8 @@ func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) { return } - ip, closed := <-ch - if closed { + ip, open := <-ch + if !open { return nil, errors.New("can't found ip") } @@ -275,8 +285,8 @@ func New(config Config) *Resolver { }) r := &Resolver{ - main: transform(config.Main), ipv6: config.IPv6, + main: transform(config.Main), cache: cache.New(time.Second * 60), mapping: config.EnhancedMode == MAPPING, fakeip: config.EnhancedMode == FAKEIP, diff --git a/hub/executor/executor.go b/hub/executor/executor.go index a0bd1c86..f295108f 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -59,7 +59,7 @@ func updateExperimental(c *config.Experimental) { func updateDNS(c *config.DNS) { if c.Enable == false { - T.Instance().SetResolver(nil) + dns.DefaultResolver = nil dns.ReCreateServer("", nil) return } @@ -70,12 +70,15 @@ func updateDNS(c *config.DNS) { EnhancedMode: c.EnhancedMode, Pool: c.FakeIPRange, }) - T.Instance().SetResolver(r) + dns.DefaultResolver = r if err := dns.ReCreateServer(c.Listen, r); err != nil { log.Errorln("Start DNS server error: %s", err.Error()) return } - log.Infoln("DNS server listening at: %s", c.Listen) + + if c.Listen != "" { + log.Infoln("DNS server listening at: %s", c.Listen) + } } func updateProxies(proxies map[string]C.Proxy) { diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index faeaf46d..fedde7f2 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -26,7 +26,6 @@ type Tunnel struct { proxies map[string]C.Proxy configMux *sync.RWMutex traffic *C.Traffic - resolver *dns.Resolver // experimental features ignoreResolveFail bool @@ -86,15 +85,6 @@ func (t *Tunnel) SetMode(mode Mode) { t.mode = mode } -// SetResolver change the resolver of tunnel -func (t *Tunnel) SetResolver(resolver *dns.Resolver) { - t.resolver = resolver -} - -func (t *Tunnel) hasResolver() bool { - return t.resolver != nil -} - func (t *Tunnel) process() { queue := t.queue.Out() for { @@ -105,20 +95,11 @@ func (t *Tunnel) process() { } func (t *Tunnel) resolveIP(host string) (net.IP, error) { - if t.resolver == nil { - ipAddr, err := net.ResolveIPAddr("ip", host) - if err != nil { - return nil, err - } - - return ipAddr.IP, nil - } - - return t.resolver.ResolveIP(host) + return dns.ResolveIP(host) } func (t *Tunnel) needLookupIP(metadata *C.Metadata) bool { - return t.hasResolver() && (t.resolver.IsMapping() || t.resolver.IsFakeIP()) && metadata.Host == "" && metadata.DstIP != nil + return dns.DefaultResolver != nil && (dns.DefaultResolver.IsMapping() || dns.DefaultResolver.IsFakeIP()) && metadata.Host == "" && metadata.DstIP != nil } func (t *Tunnel) handleConn(localConn C.ServerAdapter) { @@ -132,11 +113,11 @@ func (t *Tunnel) handleConn(localConn C.ServerAdapter) { // preprocess enhanced-mode metadata if t.needLookupIP(metadata) { - host, exist := t.resolver.IPToHost(*metadata.DstIP) + host, exist := dns.DefaultResolver.IPToHost(*metadata.DstIP) if exist { metadata.Host = host metadata.AddrType = C.AtypDomainName - if t.resolver.IsFakeIP() { + if dns.DefaultResolver.IsFakeIP() { metadata.DstIP = nil } }