From 9a62b1081d79cb00562ce25ebe73cb1c8dbca407 Mon Sep 17 00:00:00 2001 From: uchuhimo Date: Wed, 28 Oct 2020 22:35:02 +0800 Subject: [PATCH] Feature: support round-robin strategy for load-balance group (#1044) --- adapters/outboundgroup/loadbalance.go | 87 +++++++++++++++++++++------ adapters/outboundgroup/parser.go | 3 +- 2 files changed, 69 insertions(+), 21 deletions(-) diff --git a/adapters/outboundgroup/loadbalance.go b/adapters/outboundgroup/loadbalance.go index 8a7fe974..de634967 100644 --- a/adapters/outboundgroup/loadbalance.go +++ b/adapters/outboundgroup/loadbalance.go @@ -3,6 +3,8 @@ package outboundgroup import ( "context" "encoding/json" + "errors" + "fmt" "net" "github.com/Dreamacro/clash/adapters/outbound" @@ -14,11 +16,24 @@ import ( "golang.org/x/net/publicsuffix" ) +type strategyFn = func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy + type LoadBalance struct { *outbound.Base - single *singledo.Single - maxRetry int - providers []provider.ProxyProvider + single *singledo.Single + providers []provider.ProxyProvider + strategyFn strategyFn +} + +var errStrategy = errors.New("unsupported strategy") + +func parseStrategy(config map[string]interface{}) string { + if elm, ok := config["strategy"]; ok { + if strategy, ok := elm.(string); ok { + return strategy + } + } + return "consistent-hashing" } func getKey(metadata *C.Metadata) string { @@ -81,19 +96,42 @@ func (lb *LoadBalance) SupportUDP() bool { return true } -func (lb *LoadBalance) Unwrap(metadata *C.Metadata) C.Proxy { - key := uint64(murmur3.Sum32([]byte(getKey(metadata)))) - proxies := lb.proxies() - buckets := int32(len(proxies)) - for i := 0; i < lb.maxRetry; i, key = i+1, key+1 { - idx := jumpHash(key, buckets) - proxy := proxies[idx] - if proxy.Alive() { - return proxy +func strategyRoundRobin() strategyFn { + idx := 0 + return func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy { + length := len(proxies) + for i := 0; i < length; i++ { + idx = (idx + 1) % length + proxy := proxies[idx] + if proxy.Alive() { + return proxy + } } - } - return proxies[0] + return proxies[0] + } +} + +func strategyConsistentHashing() strategyFn { + maxRetry := 5 + return func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy { + key := uint64(murmur3.Sum32([]byte(getKey(metadata)))) + buckets := int32(len(proxies)) + for i := 0; i < maxRetry; i, key = i+1, key+1 { + idx := jumpHash(key, buckets) + proxy := proxies[idx] + if proxy.Alive() { + return proxy + } + } + + return proxies[0] + } +} + +func (lb *LoadBalance) Unwrap(metadata *C.Metadata) C.Proxy { + proxies := lb.proxies() + return lb.strategyFn(proxies, metadata) } func (lb *LoadBalance) proxies() []C.Proxy { @@ -115,11 +153,20 @@ func (lb *LoadBalance) MarshalJSON() ([]byte, error) { }) } -func NewLoadBalance(name string, providers []provider.ProxyProvider) *LoadBalance { - return &LoadBalance{ - Base: outbound.NewBase(name, "", C.LoadBalance, false), - single: singledo.NewSingle(defaultGetProxiesDuration), - maxRetry: 3, - providers: providers, +func NewLoadBalance(name string, providers []provider.ProxyProvider, strategy string) (lb *LoadBalance, err error) { + var strategyFn strategyFn + switch strategy { + case "consistent-hashing": + strategyFn = strategyConsistentHashing() + case "round-robin": + strategyFn = strategyRoundRobin() + default: + return nil, fmt.Errorf("%w: %s", errStrategy, strategy) } + return &LoadBalance{ + Base: outbound.NewBase(name, "", C.LoadBalance, false), + single: singledo.NewSingle(defaultGetProxiesDuration), + providers: providers, + strategyFn: strategyFn, + }, nil } diff --git a/adapters/outboundgroup/parser.go b/adapters/outboundgroup/parser.go index f0fa54c9..9a4a1681 100644 --- a/adapters/outboundgroup/parser.go +++ b/adapters/outboundgroup/parser.go @@ -111,7 +111,8 @@ func ParseProxyGroup(config map[string]interface{}, proxyMap map[string]C.Proxy, case "fallback": group = NewFallback(groupName, providers) case "load-balance": - group = NewLoadBalance(groupName, providers) + strategy := parseStrategy(config) + return NewLoadBalance(groupName, providers, strategy) case "relay": group = NewRelay(groupName, providers) default: