diff --git a/component/resolver/resolver.go b/component/resolver/resolver.go index d10e39cb..daf98a09 100644 --- a/component/resolver/resolver.go +++ b/component/resolver/resolver.go @@ -3,6 +3,7 @@ package resolver import ( "context" "errors" + "fmt" "math/rand" "net" "strings" @@ -33,29 +34,32 @@ var ( ) type Resolver interface { + LookupIP(ctx context.Context, host string) ([]net.IP, error) + LookupIPv4(ctx context.Context, host string) ([]net.IP, error) + LookupIPv6(ctx context.Context, host string) ([]net.IP, error) ResolveIP(host string) (ip net.IP, err error) ResolveIPv4(host string) (ip net.IP, err error) ResolveIPv6(host string) (ip net.IP, err error) } -// ResolveIPv4 with a host, return ipv4 -func ResolveIPv4(host string) (net.IP, error) { +// LookupIPv4 with a host, return ipv4 list +func LookupIPv4(ctx context.Context, host string) ([]net.IP, error) { if node := DefaultHosts.Search(host); node != nil { if ip := node.Data.(net.IP).To4(); ip != nil { - return ip, nil + return []net.IP{ip}, nil } } ip := net.ParseIP(host) if ip != nil { if !strings.Contains(host, ":") { - return ip, nil + return []net.IP{ip}, nil } return nil, ErrIPVersion } if DefaultResolver != nil { - return DefaultResolver.ResolveIPv4(host) + return DefaultResolver.LookupIPv4(ctx, host) } ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) @@ -67,31 +71,42 @@ func ResolveIPv4(host string) (net.IP, error) { return nil, ErrIPNotFound } - return ipAddrs[rand.Intn(len(ipAddrs))], nil + return ipAddrs, nil } -// ResolveIPv6 with a host, return ipv6 -func ResolveIPv6(host string) (net.IP, error) { +// ResolveIPv4 with a host, return ipv4 +func ResolveIPv4(host string) (net.IP, error) { + ips, err := LookupIPv4(context.Background(), host) + if err != nil { + return nil, err + } else if len(ips) == 0 { + return nil, fmt.Errorf("%w: %s", ErrIPNotFound, host) + } + return ips[rand.Intn(len(ips))], nil +} + +// LookupIPv6 with a host, return ipv6 list +func LookupIPv6(ctx context.Context, host string) ([]net.IP, error) { if DisableIPv6 { return nil, ErrIPv6Disabled } if node := DefaultHosts.Search(host); node != nil { if ip := node.Data.(net.IP).To16(); ip != nil { - return ip, nil + return []net.IP{ip}, nil } } ip := net.ParseIP(host) if ip != nil { if strings.Contains(host, ":") { - return ip, nil + return []net.IP{ip}, nil } return nil, ErrIPVersion } if DefaultResolver != nil { - return DefaultResolver.ResolveIPv6(host) + return DefaultResolver.LookupIPv6(ctx, host) } ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) @@ -103,38 +118,62 @@ func ResolveIPv6(host string) (net.IP, error) { return nil, ErrIPNotFound } - return ipAddrs[rand.Intn(len(ipAddrs))], nil + return ipAddrs, nil } -// ResolveIPWithResolver same as ResolveIP, but with a resolver -func ResolveIPWithResolver(host string, r Resolver) (net.IP, error) { +// ResolveIPv6 with a host, return ipv6 +func ResolveIPv6(host string) (net.IP, error) { + ips, err := LookupIPv6(context.Background(), host) + if err != nil { + return nil, err + } else if len(ips) == 0 { + return nil, fmt.Errorf("%w: %s", ErrIPNotFound, host) + } + return ips[rand.Intn(len(ips))], nil +} + +// LookupIPWithResolver same as ResolveIP, but with a resolver +func LookupIPWithResolver(ctx context.Context, host string, r Resolver) ([]net.IP, error) { if node := DefaultHosts.Search(host); node != nil { - return node.Data.(net.IP), nil + return []net.IP{node.Data.(net.IP)}, nil } if r != nil { if DisableIPv6 { - return r.ResolveIPv4(host) + return r.LookupIPv4(ctx, host) } - return r.ResolveIP(host) + return r.LookupIP(ctx, host) } else if DisableIPv6 { - return ResolveIPv4(host) + return LookupIPv4(ctx, host) } ip := net.ParseIP(host) if ip != nil { - return ip, nil + return []net.IP{ip}, nil } - ipAddr, err := net.ResolveIPAddr("ip", host) + ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host) if err != nil { return nil, err + } else if len(ips) == 0 { + return nil, ErrIPNotFound } - return ipAddr.IP, nil + return ips, nil +} + +// ResolveIP with a host, return ip +func LookupIP(ctx context.Context, host string) ([]net.IP, error) { + return LookupIPWithResolver(ctx, host, DefaultResolver) } // ResolveIP with a host, return ip func ResolveIP(host string) (net.IP, error) { - return ResolveIPWithResolver(host, DefaultResolver) + ips, err := LookupIP(context.Background(), host) + if err != nil { + return nil, err + } else if len(ips) == 0 { + return nil, fmt.Errorf("%w: %s", ErrIPNotFound, host) + } + return ips[rand.Intn(len(ips))], nil } diff --git a/dns/client.go b/dns/client.go index 5cb1fe02..366a179d 100644 --- a/dns/client.go +++ b/dns/client.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "fmt" + "math/rand" "net" "strings" @@ -36,9 +37,13 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error) return nil, fmt.Errorf("dns %s not a valid ip", c.host) } } else { - if ip, err = resolver.ResolveIPWithResolver(c.host, c.r); err != nil { + ips, err := resolver.LookupIPWithResolver(ctx, c.host, c.r) + if err != nil { return nil, fmt.Errorf("use default dns resolve failed: %w", err) + } else if len(ips) == 0 { + return nil, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, c.host) } + ip = ips[rand.Intn(len(ips))] } network := "udp" diff --git a/dns/doh.go b/dns/doh.go index 7e4ed469..79820f9c 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -4,7 +4,9 @@ import ( "bytes" "context" "crypto/tls" + "fmt" "io" + "math/rand" "net" "net/http" @@ -91,10 +93,13 @@ func newDoHClient(url, iface string, r *Resolver) *dohClient { return nil, err } - ip, err := resolver.ResolveIPWithResolver(host, r) + ips, err := resolver.LookupIPWithResolver(ctx, host, r) if err != nil { return nil, err + } else if len(ips) == 0 { + return nil, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, host) } + ip := ips[rand.Intn(len(ips))] options := []dialer.Option{} if iface != "" { diff --git a/dns/resolver.go b/dns/resolver.go index cec415ac..5eeda77f 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -42,19 +42,23 @@ type Resolver struct { policy *trie.DomainTrie } -// ResolveIP request with TypeA and TypeAAAA, priority return TypeA -func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) { - ch := make(chan net.IP, 1) +// LookupIP request with TypeA and TypeAAAA, priority return TypeA +func (r *Resolver) LookupIP(ctx context.Context, host string) (ip []net.IP, err error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + ch := make(chan []net.IP, 1) + go func() { defer close(ch) - ip, err := r.resolveIP(host, D.TypeAAAA) + ip, err := r.lookupIP(ctx, host, D.TypeAAAA) if err != nil { return } ch <- ip }() - ip, err = r.resolveIP(host, D.TypeA) + ip, err = r.lookupIP(ctx, host, D.TypeA) if err == nil { return } @@ -67,14 +71,47 @@ func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) { return ip, nil } +// ResolveIP request with TypeA and TypeAAAA, priority return TypeA +func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) { + ips, err := r.LookupIP(context.Background(), host) + if err != nil { + return nil, err + } else if len(ips) == 0 { + return nil, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, host) + } + return ips[rand.Intn(len(ips))], nil +} + +// LookupIPv4 request with TypeA +func (r *Resolver) LookupIPv4(ctx context.Context, host string) ([]net.IP, error) { + return r.lookupIP(ctx, host, D.TypeA) +} + // ResolveIPv4 request with TypeA func (r *Resolver) ResolveIPv4(host string) (ip net.IP, err error) { - return r.resolveIP(host, D.TypeA) + ips, err := r.lookupIP(context.Background(), host, D.TypeA) + if err != nil { + return nil, err + } else if len(ips) == 0 { + return nil, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, host) + } + return ips[rand.Intn(len(ips))], nil +} + +// LookupIPv6 request with TypeAAAA +func (r *Resolver) LookupIPv6(ctx context.Context, host string) ([]net.IP, error) { + return r.lookupIP(ctx, host, D.TypeAAAA) } // ResolveIPv6 request with TypeAAAA func (r *Resolver) ResolveIPv6(host string) (ip net.IP, err error) { - return r.resolveIP(host, D.TypeAAAA) + ips, err := r.lookupIP(context.Background(), host, D.TypeAAAA) + if err != nil { + return nil, err + } else if len(ips) == 0 { + return nil, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, host) + } + return ips[rand.Intn(len(ips))], nil } func (r *Resolver) shouldIPFallback(ip net.IP) bool { @@ -253,14 +290,15 @@ func (r *Resolver) ipExchange(ctx context.Context, m *D.Msg) (msg *D.Msg, err er return } -func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error) { - ip = net.ParseIP(host) +func (r *Resolver) lookupIP(ctx context.Context, host string, dnsType uint16) ([]net.IP, error) { + ip := net.ParseIP(host) if ip != nil { - isIPv4 := ip.To4() != nil + ip4 := ip.To4() + isIPv4 := ip4 != nil if dnsType == D.TypeAAAA && !isIPv4 { - return ip, nil + return []net.IP{ip}, nil } else if dnsType == D.TypeA && isIPv4 { - return ip, nil + return []net.IP{ip4}, nil } else { return nil, resolver.ErrIPVersion } @@ -275,13 +313,10 @@ func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error) } ips := msgToIP(msg) - ipLength := len(ips) - if ipLength == 0 { + if len(ips) == 0 { return nil, resolver.ErrIPNotFound } - - ip = ips[rand.Intn(ipLength)] - return + return ips, nil } func (r *Resolver) msgToDomain(msg *D.Msg) string { diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 4a4685b2..f5d9ed24 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -180,11 +180,13 @@ func handleUDPConn(packet *inbound.PacketAdapter) { // local resolve UDP dns if !metadata.Resolved() { - ip, err := resolver.ResolveIP(metadata.Host) + ips, err := resolver.LookupIP(context.Background(), metadata.Host) if err != nil { return + } else if len(ips) == 0 { + return } - metadata.DstIP = ip + metadata.DstIP = ips[0] } key := packet.LocalAddr().String()