diff --git a/adapter/outboundgroup/loadbalance.go b/adapter/outboundgroup/loadbalance.go index 198bf60d..93542c9d 100644 --- a/adapter/outboundgroup/loadbalance.go +++ b/adapter/outboundgroup/loadbalance.go @@ -5,7 +5,9 @@ import ( "encoding/json" "errors" "fmt" + "math/rand" "net" + "time" "github.com/Dreamacro/clash/adapter/outbound" "github.com/Dreamacro/clash/common/murmur3" @@ -137,6 +139,60 @@ func strategyConsistentHashing() strategyFn { } } +func strategyStickySessions() strategyFn { + timeout := int64(600) + type Session struct { + idx int + time time.Time + } + Sessions := make(map[string]map[string]Session) + go func() { + for true { + time.Sleep(time.Second * 60) + now := time.Now().Unix() + for _, subMap := range Sessions { + for dest, session := range subMap { + if now-session.time.Unix() > timeout { + delete(subMap, dest) + } + } + } + } + }() + return func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy { + src := metadata.SrcIP.String() + dest := getKey(metadata) + now := time.Now() + length := len(proxies) + if Sessions[src] == nil { + Sessions[src] = make(map[string]Session) + } + session, ok := Sessions[src][dest] + if !ok || now.Unix()-session.time.Unix() > timeout { + session.idx = rand.Intn(length) + } + session.time = now + + var i int + var res C.Proxy + for i := 0; i < length; i++ { + idx := (session.idx + i) % length + proxy := proxies[idx] + if proxy.Alive() { + session.idx = idx + res = proxy + break + } + } + if i == length { + session.idx = 0 + res = proxies[0] + } + Sessions[src][dest] = session + return res + } +} + // Unwrap implements C.ProxyAdapter func (lb *LoadBalance) Unwrap(metadata *C.Metadata) C.Proxy { proxies := lb.GetProxies(true) @@ -145,7 +201,7 @@ func (lb *LoadBalance) Unwrap(metadata *C.Metadata) C.Proxy { // MarshalJSON implements C.ProxyAdapter func (lb *LoadBalance) MarshalJSON() ([]byte, error) { - all := []string{} + var all []string for _, proxy := range lb.GetProxies(false) { all = append(all, proxy.Name()) } @@ -162,6 +218,8 @@ func NewLoadBalance(option *GroupCommonOption, providers []provider.ProxyProvide strategyFn = strategyConsistentHashing() case "round-robin": strategyFn = strategyRoundRobin() + case "sticky-sessions": + strategyFn = strategyStickySessions() default: return nil, fmt.Errorf("%w: %s", errStrategy, strategy) }