fix: DNS cache

This commit is contained in:
H1JK 2023-07-14 09:55:43 +08:00
parent 0b1aff5759
commit 492a731ec1
2 changed files with 22 additions and 8 deletions

View file

@ -47,6 +47,7 @@ func (p *Picker[T]) Wait() T {
p.wg.Wait() p.wg.Wait()
if p.cancel != nil { if p.cancel != nil {
p.cancel() p.cancel()
p.cancel = nil
} }
return p.result return p.result
} }
@ -69,6 +70,7 @@ func (p *Picker[T]) Go(f func() (T, error)) {
p.result = ret p.result = ret
if p.cancel != nil { if p.cancel != nil {
p.cancel() p.cancel()
p.cancel = nil
} }
}) })
} else { } else {
@ -78,3 +80,13 @@ func (p *Picker[T]) Go(f func() (T, error)) {
} }
}() }()
} }
// Close cancels the picker context and releases resources associated with it.
// If Wait has been called, then there is no need to call Close.
func (p *Picker[T]) Close() error {
if p.cancel != nil {
p.cancel()
p.cancel = nil
}
return nil
}

View file

@ -288,12 +288,16 @@ func listenPacket(ctx context.Context, proxyAdapter C.ProxyAdapter, proxyName st
} }
func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, cache bool, err error) { func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, cache bool, err error) {
cache = true
fast, ctx := picker.WithTimeout[*D.Msg](ctx, resolver.DefaultDNSTimeout) fast, ctx := picker.WithTimeout[*D.Msg](ctx, resolver.DefaultDNSTimeout)
defer fast.Close()
domain := msgToDomain(m) domain := msgToDomain(m)
for _, client := range clients { for _, client := range clients {
if _, isRCodeClient := client.(rcodeClient); isRCodeClient {
msg, err = client.Exchange(m)
return msg, false, err
}
client := client // shadow define client to ensure the value captured by the closure will not be changed in the next loop client := client // shadow define client to ensure the value captured by the closure will not be changed in the next loop
_, cache = client.(rcodeClient)
cache = !cache
fast.Go(func() (*D.Msg, error) { fast.Go(func() (*D.Msg, error) {
log.Debugln("[DNS] resolve %s from %s", domain, client.Address()) log.Debugln("[DNS] resolve %s from %s", domain, client.Address())
m, err := client.ExchangeContext(ctx, m) m, err := client.ExchangeContext(ctx, m)
@ -302,21 +306,19 @@ func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.M
} else if cache && (m.Rcode == D.RcodeServerFailure || m.Rcode == D.RcodeRefused) { } else if cache && (m.Rcode == D.RcodeServerFailure || m.Rcode == D.RcodeRefused) {
// currently, cache indicates whether this msg was from a RCode client, // currently, cache indicates whether this msg was from a RCode client,
// so we would ignore RCode errors from RCode clients. // so we would ignore RCode errors from RCode clients.
return nil, errors.New("server failure") return nil, errors.New("server failure: " + D.RcodeToString[m.Rcode])
} }
log.Debugln("[DNS] %s --> %s, from %s", domain, msgToIP(m), client.Address()) log.Debugln("[DNS] %s --> %s, from %s", domain, msgToIP(m), client.Address())
return m, nil return m, nil
}) })
} }
elm := fast.Wait() msg = fast.Wait()
if elm == nil { if msg == nil {
err := errors.New("all DNS requests failed") err = errors.New("all DNS requests failed")
if fErr := fast.Error(); fErr != nil { if fErr := fast.Error(); fErr != nil {
err = fmt.Errorf("%w, first error: %w", err, fErr) err = fmt.Errorf("%w, first error: %w", err, fErr)
} }
return nil, true, err
} }
msg = elm
return return
} }