diff --git a/adapters/outboundgroup/selector.go b/adapters/outboundgroup/selector.go index 88ac815d..1215dadb 100644 --- a/adapters/outboundgroup/selector.go +++ b/adapters/outboundgroup/selector.go @@ -59,6 +59,7 @@ func (s *Selector) Set(name string) error { for _, proxy := range getProvidersProxies(s.providers) { if proxy.Name() == name { s.selected = name + s.single.Reset() return nil } } diff --git a/common/singledo/singledo.go b/common/singledo/singledo.go index 84978115..58a95434 100644 --- a/common/singledo/singledo.go +++ b/common/singledo/singledo.go @@ -50,6 +50,10 @@ func (s *Single) Do(fn func() (interface{}, error)) (v interface{}, err error, s return call.val, call.err, false } +func (s *Single) Reset() { + s.last = time.Time{} +} + func NewSingle(wait time.Duration) *Single { return &Single{wait: wait} } diff --git a/common/singledo/singledo_test.go b/common/singledo/singledo_test.go index c1e48ca8..637557c4 100644 --- a/common/singledo/singledo_test.go +++ b/common/singledo/singledo_test.go @@ -19,7 +19,7 @@ func TestBasic(t *testing.T) { } var wg sync.WaitGroup - const n = 10 + const n = 5 wg.Add(n) for i := 0; i < n; i++ { go func() { @@ -33,7 +33,7 @@ func TestBasic(t *testing.T) { wg.Wait() assert.Equal(t, 1, foo) - assert.Equal(t, 9, shardCount) + assert.Equal(t, 4, shardCount) } func TestTimer(t *testing.T) { @@ -51,3 +51,18 @@ func TestTimer(t *testing.T) { assert.Equal(t, 1, foo) assert.True(t, shard) } + +func TestReset(t *testing.T) { + single := NewSingle(time.Millisecond * 30) + foo := 0 + call := func() (interface{}, error) { + foo++ + return nil, nil + } + + single.Do(call) + single.Reset() + single.Do(call) + + assert.Equal(t, 2, foo) +}