From 07d75d52e60a99e0d9cee35db406968e8218c3ac Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Sun, 11 Jun 2023 20:58:51 +0800 Subject: [PATCH] chore: Disable cache for RCode client --- dns/dhcp.go | 3 ++- dns/resolver.go | 15 ++++++++++----- dns/util.go | 11 +++++++---- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/dns/dhcp.go b/dns/dhcp.go index a6c1df76..7d420d89 100644 --- a/dns/dhcp.go +++ b/dns/dhcp.go @@ -59,7 +59,8 @@ func (d *dhcpClient) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, return nil, err } - return batchExchange(ctx, clients, m) + msg, _, err = batchExchange(ctx, clients, m) + return } func (d *dhcpClient) resolve(ctx context.Context) ([]dnsClient, error) { diff --git a/dns/resolver.go b/dns/resolver.go index 5ae7ba33..8f41a44e 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -182,6 +182,7 @@ func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg) (msg *D.M fn := func() (result any, err error) { ctx, cancel := context.WithTimeout(context.Background(), resolver.DefaultDNSTimeout) // reset timeout in singleflight defer cancel() + cache := false defer func() { if err != nil { @@ -192,7 +193,9 @@ func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg) (msg *D.M msg := result.(*D.Msg) - putMsgToCache(r.lruCache, q.String(), msg) + if cache { + putMsgToCache(r.lruCache, q.String(), msg) + } }() isIPReq := isIPRequest(q) @@ -201,9 +204,11 @@ func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg) (msg *D.M } if matched := r.matchPolicy(m); len(matched) != 0 { - return r.batchExchange(ctx, matched, m) + result, cache, err = r.batchExchange(ctx, matched, m) + return } - return r.batchExchange(ctx, r.main, m) + result, cache, err = r.batchExchange(ctx, r.main, m) + return } ch := r.group.DoChan(q.String(), fn) @@ -244,7 +249,7 @@ func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg) (msg *D.M return } -func (r *Resolver) batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) { +func (r *Resolver) batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, cache bool, err error) { ctx, cancel := context.WithTimeout(ctx, resolver.DefaultDNSTimeout) defer cancel() @@ -371,7 +376,7 @@ func (r *Resolver) lookupIP(ctx context.Context, host string, dnsType uint16) (i func (r *Resolver) asyncExchange(ctx context.Context, client []dnsClient, msg *D.Msg) <-chan *result { ch := make(chan *result, 1) go func() { - res, err := r.batchExchange(ctx, client, msg) + res, _, err := r.batchExchange(ctx, client, msg) ch <- &result{Msg: res, Error: err} }() return ch diff --git a/dns/util.go b/dns/util.go index a409d679..2ba4d426 100644 --- a/dns/util.go +++ b/dns/util.go @@ -287,17 +287,20 @@ func listenPacket(ctx context.Context, proxyAdapter C.ProxyAdapter, proxyName st return proxyAdapter.ListenPacketContext(ctx, metadata, opts...) } -func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) { +func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, cache bool, err error) { fast, ctx := picker.WithTimeout[*D.Msg](ctx, resolver.DefaultDNSTimeout) domain := msgToDomain(m) for _, client := range clients { - _, ignoreRCodeError := client.(rcodeClient) + _, cache = client.(rcodeClient) + cache = !cache fast.Go(func() (*D.Msg, error) { log.Debugln("[DNS] resolve %s from %s", domain, client.Address()) m, err := client.ExchangeContext(ctx, m) if err != nil { return nil, err - } else if !ignoreRCodeError && (m.Rcode == D.RcodeServerFailure || m.Rcode == D.RcodeRefused) { + } else if cache && (m.Rcode == D.RcodeServerFailure || m.Rcode == D.RcodeRefused) { + // currently, cache indicates whether this msg was from a RCode client, + // so we would ignore RCode errors from RCode clients. return nil, errors.New("server failure") } log.Debugln("[DNS] %s --> %s, from %s", domain, msgToIP(m), client.Address()) @@ -311,7 +314,7 @@ func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.M if fErr := fast.Error(); fErr != nil { err = fmt.Errorf("%w, first error: %s", err, fErr.Error()) } - return nil, err + return nil, true, err } msg = elm return