diff --git a/common/picker/picker.go b/common/picker/picker.go index 97004460..3a7688ca 100644 --- a/common/picker/picker.go +++ b/common/picker/picker.go @@ -47,6 +47,7 @@ func (p *Picker[T]) Wait() T { p.wg.Wait() if p.cancel != nil { p.cancel() + p.cancel = nil } return p.result } @@ -69,6 +70,7 @@ func (p *Picker[T]) Go(f func() (T, error)) { p.result = ret if p.cancel != nil { p.cancel() + p.cancel = nil } }) } else { @@ -78,3 +80,13 @@ func (p *Picker[T]) Go(f func() (T, error)) { } }() } + +// Close cancels the picker context and releases resources associated with it. +// If Wait has been called, then there is no need to call Close. +func (p *Picker[T]) Close() error { + if p.cancel != nil { + p.cancel() + p.cancel = nil + } + return nil +} diff --git a/dns/util.go b/dns/util.go index 73d117b8..739fd16b 100644 --- a/dns/util.go +++ b/dns/util.go @@ -288,12 +288,16 @@ func listenPacket(ctx context.Context, proxyAdapter C.ProxyAdapter, proxyName st } func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, cache bool, err error) { + cache = true fast, ctx := picker.WithTimeout[*D.Msg](ctx, resolver.DefaultDNSTimeout) + defer fast.Close() domain := msgToDomain(m) for _, client := range clients { + if _, isRCodeClient := client.(rcodeClient); isRCodeClient { + msg, err = client.Exchange(m) + return msg, false, err + } client := client // shadow define client to ensure the value captured by the closure will not be changed in the next loop - _, cache = client.(rcodeClient) - cache = !cache fast.Go(func() (*D.Msg, error) { log.Debugln("[DNS] resolve %s from %s", domain, client.Address()) m, err := client.ExchangeContext(ctx, m) @@ -302,21 +306,19 @@ func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.M } else if cache && (m.Rcode == D.RcodeServerFailure || m.Rcode == D.RcodeRefused) { // currently, cache indicates whether this msg was from a RCode client, // so we would ignore RCode errors from RCode clients. - return nil, errors.New("server failure") + return nil, errors.New("server failure: " + D.RcodeToString[m.Rcode]) } log.Debugln("[DNS] %s --> %s, from %s", domain, msgToIP(m), client.Address()) return m, nil }) } - elm := fast.Wait() - if elm == nil { - err := errors.New("all DNS requests failed") + msg = fast.Wait() + if msg == nil { + err = errors.New("all DNS requests failed") if fErr := fast.Error(); fErr != nil { err = fmt.Errorf("%w, first error: %w", err, fErr) } - return nil, true, err } - msg = elm return }