diff --git a/dns/dhcp.go b/dns/dhcp.go index f964cec8..8e0d5d4c 100644 --- a/dns/dhcp.go +++ b/dns/dhcp.go @@ -29,7 +29,7 @@ type dhcpClient struct { ifaceAddr *net.IPNet done chan struct{} - resolver *Resolver + clients []dnsClient err error } @@ -41,15 +41,15 @@ func (d *dhcpClient) Exchange(m *D.Msg) (msg *D.Msg, err error) { } func (d *dhcpClient) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { - res, err := d.resolve(ctx) + clients, err := d.resolve(ctx) if err != nil { return nil, err } - return res.ExchangeContext(ctx, m) + return batchExchange(ctx, clients, m) } -func (d *dhcpClient) resolve(ctx context.Context) (*Resolver, error) { +func (d *dhcpClient) resolve(ctx context.Context) ([]dnsClient, error) { d.lock.Lock() invalidated, err := d.invalidate() @@ -64,8 +64,9 @@ func (d *dhcpClient) resolve(ctx context.Context) (*Resolver, error) { ctx, cancel := context.WithTimeout(context.Background(), DHCPTimeout) defer cancel() - var res *Resolver + var res []dnsClient dns, err := dhcp.ResolveDNSFromDHCP(ctx, d.ifaceName) + // dns never empty if err is nil if err == nil { nameserver := make([]NameServer, 0, len(dns)) for _, item := range dns { @@ -75,9 +76,7 @@ func (d *dhcpClient) resolve(ctx context.Context) (*Resolver, error) { }) } - res = NewResolver(Config{ - Main: nameserver, - }) + res = transform(nameserver, nil) } d.lock.Lock() @@ -86,7 +85,7 @@ func (d *dhcpClient) resolve(ctx context.Context) (*Resolver, error) { close(done) d.done = nil - d.resolver = res + d.clients = res d.err = err }() } @@ -96,7 +95,7 @@ func (d *dhcpClient) resolve(ctx context.Context) (*Resolver, error) { for { d.lock.Lock() - res, err, done := d.resolver, d.err, d.done + res, err, done := d.clients, d.err, d.done d.lock.Unlock() diff --git a/dns/resolver.go b/dns/resolver.go index 2f993600..4d3634a4 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -10,7 +10,6 @@ import ( "time" "github.com/Dreamacro/clash/common/cache" - "github.com/Dreamacro/clash/common/picker" "github.com/Dreamacro/clash/component/fakeip" "github.com/Dreamacro/clash/component/resolver" "github.com/Dreamacro/clash/component/trie" @@ -187,31 +186,10 @@ func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg) (msg *D.M } func (r *Resolver) batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) { - fast, ctx := picker.WithTimeout(ctx, resolver.DefaultDNSTimeout) - for _, client := range clients { - r := client - fast.Go(func() (any, error) { - m, err := r.ExchangeContext(ctx, m) - if err != nil { - return nil, err - } else if m.Rcode == D.RcodeServerFailure || m.Rcode == D.RcodeRefused { - return nil, errors.New("server failure") - } - return m, nil - }) - } + ctx, cancel := context.WithTimeout(ctx, resolver.DefaultDNSTimeout) + defer cancel() - elm := fast.Wait() - if elm == nil { - err := errors.New("all DNS requests failed") - if fErr := fast.Error(); fErr != nil { - err = fmt.Errorf("%w, first error: %s", err, fErr.Error()) - } - return nil, err - } - - msg = elm.(*D.Msg) - return + return batchExchange(ctx, clients, m) } func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient { diff --git a/dns/util.go b/dns/util.go index 99590d63..df3a3331 100644 --- a/dns/util.go +++ b/dns/util.go @@ -1,11 +1,15 @@ package dns import ( + "context" "crypto/tls" + "errors" + "fmt" "net" "time" "github.com/Dreamacro/clash/common/cache" + "github.com/Dreamacro/clash/common/picker" "github.com/Dreamacro/clash/log" D "github.com/miekg/dns" @@ -102,3 +106,31 @@ func msgToIP(msg *D.Msg) []net.IP { return ips } + +func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) { + fast, ctx := picker.WithContext(ctx) + for _, client := range clients { + r := client + fast.Go(func() (any, error) { + m, err := r.ExchangeContext(ctx, m) + if err != nil { + return nil, err + } else if m.Rcode == D.RcodeServerFailure || m.Rcode == D.RcodeRefused { + return nil, errors.New("server failure") + } + return m, nil + }) + } + + elm := fast.Wait() + if elm == nil { + err := errors.New("all DNS requests failed") + if fErr := fast.Error(); fErr != nil { + err = fmt.Errorf("%w, first error: %s", err, fErr.Error()) + } + return nil, err + } + + msg = elm.(*D.Msg) + return +}