From 36b5d1f18f28f2d01022ead8b90d04cc6b31cf79 Mon Sep 17 00:00:00 2001 From: Dreamacro <305009791@qq.com> Date: Fri, 25 Jan 2019 15:38:14 +0800 Subject: [PATCH] Fix: DNS server returns the correct TTL --- common/cache/cache.go | 15 +++++++++++++++ dns/client.go | 11 +++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/common/cache/cache.go b/common/cache/cache.go index f33591ea..bb16adca 100644 --- a/common/cache/cache.go +++ b/common/cache/cache.go @@ -44,6 +44,21 @@ func (c *cache) Get(key interface{}) interface{} { return elm.Payload } +// GetWithExpire element in Cache with Expire Time +func (c *cache) GetWithExpire(key interface{}) (item interface{}, expired time.Time) { + item, exist := c.mapping.Load(key) + if !exist { + return + } + elm := item.(*element) + // expired + if time.Since(elm.Expired) > 0 { + c.mapping.Delete(key) + return + } + return elm.Payload, elm.Expired +} + func (c *cache) cleanup() { c.mapping.Range(func(k, v interface{}) bool { key := k.(string) diff --git a/dns/client.go b/dns/client.go index f8c711b1..a0d3a383 100644 --- a/dns/client.go +++ b/dns/client.go @@ -52,9 +52,16 @@ func (r *Resolver) Exchange(m *D.Msg) (msg *D.Msg, err error) { } q := m.Question[0] - cache := r.cache.Get(q.String()) + cache, expireTime := r.cache.GetWithExpire(q.String()) if cache != nil { - return cache.(*D.Msg).Copy(), nil + msg = cache.(*D.Msg).Copy() + if len(msg.Answer) > 0 { + ttl := uint32(expireTime.Sub(time.Now()).Seconds()) + for _, answer := range msg.Answer { + answer.Header().Ttl = ttl + } + } + return } defer func() { if msg != nil {