diff --git a/common/cache/cache.go b/common/cache/cache.go index f33591ea..bb16adca 100644 --- a/common/cache/cache.go +++ b/common/cache/cache.go @@ -44,6 +44,21 @@ func (c *cache) Get(key interface{}) interface{} { return elm.Payload } +// GetWithExpire element in Cache with Expire Time +func (c *cache) GetWithExpire(key interface{}) (item interface{}, expired time.Time) { + item, exist := c.mapping.Load(key) + if !exist { + return + } + elm := item.(*element) + // expired + if time.Since(elm.Expired) > 0 { + c.mapping.Delete(key) + return + } + return elm.Payload, elm.Expired +} + func (c *cache) cleanup() { c.mapping.Range(func(k, v interface{}) bool { key := k.(string) diff --git a/dns/client.go b/dns/client.go index f8c711b1..a0d3a383 100644 --- a/dns/client.go +++ b/dns/client.go @@ -52,9 +52,16 @@ func (r *Resolver) Exchange(m *D.Msg) (msg *D.Msg, err error) { } q := m.Question[0] - cache := r.cache.Get(q.String()) + cache, expireTime := r.cache.GetWithExpire(q.String()) if cache != nil { - return cache.(*D.Msg).Copy(), nil + msg = cache.(*D.Msg).Copy() + if len(msg.Answer) > 0 { + ttl := uint32(expireTime.Sub(time.Now()).Seconds()) + for _, answer := range msg.Answer { + answer.Header().Ttl = ttl + } + } + return } defer func() { if msg != nil {