Feature: sync missing resolver logic from premium, but still net.IP on opensource

This commit is contained in:
Dreamacro 2022-08-13 13:07:35 +08:00
parent 5940f62794
commit 3946d771e5
5 changed files with 129 additions and 43 deletions

View file

@ -3,6 +3,7 @@ package resolver
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"math/rand" "math/rand"
"net" "net"
"strings" "strings"
@ -33,29 +34,32 @@ var (
) )
type Resolver interface { type Resolver interface {
LookupIP(ctx context.Context, host string) ([]net.IP, error)
LookupIPv4(ctx context.Context, host string) ([]net.IP, error)
LookupIPv6(ctx context.Context, host string) ([]net.IP, error)
ResolveIP(host string) (ip net.IP, err error) ResolveIP(host string) (ip net.IP, err error)
ResolveIPv4(host string) (ip net.IP, err error) ResolveIPv4(host string) (ip net.IP, err error)
ResolveIPv6(host string) (ip net.IP, err error) ResolveIPv6(host string) (ip net.IP, err error)
} }
// ResolveIPv4 with a host, return ipv4 // LookupIPv4 with a host, return ipv4 list
func ResolveIPv4(host string) (net.IP, error) { func LookupIPv4(ctx context.Context, host string) ([]net.IP, error) {
if node := DefaultHosts.Search(host); node != nil { if node := DefaultHosts.Search(host); node != nil {
if ip := node.Data.(net.IP).To4(); ip != nil { if ip := node.Data.(net.IP).To4(); ip != nil {
return ip, nil return []net.IP{ip}, nil
} }
} }
ip := net.ParseIP(host) ip := net.ParseIP(host)
if ip != nil { if ip != nil {
if !strings.Contains(host, ":") { if !strings.Contains(host, ":") {
return ip, nil return []net.IP{ip}, nil
} }
return nil, ErrIPVersion return nil, ErrIPVersion
} }
if DefaultResolver != nil { if DefaultResolver != nil {
return DefaultResolver.ResolveIPv4(host) return DefaultResolver.LookupIPv4(ctx, host)
} }
ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout)
@ -67,31 +71,42 @@ func ResolveIPv4(host string) (net.IP, error) {
return nil, ErrIPNotFound return nil, ErrIPNotFound
} }
return ipAddrs[rand.Intn(len(ipAddrs))], nil return ipAddrs, nil
} }
// ResolveIPv6 with a host, return ipv6 // ResolveIPv4 with a host, return ipv4
func ResolveIPv6(host string) (net.IP, error) { func ResolveIPv4(host string) (net.IP, error) {
ips, err := LookupIPv4(context.Background(), host)
if err != nil {
return nil, err
} else if len(ips) == 0 {
return nil, fmt.Errorf("%w: %s", ErrIPNotFound, host)
}
return ips[rand.Intn(len(ips))], nil
}
// LookupIPv6 with a host, return ipv6 list
func LookupIPv6(ctx context.Context, host string) ([]net.IP, error) {
if DisableIPv6 { if DisableIPv6 {
return nil, ErrIPv6Disabled return nil, ErrIPv6Disabled
} }
if node := DefaultHosts.Search(host); node != nil { if node := DefaultHosts.Search(host); node != nil {
if ip := node.Data.(net.IP).To16(); ip != nil { if ip := node.Data.(net.IP).To16(); ip != nil {
return ip, nil return []net.IP{ip}, nil
} }
} }
ip := net.ParseIP(host) ip := net.ParseIP(host)
if ip != nil { if ip != nil {
if strings.Contains(host, ":") { if strings.Contains(host, ":") {
return ip, nil return []net.IP{ip}, nil
} }
return nil, ErrIPVersion return nil, ErrIPVersion
} }
if DefaultResolver != nil { if DefaultResolver != nil {
return DefaultResolver.ResolveIPv6(host) return DefaultResolver.LookupIPv6(ctx, host)
} }
ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout)
@ -103,38 +118,62 @@ func ResolveIPv6(host string) (net.IP, error) {
return nil, ErrIPNotFound return nil, ErrIPNotFound
} }
return ipAddrs[rand.Intn(len(ipAddrs))], nil return ipAddrs, nil
} }
// ResolveIPWithResolver same as ResolveIP, but with a resolver // ResolveIPv6 with a host, return ipv6
func ResolveIPWithResolver(host string, r Resolver) (net.IP, error) { func ResolveIPv6(host string) (net.IP, error) {
ips, err := LookupIPv6(context.Background(), host)
if err != nil {
return nil, err
} else if len(ips) == 0 {
return nil, fmt.Errorf("%w: %s", ErrIPNotFound, host)
}
return ips[rand.Intn(len(ips))], nil
}
// LookupIPWithResolver same as ResolveIP, but with a resolver
func LookupIPWithResolver(ctx context.Context, host string, r Resolver) ([]net.IP, error) {
if node := DefaultHosts.Search(host); node != nil { if node := DefaultHosts.Search(host); node != nil {
return node.Data.(net.IP), nil return []net.IP{node.Data.(net.IP)}, nil
} }
if r != nil { if r != nil {
if DisableIPv6 { if DisableIPv6 {
return r.ResolveIPv4(host) return r.LookupIPv4(ctx, host)
} }
return r.ResolveIP(host) return r.LookupIP(ctx, host)
} else if DisableIPv6 { } else if DisableIPv6 {
return ResolveIPv4(host) return LookupIPv4(ctx, host)
} }
ip := net.ParseIP(host) ip := net.ParseIP(host)
if ip != nil { if ip != nil {
return ip, nil return []net.IP{ip}, nil
} }
ipAddr, err := net.ResolveIPAddr("ip", host) ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
if err != nil { if err != nil {
return nil, err return nil, err
} else if len(ips) == 0 {
return nil, ErrIPNotFound
} }
return ipAddr.IP, nil return ips, nil
}
// ResolveIP with a host, return ip
func LookupIP(ctx context.Context, host string) ([]net.IP, error) {
return LookupIPWithResolver(ctx, host, DefaultResolver)
} }
// ResolveIP with a host, return ip // ResolveIP with a host, return ip
func ResolveIP(host string) (net.IP, error) { func ResolveIP(host string) (net.IP, error) {
return ResolveIPWithResolver(host, DefaultResolver) ips, err := LookupIP(context.Background(), host)
if err != nil {
return nil, err
} else if len(ips) == 0 {
return nil, fmt.Errorf("%w: %s", ErrIPNotFound, host)
}
return ips[rand.Intn(len(ips))], nil
} }

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"math/rand"
"net" "net"
"strings" "strings"
@ -36,9 +37,13 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error)
return nil, fmt.Errorf("dns %s not a valid ip", c.host) return nil, fmt.Errorf("dns %s not a valid ip", c.host)
} }
} else { } else {
if ip, err = resolver.ResolveIPWithResolver(c.host, c.r); err != nil { ips, err := resolver.LookupIPWithResolver(ctx, c.host, c.r)
if err != nil {
return nil, fmt.Errorf("use default dns resolve failed: %w", err) return nil, fmt.Errorf("use default dns resolve failed: %w", err)
} else if len(ips) == 0 {
return nil, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, c.host)
} }
ip = ips[rand.Intn(len(ips))]
} }
network := "udp" network := "udp"

View file

@ -4,7 +4,9 @@ import (
"bytes" "bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"fmt"
"io" "io"
"math/rand"
"net" "net"
"net/http" "net/http"
@ -91,10 +93,13 @@ func newDoHClient(url, iface string, r *Resolver) *dohClient {
return nil, err return nil, err
} }
ip, err := resolver.ResolveIPWithResolver(host, r) ips, err := resolver.LookupIPWithResolver(ctx, host, r)
if err != nil { if err != nil {
return nil, err return nil, err
} else if len(ips) == 0 {
return nil, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, host)
} }
ip := ips[rand.Intn(len(ips))]
options := []dialer.Option{} options := []dialer.Option{}
if iface != "" { if iface != "" {

View file

@ -42,19 +42,23 @@ type Resolver struct {
policy *trie.DomainTrie policy *trie.DomainTrie
} }
// ResolveIP request with TypeA and TypeAAAA, priority return TypeA // LookupIP request with TypeA and TypeAAAA, priority return TypeA
func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) { func (r *Resolver) LookupIP(ctx context.Context, host string) (ip []net.IP, err error) {
ch := make(chan net.IP, 1) ctx, cancel := context.WithCancel(ctx)
defer cancel()
ch := make(chan []net.IP, 1)
go func() { go func() {
defer close(ch) defer close(ch)
ip, err := r.resolveIP(host, D.TypeAAAA) ip, err := r.lookupIP(ctx, host, D.TypeAAAA)
if err != nil { if err != nil {
return return
} }
ch <- ip ch <- ip
}() }()
ip, err = r.resolveIP(host, D.TypeA) ip, err = r.lookupIP(ctx, host, D.TypeA)
if err == nil { if err == nil {
return return
} }
@ -67,14 +71,47 @@ func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) {
return ip, nil return ip, nil
} }
// ResolveIP request with TypeA and TypeAAAA, priority return TypeA
func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) {
ips, err := r.LookupIP(context.Background(), host)
if err != nil {
return nil, err
} else if len(ips) == 0 {
return nil, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, host)
}
return ips[rand.Intn(len(ips))], nil
}
// LookupIPv4 request with TypeA
func (r *Resolver) LookupIPv4(ctx context.Context, host string) ([]net.IP, error) {
return r.lookupIP(ctx, host, D.TypeA)
}
// ResolveIPv4 request with TypeA // ResolveIPv4 request with TypeA
func (r *Resolver) ResolveIPv4(host string) (ip net.IP, err error) { func (r *Resolver) ResolveIPv4(host string) (ip net.IP, err error) {
return r.resolveIP(host, D.TypeA) ips, err := r.lookupIP(context.Background(), host, D.TypeA)
if err != nil {
return nil, err
} else if len(ips) == 0 {
return nil, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, host)
}
return ips[rand.Intn(len(ips))], nil
}
// LookupIPv6 request with TypeAAAA
func (r *Resolver) LookupIPv6(ctx context.Context, host string) ([]net.IP, error) {
return r.lookupIP(ctx, host, D.TypeAAAA)
} }
// ResolveIPv6 request with TypeAAAA // ResolveIPv6 request with TypeAAAA
func (r *Resolver) ResolveIPv6(host string) (ip net.IP, err error) { func (r *Resolver) ResolveIPv6(host string) (ip net.IP, err error) {
return r.resolveIP(host, D.TypeAAAA) ips, err := r.lookupIP(context.Background(), host, D.TypeAAAA)
if err != nil {
return nil, err
} else if len(ips) == 0 {
return nil, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, host)
}
return ips[rand.Intn(len(ips))], nil
} }
func (r *Resolver) shouldIPFallback(ip net.IP) bool { func (r *Resolver) shouldIPFallback(ip net.IP) bool {
@ -253,14 +290,15 @@ func (r *Resolver) ipExchange(ctx context.Context, m *D.Msg) (msg *D.Msg, err er
return return
} }
func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error) { func (r *Resolver) lookupIP(ctx context.Context, host string, dnsType uint16) ([]net.IP, error) {
ip = net.ParseIP(host) ip := net.ParseIP(host)
if ip != nil { if ip != nil {
isIPv4 := ip.To4() != nil ip4 := ip.To4()
isIPv4 := ip4 != nil
if dnsType == D.TypeAAAA && !isIPv4 { if dnsType == D.TypeAAAA && !isIPv4 {
return ip, nil return []net.IP{ip}, nil
} else if dnsType == D.TypeA && isIPv4 { } else if dnsType == D.TypeA && isIPv4 {
return ip, nil return []net.IP{ip4}, nil
} else { } else {
return nil, resolver.ErrIPVersion return nil, resolver.ErrIPVersion
} }
@ -275,13 +313,10 @@ func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error)
} }
ips := msgToIP(msg) ips := msgToIP(msg)
ipLength := len(ips) if len(ips) == 0 {
if ipLength == 0 {
return nil, resolver.ErrIPNotFound return nil, resolver.ErrIPNotFound
} }
return ips, nil
ip = ips[rand.Intn(ipLength)]
return
} }
func (r *Resolver) msgToDomain(msg *D.Msg) string { func (r *Resolver) msgToDomain(msg *D.Msg) string {

View file

@ -180,11 +180,13 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
// local resolve UDP dns // local resolve UDP dns
if !metadata.Resolved() { if !metadata.Resolved() {
ip, err := resolver.ResolveIP(metadata.Host) ips, err := resolver.LookupIP(context.Background(), metadata.Host)
if err != nil { if err != nil {
return return
} else if len(ips) == 0 {
return
} }
metadata.DstIP = ip metadata.DstIP = ips[0]
} }
key := packet.LocalAddr().String() key := packet.LocalAddr().String()