diff --git a/common/cache/lrucache.go b/common/cache/lrucache.go index 6fe1d74c..1b1e492a 100644 --- a/common/cache/lrucache.go +++ b/common/cache/lrucache.go @@ -42,6 +42,14 @@ func WithSize(maxSize int) Option { } } +// WithStale decide whether Stale return is enabled. +// If this feature is enabled, element will not get Evicted according to `WithAge`. +func WithStale(stale bool) Option { + return func(l *LruCache) { + l.staleReturn = stale + } +} + // LruCache is a thread-safe, in-memory lru-cache that evicts the // least recently used entries from memory when (if set) the entries are // older than maxAge (in seconds). Use the New constructor to create one. @@ -52,6 +60,7 @@ type LruCache struct { cache map[interface{}]*list.Element lru *list.List // Front is least-recent updateAgeOnGet bool + staleReturn bool onEvict EvictCallback } @@ -72,31 +81,28 @@ func NewLRUCache(options ...Option) *LruCache { // Get returns the interface{} representation of a cached response and a bool // set to true if the key was found. func (c *LruCache) Get(key interface{}) (interface{}, bool) { - c.mu.Lock() - defer c.mu.Unlock() - - le, ok := c.cache[key] - if !ok { + entry := c.get(key) + if entry == nil { return nil, false } - - if c.maxAge > 0 && le.Value.(*entry).expires <= time.Now().Unix() { - c.deleteElement(le) - c.maybeDeleteOldest() - - return nil, false - } - - c.lru.MoveToBack(le) - entry := le.Value.(*entry) - if c.maxAge > 0 && c.updateAgeOnGet { - entry.expires = time.Now().Unix() + c.maxAge - } value := entry.value return value, true } +// GetWithExpire returns the interface{} representation of a cached response, +// a time.Time Give expected expires, +// and a bool set to true if the key was found. +// This method will NOT check the maxAge of element and will NOT update the expires. +func (c *LruCache) GetWithExpire(key interface{}) (interface{}, time.Time, bool) { + entry := c.get(key) + if entry == nil { + return nil, time.Time{}, false + } + + return entry.value, time.Unix(entry.expires, 0), true +} + // Exist returns if key exist in cache but not put item to the head of linked list func (c *LruCache) Exist(key interface{}) bool { c.mu.Lock() @@ -108,21 +114,26 @@ func (c *LruCache) Exist(key interface{}) bool { // Set stores the interface{} representation of a response for a given key. func (c *LruCache) Set(key interface{}, value interface{}) { - c.mu.Lock() - defer c.mu.Unlock() - expires := int64(0) if c.maxAge > 0 { expires = time.Now().Unix() + c.maxAge } + c.SetWithExpire(key, value, time.Unix(expires, 0)) +} + +// SetWithExpire stores the interface{} representation of a response for a given key and given exires. +// The expires time will round to second. +func (c *LruCache) SetWithExpire(key interface{}, value interface{}, expires time.Time) { + c.mu.Lock() + defer c.mu.Unlock() if le, ok := c.cache[key]; ok { c.lru.MoveToBack(le) e := le.Value.(*entry) e.value = value - e.expires = expires + e.expires = expires.Unix() } else { - e := &entry{key: key, value: value, expires: expires} + e := &entry{key: key, value: value, expires: expires.Unix()} c.cache[key] = c.lru.PushBack(e) if c.maxSize > 0 { @@ -135,6 +146,30 @@ func (c *LruCache) Set(key interface{}, value interface{}) { c.maybeDeleteOldest() } +func (c *LruCache) get(key interface{}) *entry { + c.mu.Lock() + defer c.mu.Unlock() + + le, ok := c.cache[key] + if !ok { + return nil + } + + if !c.staleReturn && c.maxAge > 0 && le.Value.(*entry).expires <= time.Now().Unix() { + c.deleteElement(le) + c.maybeDeleteOldest() + + return nil + } + + c.lru.MoveToBack(le) + entry := le.Value.(*entry) + if c.maxAge > 0 && c.updateAgeOnGet { + entry.expires = time.Now().Unix() + c.maxAge + } + return entry +} + // Delete removes the value associated with a key. func (c *LruCache) Delete(key string) { c.mu.Lock() @@ -147,7 +182,7 @@ func (c *LruCache) Delete(key string) { } func (c *LruCache) maybeDeleteOldest() { - if c.maxAge > 0 { + if !c.staleReturn && c.maxAge > 0 { now := time.Now().Unix() for le := c.lru.Front(); le != nil && le.Value.(*entry).expires <= now; le = c.lru.Front() { c.deleteElement(le) diff --git a/common/cache/lrucache_test.go b/common/cache/lrucache_test.go index b296d6b9..c3c629d9 100644 --- a/common/cache/lrucache_test.go +++ b/common/cache/lrucache_test.go @@ -136,3 +136,31 @@ func TestEvict(t *testing.T) { assert.Equal(t, temp, 3) } + +func TestSetWithExpire(t *testing.T) { + c := NewLRUCache(WithAge(1)) + now := time.Now().Unix() + + tenSecBefore := time.Unix(now-10, 0) + c.SetWithExpire(1, 2, tenSecBefore) + + // res is expected not to exist, and expires should be empty time.Time + res, expires, exist := c.GetWithExpire(1) + assert.Equal(t, nil, res) + assert.Equal(t, time.Time{}, expires) + assert.Equal(t, false, exist) + +} + +func TestStale(t *testing.T) { + c := NewLRUCache(WithAge(1), WithStale(true)) + now := time.Now().Unix() + + tenSecBefore := time.Unix(now-10, 0) + c.SetWithExpire(1, 2, tenSecBefore) + + res, expires, exist := c.GetWithExpire(1) + assert.Equal(t, 2, res) + assert.Equal(t, tenSecBefore, expires) + assert.Equal(t, true, exist) +} diff --git a/dns/resolver.go b/dns/resolver.go index eec7d412..7f50a1a0 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -42,7 +42,7 @@ type Resolver struct { fallback []dnsClient fallbackFilters []fallbackFilter group singleflight.Group - cache *cache.Cache + lruCache *cache.LruCache } // ResolveIP request with TypeA and TypeAAAA, priority return TypeA @@ -96,22 +96,35 @@ func (r *Resolver) Exchange(m *D.Msg) (msg *D.Msg, err error) { } q := m.Question[0] - cache, expireTime := r.cache.GetWithExpire(q.String()) - if cache != nil { + cache, expireTime, hit := r.lruCache.GetWithExpire(q.String()) + if hit { + now := time.Now() msg = cache.(*D.Msg).Copy() - setMsgTTL(msg, uint32(expireTime.Sub(time.Now()).Seconds())) + if expireTime.Before(now) { + setMsgTTL(msg, uint32(1)) // Continue fetch + go r.exchangeWithoutCache(m) + } else { + setMsgTTL(msg, uint32(expireTime.Sub(time.Now()).Seconds())) + } return } + return r.exchangeWithoutCache(m) +} + +// ExchangeWithoutCache a batch of dns request, and it do NOT GET from cache +func (r *Resolver) exchangeWithoutCache(m *D.Msg) (msg *D.Msg, err error) { + q := m.Question[0] + defer func() { if msg == nil { return } - putMsgToCache(r.cache, q.String(), msg) + putMsgToCache(r.lruCache, q.String(), msg) if r.mapping { ips := r.msgToIP(msg) for _, ip := range ips { - putMsgToCache(r.cache, ip.String(), msg) + putMsgToCache(r.lruCache, ip.String(), msg) } } }() @@ -141,7 +154,7 @@ func (r *Resolver) IPToHost(ip net.IP) (string, bool) { return r.pool.LookBack(ip) } - cache := r.cache.Get(ip.String()) + cache, _ := r.lruCache.Get(ip.String()) if cache == nil { return "", false } @@ -294,17 +307,17 @@ type Config struct { func New(config Config) *Resolver { defaultResolver := &Resolver{ - main: transform(config.Default, nil), - cache: cache.New(time.Second * 60), + main: transform(config.Default, nil), + lruCache: cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)), } r := &Resolver{ - ipv6: config.IPv6, - main: transform(config.Main, defaultResolver), - cache: cache.New(time.Second * 60), - mapping: config.EnhancedMode == MAPPING, - fakeip: config.EnhancedMode == FAKEIP, - pool: config.Pool, + ipv6: config.IPv6, + main: transform(config.Main, defaultResolver), + lruCache: cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)), + mapping: config.EnhancedMode == MAPPING, + fakeip: config.EnhancedMode == FAKEIP, + pool: config.Pool, } if len(config.Fallback) != 0 { diff --git a/dns/util.go b/dns/util.go index b1541091..b6adf8d6 100644 --- a/dns/util.go +++ b/dns/util.go @@ -79,21 +79,21 @@ func (e EnhancedMode) String() string { } } -func putMsgToCache(c *cache.Cache, key string, msg *D.Msg) { - var ttl time.Duration +func putMsgToCache(c *cache.LruCache, key string, msg *D.Msg) { + var ttl uint32 switch { case len(msg.Answer) != 0: - ttl = time.Duration(msg.Answer[0].Header().Ttl) * time.Second + ttl = msg.Answer[0].Header().Ttl case len(msg.Ns) != 0: - ttl = time.Duration(msg.Ns[0].Header().Ttl) * time.Second + ttl = msg.Ns[0].Header().Ttl case len(msg.Extra) != 0: - ttl = time.Duration(msg.Extra[0].Header().Ttl) * time.Second + ttl = msg.Extra[0].Header().Ttl default: log.Debugln("[DNS] response msg error: %#v", msg) return } - c.Put(key, msg.Copy(), ttl) + c.SetWithExpire(key, msg.Copy(), time.Now().Add(time.Second*time.Duration(ttl))) } func setMsgTTL(msg *D.Msg, ttl uint32) {