Fix: tunnel manager & tracker race condition (#1048)
This commit is contained in:
parent
b98e9ea202
commit
87e4d94290
7 changed files with 82 additions and 68 deletions
|
@ -6,11 +6,12 @@ import (
|
|||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Dreamacro/clash/common/queue"
|
||||
C "github.com/Dreamacro/clash/constant"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
type Base struct {
|
||||
|
@ -95,11 +96,11 @@ func newPacketConn(pc net.PacketConn, a C.ProxyAdapter) C.PacketConn {
|
|||
type Proxy struct {
|
||||
C.ProxyAdapter
|
||||
history *queue.Queue
|
||||
alive uint32
|
||||
alive *atomic.Bool
|
||||
}
|
||||
|
||||
func (p *Proxy) Alive() bool {
|
||||
return atomic.LoadUint32(&p.alive) > 0
|
||||
return p.alive.Load()
|
||||
}
|
||||
|
||||
func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) {
|
||||
|
@ -111,7 +112,7 @@ func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) {
|
|||
func (p *Proxy) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
|
||||
conn, err := p.ProxyAdapter.DialContext(ctx, metadata)
|
||||
if err != nil {
|
||||
atomic.StoreUint32(&p.alive, 0)
|
||||
p.alive.Store(false)
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
@ -128,7 +129,7 @@ func (p *Proxy) DelayHistory() []C.DelayHistory {
|
|||
// LastDelay return last history record. if proxy is not alive, return the max value of uint16.
|
||||
func (p *Proxy) LastDelay() (delay uint16) {
|
||||
var max uint16 = 0xffff
|
||||
if atomic.LoadUint32(&p.alive) == 0 {
|
||||
if !p.alive.Load() {
|
||||
return max
|
||||
}
|
||||
|
||||
|
@ -159,11 +160,7 @@ func (p *Proxy) MarshalJSON() ([]byte, error) {
|
|||
// URLTest get the delay for the specified URL
|
||||
func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
atomic.StoreUint32(&p.alive, 1)
|
||||
} else {
|
||||
atomic.StoreUint32(&p.alive, 0)
|
||||
}
|
||||
p.alive.Store(err == nil)
|
||||
record := C.DelayHistory{Time: time.Now()}
|
||||
if err == nil {
|
||||
record.Delay = t
|
||||
|
@ -219,5 +216,5 @@ func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) {
|
|||
}
|
||||
|
||||
func NewProxy(adapter C.ProxyAdapter) *Proxy {
|
||||
return &Proxy{adapter, queue.New(10), 1}
|
||||
return &Proxy{adapter, queue.New(10), atomic.NewBool(true)}
|
||||
}
|
||||
|
|
|
@ -2,11 +2,11 @@ package observable
|
|||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
func iterator(item []interface{}) chan interface{} {
|
||||
|
@ -33,25 +33,25 @@ func TestObservable(t *testing.T) {
|
|||
assert.Equal(t, count, 5)
|
||||
}
|
||||
|
||||
func TestObservable_MutilSubscribe(t *testing.T) {
|
||||
func TestObservable_MultiSubscribe(t *testing.T) {
|
||||
iter := iterator([]interface{}{1, 2, 3, 4, 5})
|
||||
src := NewObservable(iter)
|
||||
ch1, _ := src.Subscribe()
|
||||
ch2, _ := src.Subscribe()
|
||||
var count int32
|
||||
var count = atomic.NewInt32(0)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
waitCh := func(ch <-chan interface{}) {
|
||||
for range ch {
|
||||
atomic.AddInt32(&count, 1)
|
||||
count.Inc()
|
||||
}
|
||||
wg.Done()
|
||||
}
|
||||
go waitCh(ch1)
|
||||
go waitCh(ch2)
|
||||
wg.Wait()
|
||||
assert.Equal(t, int32(10), count)
|
||||
assert.Equal(t, int32(10), count.Load())
|
||||
}
|
||||
|
||||
func TestObservable_UnSubscribe(t *testing.T) {
|
||||
|
|
|
@ -2,17 +2,17 @@ package singledo
|
|||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
func TestBasic(t *testing.T) {
|
||||
single := NewSingle(time.Millisecond * 30)
|
||||
foo := 0
|
||||
var shardCount int32 = 0
|
||||
var shardCount = atomic.NewInt32(0)
|
||||
call := func() (interface{}, error) {
|
||||
foo++
|
||||
time.Sleep(time.Millisecond * 5)
|
||||
|
@ -26,7 +26,7 @@ func TestBasic(t *testing.T) {
|
|||
go func() {
|
||||
_, _, shard := single.Do(call)
|
||||
if shard {
|
||||
atomic.AddInt32(&shardCount, 1)
|
||||
shardCount.Inc()
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
|
@ -34,7 +34,7 @@ func TestBasic(t *testing.T) {
|
|||
|
||||
wg.Wait()
|
||||
assert.Equal(t, 1, foo)
|
||||
assert.Equal(t, int32(4), shardCount)
|
||||
assert.Equal(t, int32(4), shardCount.Load())
|
||||
}
|
||||
|
||||
func TestTimer(t *testing.T) {
|
||||
|
|
1
go.mod
1
go.mod
|
@ -13,6 +13,7 @@ require (
|
|||
github.com/oschwald/geoip2-golang v1.4.0
|
||||
github.com/sirupsen/logrus v1.7.0
|
||||
github.com/stretchr/testify v1.6.1
|
||||
go.uber.org/atomic v1.7.0
|
||||
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897
|
||||
golang.org/x/net v0.0.0-20201020065357-d65d470038a5
|
||||
golang.org/x/sync v0.0.0-20201008141435-b3e1573b7520
|
||||
|
|
3
go.sum
3
go.sum
|
@ -25,9 +25,12 @@ github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM
|
|||
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
|
||||
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
|
|
|
@ -2,26 +2,34 @@ package tunnel
|
|||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
var DefaultManager *Manager
|
||||
|
||||
func init() {
|
||||
DefaultManager = &Manager{}
|
||||
DefaultManager = &Manager{
|
||||
uploadTemp: atomic.NewInt64(0),
|
||||
downloadTemp: atomic.NewInt64(0),
|
||||
uploadBlip: atomic.NewInt64(0),
|
||||
downloadBlip: atomic.NewInt64(0),
|
||||
uploadTotal: atomic.NewInt64(0),
|
||||
downloadTotal: atomic.NewInt64(0),
|
||||
}
|
||||
|
||||
go DefaultManager.handle()
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
connections sync.Map
|
||||
uploadTemp int64
|
||||
downloadTemp int64
|
||||
uploadBlip int64
|
||||
downloadBlip int64
|
||||
uploadTotal int64
|
||||
downloadTotal int64
|
||||
uploadTemp *atomic.Int64
|
||||
downloadTemp *atomic.Int64
|
||||
uploadBlip *atomic.Int64
|
||||
downloadBlip *atomic.Int64
|
||||
uploadTotal *atomic.Int64
|
||||
downloadTotal *atomic.Int64
|
||||
}
|
||||
|
||||
func (m *Manager) Join(c tracker) {
|
||||
|
@ -33,17 +41,17 @@ func (m *Manager) Leave(c tracker) {
|
|||
}
|
||||
|
||||
func (m *Manager) PushUploaded(size int64) {
|
||||
atomic.AddInt64(&m.uploadTemp, size)
|
||||
atomic.AddInt64(&m.uploadTotal, size)
|
||||
m.uploadTemp.Add(size)
|
||||
m.uploadTotal.Add(size)
|
||||
}
|
||||
|
||||
func (m *Manager) PushDownloaded(size int64) {
|
||||
atomic.AddInt64(&m.downloadTemp, size)
|
||||
atomic.AddInt64(&m.downloadTotal, size)
|
||||
m.downloadTemp.Add(size)
|
||||
m.downloadTotal.Add(size)
|
||||
}
|
||||
|
||||
func (m *Manager) Now() (up int64, down int64) {
|
||||
return atomic.LoadInt64(&m.uploadBlip), atomic.LoadInt64(&m.downloadBlip)
|
||||
return m.uploadBlip.Load(), m.downloadBlip.Load()
|
||||
}
|
||||
|
||||
func (m *Manager) Snapshot() *Snapshot {
|
||||
|
@ -54,29 +62,29 @@ func (m *Manager) Snapshot() *Snapshot {
|
|||
})
|
||||
|
||||
return &Snapshot{
|
||||
UploadTotal: atomic.LoadInt64(&m.uploadTotal),
|
||||
DownloadTotal: atomic.LoadInt64(&m.downloadTotal),
|
||||
UploadTotal: m.uploadTotal.Load(),
|
||||
DownloadTotal: m.downloadTotal.Load(),
|
||||
Connections: connections,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) ResetStatistic() {
|
||||
m.uploadTemp = 0
|
||||
m.uploadBlip = 0
|
||||
m.uploadTotal = 0
|
||||
m.downloadTemp = 0
|
||||
m.downloadBlip = 0
|
||||
m.downloadTotal = 0
|
||||
m.uploadTemp.Store(0)
|
||||
m.uploadBlip.Store(0)
|
||||
m.uploadTotal.Store(0)
|
||||
m.downloadTemp.Store(0)
|
||||
m.downloadBlip.Store(0)
|
||||
m.downloadTotal.Store(0)
|
||||
}
|
||||
|
||||
func (m *Manager) handle() {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
|
||||
for range ticker.C {
|
||||
atomic.StoreInt64(&m.uploadBlip, atomic.LoadInt64(&m.uploadTemp))
|
||||
atomic.StoreInt64(&m.uploadTemp, 0)
|
||||
atomic.StoreInt64(&m.downloadBlip, atomic.LoadInt64(&m.downloadTemp))
|
||||
atomic.StoreInt64(&m.downloadTemp, 0)
|
||||
m.uploadBlip.Store(m.uploadTemp.Load())
|
||||
m.uploadTemp.Store(0)
|
||||
m.downloadBlip.Store(m.downloadTemp.Load())
|
||||
m.downloadTemp.Store(0)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,11 +2,12 @@ package tunnel
|
|||
|
||||
import (
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
C "github.com/Dreamacro/clash/constant"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
type tracker interface {
|
||||
|
@ -15,14 +16,14 @@ type tracker interface {
|
|||
}
|
||||
|
||||
type trackerInfo struct {
|
||||
UUID uuid.UUID `json:"id"`
|
||||
Metadata *C.Metadata `json:"metadata"`
|
||||
UploadTotal int64 `json:"upload"`
|
||||
DownloadTotal int64 `json:"download"`
|
||||
Start time.Time `json:"start"`
|
||||
Chain C.Chain `json:"chains"`
|
||||
Rule string `json:"rule"`
|
||||
RulePayload string `json:"rulePayload"`
|
||||
UUID uuid.UUID `json:"id"`
|
||||
Metadata *C.Metadata `json:"metadata"`
|
||||
UploadTotal *atomic.Int64 `json:"upload"`
|
||||
DownloadTotal *atomic.Int64 `json:"download"`
|
||||
Start time.Time `json:"start"`
|
||||
Chain C.Chain `json:"chains"`
|
||||
Rule string `json:"rule"`
|
||||
RulePayload string `json:"rulePayload"`
|
||||
}
|
||||
|
||||
type tcpTracker struct {
|
||||
|
@ -39,7 +40,7 @@ func (tt *tcpTracker) Read(b []byte) (int, error) {
|
|||
n, err := tt.Conn.Read(b)
|
||||
download := int64(n)
|
||||
tt.manager.PushDownloaded(download)
|
||||
atomic.AddInt64(&tt.DownloadTotal, download)
|
||||
tt.DownloadTotal.Add(download)
|
||||
return n, err
|
||||
}
|
||||
|
||||
|
@ -47,7 +48,7 @@ func (tt *tcpTracker) Write(b []byte) (int, error) {
|
|||
n, err := tt.Conn.Write(b)
|
||||
upload := int64(n)
|
||||
tt.manager.PushUploaded(upload)
|
||||
atomic.AddInt64(&tt.UploadTotal, upload)
|
||||
tt.UploadTotal.Add(upload)
|
||||
return n, err
|
||||
}
|
||||
|
||||
|
@ -63,11 +64,13 @@ func newTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.R
|
|||
Conn: conn,
|
||||
manager: manager,
|
||||
trackerInfo: &trackerInfo{
|
||||
UUID: uuid,
|
||||
Start: time.Now(),
|
||||
Metadata: metadata,
|
||||
Chain: conn.Chains(),
|
||||
Rule: "",
|
||||
UUID: uuid,
|
||||
Start: time.Now(),
|
||||
Metadata: metadata,
|
||||
Chain: conn.Chains(),
|
||||
Rule: "",
|
||||
UploadTotal: atomic.NewInt64(0),
|
||||
DownloadTotal: atomic.NewInt64(0),
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -94,7 +97,7 @@ func (ut *udpTracker) ReadFrom(b []byte) (int, net.Addr, error) {
|
|||
n, addr, err := ut.PacketConn.ReadFrom(b)
|
||||
download := int64(n)
|
||||
ut.manager.PushDownloaded(download)
|
||||
atomic.AddInt64(&ut.DownloadTotal, download)
|
||||
ut.DownloadTotal.Add(download)
|
||||
return n, addr, err
|
||||
}
|
||||
|
||||
|
@ -102,7 +105,7 @@ func (ut *udpTracker) WriteTo(b []byte, addr net.Addr) (int, error) {
|
|||
n, err := ut.PacketConn.WriteTo(b, addr)
|
||||
upload := int64(n)
|
||||
ut.manager.PushUploaded(upload)
|
||||
atomic.AddInt64(&ut.UploadTotal, upload)
|
||||
ut.UploadTotal.Add(upload)
|
||||
return n, err
|
||||
}
|
||||
|
||||
|
@ -118,11 +121,13 @@ func newUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, ru
|
|||
PacketConn: conn,
|
||||
manager: manager,
|
||||
trackerInfo: &trackerInfo{
|
||||
UUID: uuid,
|
||||
Start: time.Now(),
|
||||
Metadata: metadata,
|
||||
Chain: conn.Chains(),
|
||||
Rule: "",
|
||||
UUID: uuid,
|
||||
Start: time.Now(),
|
||||
Metadata: metadata,
|
||||
Chain: conn.Chains(),
|
||||
Rule: "",
|
||||
UploadTotal: atomic.NewInt64(0),
|
||||
DownloadTotal: atomic.NewInt64(0),
|
||||
},
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue