fix: loadbalance panic

This commit is contained in:
wwqgtxx 2023-03-01 14:04:42 +08:00
parent 685fd49dd7
commit e7613e4f8b
2 changed files with 20 additions and 17 deletions

View file

@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"sync"
"time" "time"
"github.com/Dreamacro/clash/adapter/outbound" "github.com/Dreamacro/clash/adapter/outbound"
@ -20,7 +21,7 @@ import (
"golang.org/x/net/publicsuffix" "golang.org/x/net/publicsuffix"
) )
type strategyFn = func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy type strategyFn = func(proxies []C.Proxy, metadata *C.Metadata, touch bool) C.Proxy
type LoadBalance struct { type LoadBalance struct {
*GroupBase *GroupBase
@ -127,21 +128,23 @@ func (lb *LoadBalance) SupportUDP() bool {
} }
func strategyRoundRobin() strategyFn { func strategyRoundRobin() strategyFn {
flag := true
idx := 0 idx := 0
return func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy { idxMutex := sync.Mutex{}
return func(proxies []C.Proxy, metadata *C.Metadata, touch bool) C.Proxy {
id := idx // value could be wrong due to no lock, but don't care if we don't touch
if touch {
idxMutex.Lock()
defer idxMutex.Unlock()
id = idx // get again by lock's protect, so it must be right
defer func() {
idx = id
}()
}
length := len(proxies) length := len(proxies)
for i := 0; i < length; i++ { for i := 0; i < length; i++ {
flag = !flag id = (id + 1) % length
if flag { proxy := proxies[id]
idx = (idx - 1) % length
} else {
idx = (idx + 2) % length
}
if idx < 0 {
idx = idx + length
}
proxy := proxies[idx]
if proxy.Alive() { if proxy.Alive() {
return proxy return proxy
} }
@ -153,7 +156,7 @@ func strategyRoundRobin() strategyFn {
func strategyConsistentHashing() strategyFn { func strategyConsistentHashing() strategyFn {
maxRetry := 5 maxRetry := 5
return func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy { return func(proxies []C.Proxy, metadata *C.Metadata, touch bool) C.Proxy {
key := uint64(murmur3.Sum32([]byte(getKey(metadata)))) key := uint64(murmur3.Sum32([]byte(getKey(metadata))))
buckets := int32(len(proxies)) buckets := int32(len(proxies))
for i := 0; i < maxRetry; i, key = i+1, key+1 { for i := 0; i < maxRetry; i, key = i+1, key+1 {
@ -181,7 +184,7 @@ func strategyStickySessions() strategyFn {
lruCache := cache.New[uint64, int]( lruCache := cache.New[uint64, int](
cache.WithAge[uint64, int](int64(ttl.Seconds())), cache.WithAge[uint64, int](int64(ttl.Seconds())),
cache.WithSize[uint64, int](1000)) cache.WithSize[uint64, int](1000))
return func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy { return func(proxies []C.Proxy, metadata *C.Metadata, touch bool) C.Proxy {
key := uint64(murmur3.Sum32([]byte(getKeyWithSrcAndDst(metadata)))) key := uint64(murmur3.Sum32([]byte(getKeyWithSrcAndDst(metadata))))
length := len(proxies) length := len(proxies)
idx, has := lruCache.Get(key) idx, has := lruCache.Get(key)
@ -213,7 +216,7 @@ func strategyStickySessions() strategyFn {
// Unwrap implements C.ProxyAdapter // Unwrap implements C.ProxyAdapter
func (lb *LoadBalance) Unwrap(metadata *C.Metadata, touch bool) C.Proxy { func (lb *LoadBalance) Unwrap(metadata *C.Metadata, touch bool) C.Proxy {
proxies := lb.GetProxies(touch) proxies := lb.GetProxies(touch)
return lb.strategyFn(proxies, metadata) return lb.strategyFn(proxies, metadata, true)
} }
// MarshalJSON implements C.ProxyAdapter // MarshalJSON implements C.ProxyAdapter

View file

@ -176,7 +176,7 @@ func (r *Relay) proxies(metadata *C.Metadata, touch bool) ([]C.Proxy, []C.Proxy)
} }
func (r *Relay) Addr() string { func (r *Relay) Addr() string {
proxies, _ := r.proxies(nil, true) proxies, _ := r.proxies(nil, false)
return proxies[len(proxies)-1].Addr() return proxies[len(proxies)-1].Addr()
} }