Feature: custom dns ipv4/ipv6 dual stack

This commit is contained in:
Dreamacro 2019-06-29 00:58:59 +08:00
parent bc3fc0c840
commit 57fdd223f1
11 changed files with 94 additions and 39 deletions

View file

@ -16,7 +16,7 @@ func (d *Direct) Dial(metadata *C.Metadata) (net.Conn, error) {
address = net.JoinHostPort(metadata.DstIP.String(), metadata.DstPort) address = net.JoinHostPort(metadata.DstIP.String(), metadata.DstPort)
} }
c, err := net.DialTimeout("tcp", address, tcpTimeout) c, err := dialTimeout("tcp", address, tcpTimeout)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -30,7 +30,7 @@ func (d *Direct) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr, error)
return nil, nil, err return nil, nil, err
} }
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(metadata.String(), metadata.DstPort)) addr, err := resolveUDPAddr("udp", net.JoinHostPort(metadata.String(), metadata.DstPort))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View file

@ -36,7 +36,7 @@ type HttpOption struct {
} }
func (h *Http) Dial(metadata *C.Metadata) (net.Conn, error) { func (h *Http) Dial(metadata *C.Metadata) (net.Conn, error) {
c, err := net.DialTimeout("tcp", h.addr, tcpTimeout) c, err := dialTimeout("tcp", h.addr, tcpTimeout)
if err == nil && h.tls { if err == nil && h.tls {
cc := tls.Client(c, h.tlsConfig) cc := tls.Client(c, h.tlsConfig)
err = cc.Handshake() err = cc.Handshake()

View file

@ -58,7 +58,7 @@ type v2rayObfsOption struct {
} }
func (ss *ShadowSocks) Dial(metadata *C.Metadata) (net.Conn, error) { func (ss *ShadowSocks) Dial(metadata *C.Metadata) (net.Conn, error) {
c, err := net.DialTimeout("tcp", ss.server, tcpTimeout) c, err := dialTimeout("tcp", ss.server, tcpTimeout)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", ss.server, err.Error()) return nil, fmt.Errorf("%s connect error: %s", ss.server, err.Error())
} }
@ -87,7 +87,7 @@ func (ss *ShadowSocks) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr,
return nil, nil, err return nil, nil, err
} }
addr, err := net.ResolveUDPAddr("udp", ss.server) addr, err := resolveUDPAddr("udp", ss.server)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View file

@ -32,7 +32,7 @@ type Socks5Option struct {
} }
func (ss *Socks5) Dial(metadata *C.Metadata) (net.Conn, error) { func (ss *Socks5) Dial(metadata *C.Metadata) (net.Conn, error) {
c, err := net.DialTimeout("tcp", ss.addr, tcpTimeout) c, err := dialTimeout("tcp", ss.addr, tcpTimeout)
if err == nil && ss.tls { if err == nil && ss.tls {
cc := tls.Client(c, ss.tlsConfig) cc := tls.Client(c, ss.tlsConfig)
@ -58,7 +58,7 @@ func (ss *Socks5) Dial(metadata *C.Metadata) (net.Conn, error) {
} }
func (ss *Socks5) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr, error) { func (ss *Socks5) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr, error) {
c, err := net.DialTimeout("tcp", ss.addr, tcpTimeout) c, err := dialTimeout("tcp", ss.addr, tcpTimeout)
if err == nil && ss.tls { if err == nil && ss.tls {
cc := tls.Client(c, ss.tlsConfig) cc := tls.Client(c, ss.tlsConfig)

View file

@ -12,6 +12,7 @@ import (
"github.com/Dreamacro/clash/component/socks5" "github.com/Dreamacro/clash/component/socks5"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/dns"
) )
const ( const (
@ -96,3 +97,30 @@ func (fuc *fakeUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
n, err := fuc.Conn.Read(b) n, err := fuc.Conn.Read(b)
return n, fuc.RemoteAddr(), err return n, fuc.RemoteAddr(), err
} }
func dialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
ip, err := dns.ResolveIP(host)
if err != nil {
return nil, err
}
return net.DialTimeout(network, net.JoinHostPort(ip.String(), port), timeout)
}
func resolveUDPAddr(network, address string) (*net.UDPAddr, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
ip, err := dns.ResolveIP(host)
if err != nil {
return nil, err
}
return net.ResolveUDPAddr(network, net.JoinHostPort(ip.String(), port))
}

View file

@ -32,7 +32,7 @@ type VmessOption struct {
} }
func (v *Vmess) Dial(metadata *C.Metadata) (net.Conn, error) { func (v *Vmess) Dial(metadata *C.Metadata) (net.Conn, error) {
c, err := net.DialTimeout("tcp", v.server, tcpTimeout) c, err := dialTimeout("tcp", v.server, tcpTimeout)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error", v.server) return nil, fmt.Errorf("%s connect error", v.server)
} }
@ -42,7 +42,7 @@ func (v *Vmess) Dial(metadata *C.Metadata) (net.Conn, error) {
} }
func (v *Vmess) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr, error) { func (v *Vmess) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr, error) {
c, err := net.DialTimeout("tcp", v.server, tcpTimeout) c, err := dialTimeout("tcp", v.server, tcpTimeout)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("%s connect error", v.server) return nil, nil, fmt.Errorf("%s connect error", v.server)
} }

View file

@ -506,6 +506,7 @@ func parseDNS(cfg rawDNS) (*DNS, error) {
dnsCfg := &DNS{ dnsCfg := &DNS{
Enable: cfg.Enable, Enable: cfg.Enable,
Listen: cfg.Listen, Listen: cfg.Listen,
IPv6: cfg.IPv6,
EnhancedMode: cfg.EnhancedMode, EnhancedMode: cfg.EnhancedMode,
} }
var err error var err error

32
dns/iputil.go Normal file
View file

@ -0,0 +1,32 @@
package dns
import (
"errors"
"net"
)
var (
errIPNotFound = errors.New("ip not found")
)
// ResolveIP with a host, return ip
func ResolveIP(host string) (net.IP, error) {
if DefaultResolver != nil {
if DefaultResolver.ipv6 {
return DefaultResolver.ResolveIP(host)
}
return DefaultResolver.ResolveIPv4(host)
}
ip := net.ParseIP(host)
if ip != nil {
return ip, nil
}
ipAddr, err := net.ResolveIPAddr("ip", host)
if err != nil {
return nil, err
}
return ipAddr.IP, nil
}

View file

@ -18,6 +18,11 @@ import (
geoip2 "github.com/oschwald/geoip2-golang" geoip2 "github.com/oschwald/geoip2-golang"
) )
var (
// DefaultResolver aim to resolve ip with host
DefaultResolver *Resolver
)
var ( var (
globalSessionCache = tls.NewLRUClientSessionCache(64) globalSessionCache = tls.NewLRUClientSessionCache(64)
@ -47,11 +52,16 @@ type Resolver struct {
// ResolveIP request with TypeA and TypeAAAA, priority return TypeAAAA // ResolveIP request with TypeA and TypeAAAA, priority return TypeAAAA
func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) { func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) {
ip = net.ParseIP(host)
if ip != nil {
return ip, nil
}
ch := make(chan net.IP) ch := make(chan net.IP)
go func() { go func() {
defer close(ch)
ip, err := r.resolveIP(host, D.TypeA) ip, err := r.resolveIP(host, D.TypeA)
if err != nil { if err != nil {
close(ch)
return return
} }
ch <- ip ch <- ip
@ -65,8 +75,8 @@ func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) {
return return
} }
ip, closed := <-ch ip, open := <-ch
if closed { if !open {
return nil, errors.New("can't found ip") return nil, errors.New("can't found ip")
} }
@ -275,8 +285,8 @@ func New(config Config) *Resolver {
}) })
r := &Resolver{ r := &Resolver{
main: transform(config.Main),
ipv6: config.IPv6, ipv6: config.IPv6,
main: transform(config.Main),
cache: cache.New(time.Second * 60), cache: cache.New(time.Second * 60),
mapping: config.EnhancedMode == MAPPING, mapping: config.EnhancedMode == MAPPING,
fakeip: config.EnhancedMode == FAKEIP, fakeip: config.EnhancedMode == FAKEIP,

View file

@ -59,7 +59,7 @@ func updateExperimental(c *config.Experimental) {
func updateDNS(c *config.DNS) { func updateDNS(c *config.DNS) {
if c.Enable == false { if c.Enable == false {
T.Instance().SetResolver(nil) dns.DefaultResolver = nil
dns.ReCreateServer("", nil) dns.ReCreateServer("", nil)
return return
} }
@ -70,12 +70,15 @@ func updateDNS(c *config.DNS) {
EnhancedMode: c.EnhancedMode, EnhancedMode: c.EnhancedMode,
Pool: c.FakeIPRange, Pool: c.FakeIPRange,
}) })
T.Instance().SetResolver(r) dns.DefaultResolver = r
if err := dns.ReCreateServer(c.Listen, r); err != nil { if err := dns.ReCreateServer(c.Listen, r); err != nil {
log.Errorln("Start DNS server error: %s", err.Error()) log.Errorln("Start DNS server error: %s", err.Error())
return return
} }
log.Infoln("DNS server listening at: %s", c.Listen)
if c.Listen != "" {
log.Infoln("DNS server listening at: %s", c.Listen)
}
} }
func updateProxies(proxies map[string]C.Proxy) { func updateProxies(proxies map[string]C.Proxy) {

View file

@ -26,7 +26,6 @@ type Tunnel struct {
proxies map[string]C.Proxy proxies map[string]C.Proxy
configMux *sync.RWMutex configMux *sync.RWMutex
traffic *C.Traffic traffic *C.Traffic
resolver *dns.Resolver
// experimental features // experimental features
ignoreResolveFail bool ignoreResolveFail bool
@ -86,15 +85,6 @@ func (t *Tunnel) SetMode(mode Mode) {
t.mode = mode t.mode = mode
} }
// SetResolver change the resolver of tunnel
func (t *Tunnel) SetResolver(resolver *dns.Resolver) {
t.resolver = resolver
}
func (t *Tunnel) hasResolver() bool {
return t.resolver != nil
}
func (t *Tunnel) process() { func (t *Tunnel) process() {
queue := t.queue.Out() queue := t.queue.Out()
for { for {
@ -105,20 +95,11 @@ func (t *Tunnel) process() {
} }
func (t *Tunnel) resolveIP(host string) (net.IP, error) { func (t *Tunnel) resolveIP(host string) (net.IP, error) {
if t.resolver == nil { return dns.ResolveIP(host)
ipAddr, err := net.ResolveIPAddr("ip", host)
if err != nil {
return nil, err
}
return ipAddr.IP, nil
}
return t.resolver.ResolveIP(host)
} }
func (t *Tunnel) needLookupIP(metadata *C.Metadata) bool { func (t *Tunnel) needLookupIP(metadata *C.Metadata) bool {
return t.hasResolver() && (t.resolver.IsMapping() || t.resolver.IsFakeIP()) && metadata.Host == "" && metadata.DstIP != nil return dns.DefaultResolver != nil && (dns.DefaultResolver.IsMapping() || dns.DefaultResolver.IsFakeIP()) && metadata.Host == "" && metadata.DstIP != nil
} }
func (t *Tunnel) handleConn(localConn C.ServerAdapter) { func (t *Tunnel) handleConn(localConn C.ServerAdapter) {
@ -132,11 +113,11 @@ func (t *Tunnel) handleConn(localConn C.ServerAdapter) {
// preprocess enhanced-mode metadata // preprocess enhanced-mode metadata
if t.needLookupIP(metadata) { if t.needLookupIP(metadata) {
host, exist := t.resolver.IPToHost(*metadata.DstIP) host, exist := dns.DefaultResolver.IPToHost(*metadata.DstIP)
if exist { if exist {
metadata.Host = host metadata.Host = host
metadata.AddrType = C.AtypDomainName metadata.AddrType = C.AtypDomainName
if t.resolver.IsFakeIP() { if dns.DefaultResolver.IsFakeIP() {
metadata.DstIP = nil metadata.DstIP = nil
} }
} }