Fix: refactor observable to solve a small-probability crash

This commit is contained in:
gVisor bot 2019-10-30 15:43:55 +08:00
parent 2a75f9aa1e
commit b9883e8fad
2 changed files with 31 additions and 45 deletions

View file

@ -7,60 +7,58 @@ import (
type Observable struct { type Observable struct {
iterable Iterable iterable Iterable
listener *sync.Map listener map[Subscription]*Subscriber
mux sync.Mutex
done bool done bool
doneLock sync.RWMutex
} }
func (o *Observable) process() { func (o *Observable) process() {
for item := range o.iterable { for item := range o.iterable {
o.listener.Range(func(key, value interface{}) bool { o.mux.Lock()
elm := value.(*Subscriber) for _, sub := range o.listener {
elm.Emit(item) sub.Emit(item)
return true }
}) o.mux.Unlock()
} }
o.close() o.close()
} }
func (o *Observable) close() { func (o *Observable) close() {
o.doneLock.Lock() o.mux.Lock()
o.done = true defer o.mux.Unlock()
o.doneLock.Unlock()
o.listener.Range(func(key, value interface{}) bool { o.done = true
elm := value.(*Subscriber) for _, sub := range o.listener {
elm.Close() sub.Close()
return true }
})
} }
func (o *Observable) Subscribe() (Subscription, error) { func (o *Observable) Subscribe() (Subscription, error) {
o.doneLock.RLock() o.mux.Lock()
done := o.done defer o.mux.Unlock()
o.doneLock.RUnlock() if o.done {
if done == true {
return nil, errors.New("Observable is closed") return nil, errors.New("Observable is closed")
} }
subscriber := newSubscriber() subscriber := newSubscriber()
o.listener.Store(subscriber.Out(), subscriber) o.listener[subscriber.Out()] = subscriber
return subscriber.Out(), nil return subscriber.Out(), nil
} }
func (o *Observable) UnSubscribe(sub Subscription) { func (o *Observable) UnSubscribe(sub Subscription) {
elm, exist := o.listener.Load(sub) o.mux.Lock()
defer o.mux.Unlock()
subscriber, exist := o.listener[sub]
if !exist { if !exist {
return return
} }
subscriber := elm.(*Subscriber) delete(o.listener, sub)
o.listener.Delete(subscriber.Out())
subscriber.Close() subscriber.Close()
} }
func NewObservable(any Iterable) *Observable { func NewObservable(any Iterable) *Observable {
observable := &Observable{ observable := &Observable{
iterable: any, iterable: any,
listener: &sync.Map{}, listener: map[Subscription]*Subscriber{},
} }
go observable.process() go observable.process()
return observable return observable

View file

@ -5,6 +5,8 @@ import (
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
) )
func iterator(item []interface{}) chan interface{} { func iterator(item []interface{}) chan interface{} {
@ -23,16 +25,12 @@ func TestObservable(t *testing.T) {
iter := iterator([]interface{}{1, 2, 3, 4, 5}) iter := iterator([]interface{}{1, 2, 3, 4, 5})
src := NewObservable(iter) src := NewObservable(iter)
data, err := src.Subscribe() data, err := src.Subscribe()
if err != nil { assert.Nil(t, err)
t.Error(err)
}
count := 0 count := 0
for range data { for range data {
count++ count++
} }
if count != 5 { assert.Equal(t, count, 5)
t.Error("Revc number error")
}
} }
func TestObservable_MutilSubscribe(t *testing.T) { func TestObservable_MutilSubscribe(t *testing.T) {
@ -53,23 +51,17 @@ func TestObservable_MutilSubscribe(t *testing.T) {
go waitCh(ch1) go waitCh(ch1)
go waitCh(ch2) go waitCh(ch2)
wg.Wait() wg.Wait()
if count != 10 { assert.Equal(t, count, 10)
t.Error("Revc number error")
}
} }
func TestObservable_UnSubscribe(t *testing.T) { func TestObservable_UnSubscribe(t *testing.T) {
iter := iterator([]interface{}{1, 2, 3, 4, 5}) iter := iterator([]interface{}{1, 2, 3, 4, 5})
src := NewObservable(iter) src := NewObservable(iter)
data, err := src.Subscribe() data, err := src.Subscribe()
if err != nil { assert.Nil(t, err)
t.Error(err)
}
src.UnSubscribe(data) src.UnSubscribe(data)
_, open := <-data _, open := <-data
if open { assert.False(t, open)
t.Error("Revc number error")
}
} }
func TestObservable_SubscribeClosedSource(t *testing.T) { func TestObservable_SubscribeClosedSource(t *testing.T) {
@ -79,9 +71,7 @@ func TestObservable_SubscribeClosedSource(t *testing.T) {
<-data <-data
_, closed := src.Subscribe() _, closed := src.Subscribe()
if closed == nil { assert.NotNil(t, closed)
t.Error("Observable should be closed")
}
} }
func TestObservable_UnSubscribeWithNotExistSubscription(t *testing.T) { func TestObservable_UnSubscribeWithNotExistSubscription(t *testing.T) {
@ -118,7 +108,5 @@ func TestObservable_SubscribeGoroutineLeak(t *testing.T) {
} }
wg.Wait() wg.Wait()
now := runtime.NumGoroutine() now := runtime.NumGoroutine()
if init != now { assert.Equal(t, init, now)
t.Errorf("Goroutine Leak: init %d now %d", init, now)
}
} }