From 545a79d406515809690f6b111388188a3e46501b Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Mon, 6 Mar 2023 23:23:05 +0800 Subject: [PATCH] chore: cleanup dialer's code --- component/dialer/dialer.go | 98 +++++++++++++------------------------- hub/executor/executor.go | 6 +-- hub/route/configs.go | 2 +- 3 files changed, 37 insertions(+), 69 deletions(-) diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index 478e9f19..f53435fb 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -13,6 +13,8 @@ import ( "github.com/Dreamacro/clash/component/resolver" ) +type dialFunc func(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) + var ( dialMux sync.Mutex actualSingleStackDialContext = serialSingleStackDialContext @@ -51,11 +53,16 @@ func DialContext(ctx context.Context, network, address string, options ...Option network = fmt.Sprintf("%s%d", network, opt.network) } + ips, port, err := parseAddr(ctx, network, address, opt.resolver) + if err != nil { + return nil, err + } + switch network { case "tcp4", "tcp6", "udp4", "udp6": - return actualSingleStackDialContext(ctx, network, address, opt) + return actualSingleStackDialContext(ctx, network, ips, port, opt) case "tcp", "udp": - return actualDualStackDialContext(ctx, network, address, opt) + return actualDualStackDialContext(ctx, network, ips, port, opt) default: return nil, ErrorInvalidedNetworkStack } @@ -82,8 +89,9 @@ func ListenPacket(ctx context.Context, network, address string, options ...Optio return lc.ListenPacket(ctx, network, address) } -func SetDial(concurrent bool) { +func SetTcpConcurrent(concurrent bool) { dialMux.Lock() + defer dialMux.Unlock() tcpConcurrent = concurrent if concurrent { actualSingleStackDialContext = concurrentSingleStackDialContext @@ -92,11 +100,11 @@ func SetDial(concurrent bool) { actualSingleStackDialContext = serialSingleStackDialContext actualDualStackDialContext = serialDualStackDialContext } - - dialMux.Unlock() } -func GetDial() bool { +func GetTcpConcurrent() bool { + dialMux.Lock() + defer dialMux.Unlock() return tcpConcurrent } @@ -118,71 +126,35 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po return dialer.DialContext(ctx, network, address) } -func serialSingleStackDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) { - ips, port, err := parseAddr(ctx, network, address, opt.resolver) - if err != nil { - return nil, err - } +func serialSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { return serialDialContext(ctx, network, ips, port, opt) } -func serialDualStackDialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) { - ips, port, err := parseAddr(ctx, network, address, opt.resolver) - if err != nil { - return nil, err - } - ipv4s, ipv6s := sortationAddr(ips) - return dualStackDialContext( - ctx, - func(ctx context.Context) (net.Conn, error) { return serialDialContext(ctx, network, ipv4s, port, opt) }, - func(ctx context.Context) (net.Conn, error) { return serialDialContext(ctx, network, ipv6s, port, opt) }, - opt.prefer) +func serialDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { + return dualStackDialContext(ctx, serialDialContext, network, ips, port, opt) } -func concurrentSingleStackDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) { - ips, port, err := parseAddr(ctx, network, address, opt.resolver) - if err != nil { - return nil, err - } - - if conn, err := parallelDialContext(ctx, network, ips, port, opt); err != nil { - return nil, err - } else { - return conn, nil - } +func concurrentSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { + return parallelDialContext(ctx, network, ips, port, opt) } -func concurrentDualStackDialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) { - ips, port, err := parseAddr(ctx, network, address, opt.resolver) - if err != nil { - return nil, err - } +func concurrentDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { if opt.prefer != 4 && opt.prefer != 6 { return parallelDialContext(ctx, network, ips, port, opt) } - ipv4s, ipv6s := sortationAddr(ips) - return dualStackDialContext( - ctx, - func(ctx context.Context) (net.Conn, error) { - return parallelDialContext(ctx, network, ipv4s, port, opt) - }, - func(ctx context.Context) (net.Conn, error) { - return parallelDialContext(ctx, network, ipv6s, port, opt) - }, - opt.prefer) + return dualStackDialContext(ctx, parallelDialContext, network, ips, port, opt) } -func dualStackDialContext( - ctx context.Context, - ipv4DialFn func(ctx context.Context) (net.Conn, error), - ipv6DialFn func(ctx context.Context) (net.Conn, error), - preferIPVersion int) (net.Conn, error) { +func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { + ipv4s, ipv6s := sortationAddr(ips) + preferIPVersion := opt.prefer + fallbackTicker := time.NewTicker(fallbackTimeout) defer fallbackTicker.Stop() results := make(chan dialResult) returned := make(chan struct{}) defer close(returned) - racer := func(dial func(ctx context.Context) (net.Conn, error), isPrimary bool) { + racer := func(ips []netip.Addr, isPrimary bool) { result := dialResult{isPrimary: isPrimary} defer func() { select { @@ -193,10 +165,10 @@ func dualStackDialContext( } } }() - result.Conn, result.error = dial(ctx) + result.Conn, result.error = dialFn(ctx, network, ips, port, opt) } - go racer(ipv4DialFn, preferIPVersion != 6) - go racer(ipv6DialFn, preferIPVersion != 4) + go racer(ipv4s, preferIPVersion != 6) + go racer(ipv6s, preferIPVersion != 4) var fallback dialResult var err error for { @@ -230,7 +202,7 @@ func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, results := make(chan dialResult) returned := make(chan struct{}) defer close(returned) - tcpRacer := func(ctx context.Context, ip netip.Addr, port string) { + racer := func(ctx context.Context, ip netip.Addr) { result := dialResult{isPrimary: true} defer func() { select { @@ -246,7 +218,7 @@ func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, } for _, ip := range ips { - go tcpRacer(ctx, ip, port) + go racer(ctx, ip) } var err error for { @@ -272,13 +244,9 @@ func serialDialContext(ctx context.Context, network string, ips []netip.Addr, po if len(ips) == 0 { return nil, ErrorNoIpAddress } - var ( - conn net.Conn - err error - errs []error - ) + var errs []error for _, ip := range ips { - if conn, err = dialContext(ctx, network, ip, port, opt); err == nil { + if conn, err := dialContext(ctx, network, ip, port, opt); err == nil { return conn, nil } else { errs = append(errs, err) diff --git a/hub/executor/executor.go b/hub/executor/executor.go index c55201cd..d6ff7851 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -128,7 +128,7 @@ func GetGeneral() *config.General { GeodataLoader: G.LoaderName(), Interface: dialer.DefaultInterface.Load(), Sniffing: tunnel.IsSniffing(), - TCPConcurrent: dialer.GetDial(), + TCPConcurrent: dialer.GetTcpConcurrent(), } return general @@ -331,10 +331,10 @@ func updateTunnels(tunnels []LC.Tunnel) { func updateGeneral(general *config.General) { tunnel.SetMode(general.Mode) tunnel.SetFindProcessMode(general.FindProcessMode) - resolver.DisableIPv6 =!general.IPv6 + resolver.DisableIPv6 = !general.IPv6 if general.TCPConcurrent { - dialer.SetDial(general.TCPConcurrent) + dialer.SetTcpConcurrent(general.TCPConcurrent) log.Infoln("Use tcp concurrent") } diff --git a/hub/route/configs.go b/hub/route/configs.go index 9e630b29..50e3cd13 100644 --- a/hub/route/configs.go +++ b/hub/route/configs.go @@ -228,7 +228,7 @@ func patchConfigs(w http.ResponseWriter, r *http.Request) { } if general.TcpConcurrent != nil { - dialer.SetDial(*general.TcpConcurrent) + dialer.SetTcpConcurrent(*general.TcpConcurrent) } if general.InterfaceName != nil {