diff --git a/common/observable/observable_test.go b/common/observable/observable_test.go index d965fa3b..67e29341 100644 --- a/common/observable/observable_test.go +++ b/common/observable/observable_test.go @@ -1,8 +1,8 @@ package observable import ( - "runtime" "sync" + "sync/atomic" "testing" "time" @@ -38,20 +38,20 @@ func TestObservable_MutilSubscribe(t *testing.T) { src := NewObservable(iter) ch1, _ := src.Subscribe() ch2, _ := src.Subscribe() - count := 0 + var count int32 var wg sync.WaitGroup wg.Add(2) waitCh := func(ch <-chan interface{}) { for range ch { - count++ + atomic.AddInt32(&count, 1) } wg.Done() } go waitCh(ch1) go waitCh(ch2) wg.Wait() - assert.Equal(t, count, 10) + assert.Equal(t, int32(10), count) } func TestObservable_UnSubscribe(t *testing.T) { @@ -82,9 +82,6 @@ func TestObservable_UnSubscribeWithNotExistSubscription(t *testing.T) { } func TestObservable_SubscribeGoroutineLeak(t *testing.T) { - // waiting for other goroutine recycle - time.Sleep(120 * time.Millisecond) - init := runtime.NumGoroutine() iter := iterator([]interface{}{1, 2, 3, 4, 5}) src := NewObservable(iter) max := 100 @@ -107,6 +104,12 @@ func TestObservable_SubscribeGoroutineLeak(t *testing.T) { go waitCh(ch) } wg.Wait() - now := runtime.NumGoroutine() - assert.Equal(t, init, now) + + for _, sub := range list { + _, more := <-sub + assert.False(t, more) + } + + _, more := <-list[0] + assert.False(t, more) } diff --git a/common/singledo/singledo.go b/common/singledo/singledo.go index 58a95434..0db56c1f 100644 --- a/common/singledo/singledo.go +++ b/common/singledo/singledo.go @@ -44,9 +44,12 @@ func (s *Single) Do(fn func() (interface{}, error)) (v interface{}, err error, s s.mux.Unlock() call.val, call.err = fn() call.wg.Done() + + s.mux.Lock() s.call = nil s.result = &Result{call.val, call.err} s.last = now + s.mux.Unlock() return call.val, call.err, false } diff --git a/common/singledo/singledo_test.go b/common/singledo/singledo_test.go index 637557c4..c9c58e58 100644 --- a/common/singledo/singledo_test.go +++ b/common/singledo/singledo_test.go @@ -2,6 +2,7 @@ package singledo import ( "sync" + "sync/atomic" "testing" "time" @@ -11,7 +12,7 @@ import ( func TestBasic(t *testing.T) { single := NewSingle(time.Millisecond * 30) foo := 0 - shardCount := 0 + var shardCount int32 = 0 call := func() (interface{}, error) { foo++ time.Sleep(time.Millisecond * 5) @@ -25,7 +26,7 @@ func TestBasic(t *testing.T) { go func() { _, _, shard := single.Do(call) if shard { - shardCount++ + atomic.AddInt32(&shardCount, 1) } wg.Done() }() @@ -33,7 +34,7 @@ func TestBasic(t *testing.T) { wg.Wait() assert.Equal(t, 1, foo) - assert.Equal(t, 4, shardCount) + assert.Equal(t, int32(4), shardCount) } func TestTimer(t *testing.T) {