diff --git a/component/resolver/resolver.go b/component/resolver/resolver.go index 61258898..356bc660 100644 --- a/component/resolver/resolver.go +++ b/component/resolver/resolver.go @@ -1,9 +1,11 @@ package resolver import ( + "context" "errors" "net" "strings" + "time" "github.com/Dreamacro/clash/component/trie" ) @@ -18,6 +20,9 @@ var ( // DefaultHosts aim to resolve hosts DefaultHosts = trie.New() + + // DefaultDNSTimeout defined the default dns request timeout + DefaultDNSTimeout = time.Second * 5 ) var ( @@ -52,18 +57,16 @@ func ResolveIPv4(host string) (net.IP, error) { return DefaultResolver.ResolveIPv4(host) } - ipAddrs, err := net.LookupIP(host) + ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) + defer cancel() + ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip4", host) if err != nil { return nil, err + } else if len(ipAddrs) == 0 { + return nil, ErrIPNotFound } - for _, ip := range ipAddrs { - if ip4 := ip.To4(); ip4 != nil { - return ip4, nil - } - } - - return nil, ErrIPNotFound + return ipAddrs[0], nil } // ResolveIPv6 with a host, return ipv6 @@ -90,18 +93,16 @@ func ResolveIPv6(host string) (net.IP, error) { return DefaultResolver.ResolveIPv6(host) } - ipAddrs, err := net.LookupIP(host) + ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) + defer cancel() + ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip6", host) if err != nil { return nil, err + } else if len(ipAddrs) == 0 { + return nil, ErrIPNotFound } - for _, ip := range ipAddrs { - if ip.To4() == nil { - return ip, nil - } - } - - return nil, ErrIPNotFound + return ipAddrs[0], nil } // ResolveIP with a host, return ip diff --git a/dns/resolver.go b/dns/resolver.go index d110aa34..0db6855d 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -145,7 +145,7 @@ func (r *Resolver) exchangeWithoutCache(m *D.Msg) (msg *D.Msg, err error) { } func (r *Resolver) batchExchange(clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) { - fast, ctx := picker.WithTimeout(context.Background(), time.Second*5) + fast, ctx := picker.WithTimeout(context.Background(), resolver.DefaultDNSTimeout) for _, client := range clients { r := client fast.Go(func() (interface{}, error) {