From 97e14337e355a289b21227b8dcc9ba1d6d7e44e4 Mon Sep 17 00:00:00 2001 From: Skyxim Date: Sun, 26 Feb 2023 11:24:49 +0800 Subject: [PATCH] refactor: tcp dial (#412) Non-concurrent support to try to connect in turn fix: serial dual stack dial --- component/dialer/dialer.go | 413 +++++++++++++++---------------------- hub/executor/executor.go | 6 +- 2 files changed, 170 insertions(+), 249 deletions(-) diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index 9ac9d719..ab2fe047 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -8,20 +8,19 @@ import ( "net/netip" "strings" "sync" + "time" "github.com/Dreamacro/clash/component/resolver" - - "go.uber.org/atomic" ) var ( - dialMux sync.Mutex - actualSingleDialContext = singleDialContext - actualDualStackDialContext = dualStackDialContext - tcpConcurrent = false - DisableIPv6 = false - ErrorInvalidedNetworkStack = errors.New("invalided network stack") - ErrorDisableIPv6 = errors.New("IPv6 is disabled, dialer cancel") + dialMux sync.Mutex + actualSingleStackDialContext = serialSingleStackDialContext + actualDualStackDialContext = serialDualStackDialContext + tcpConcurrent = false + ErrorInvalidedNetworkStack = errors.New("invalided network stack") + ErrorConnTimeout = errors.New("connect timeout") + fallbackTimeout = 300 * time.Millisecond ) func applyOptions(options ...Option) *option { @@ -56,7 +55,7 @@ func DialContext(ctx context.Context, network, address string, options ...Option switch network { case "tcp4", "tcp6", "udp4", "udp6": - return actualSingleDialContext(ctx, network, address, opt) + return actualSingleStackDialContext(ctx, network, address, opt) case "tcp", "udp": return actualDualStackDialContext(ctx, network, address, opt) default: @@ -89,11 +88,11 @@ func SetDial(concurrent bool) { dialMux.Lock() tcpConcurrent = concurrent if concurrent { - actualSingleDialContext = concurrentSingleDialContext + actualSingleStackDialContext = concurrentSingleStackDialContext actualDualStackDialContext = concurrentDualStackDialContext } else { - actualSingleDialContext = singleDialContext - actualDualStackDialContext = dualStackDialContext + actualSingleStackDialContext = serialSingleStackDialContext + actualDualStackDialContext = serialDualStackDialContext } dialMux.Unlock() @@ -114,10 +113,6 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po bindMarkToDialer(opt.routingMark, dialer, network, destination) } - if DisableIPv6 && destination.Is6() { - return nil, ErrorDisableIPv6 - } - address := net.JoinHostPort(destination.String(), port) if opt.tfo { return dialTFO(ctx, *dialer, network, address) @@ -125,146 +120,74 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po return dialer.DialContext(ctx, network, address) } -func singleDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) { - host, port, err := net.SplitHostPort(address) +func serialSingleStackDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) { + ips, port, err := parseAddr(ctx, network, address, opt.resolver) if err != nil { return nil, err } - - var ip netip.Addr - switch network { - case "tcp4", "udp4": - if opt.resolver == nil { - ip, err = resolver.ResolveIPv4ProxyServerHost(ctx, host) - } else { - ip, err = resolver.ResolveIPv4WithResolver(ctx, host, opt.resolver) - } - default: - if opt.resolver == nil { - ip, err = resolver.ResolveIPv6ProxyServerHost(ctx, host) - } else { - ip, err = resolver.ResolveIPv6WithResolver(ctx, host, opt.resolver) - } - } - if err != nil { - err = fmt.Errorf("dns resolve failed:%w", err) - return nil, err - } - - return dialContext(ctx, network, ip, port, opt) + return serialDialContext(ctx, network, ips, port, opt) } -func dualStackDialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) { - host, port, err := net.SplitHostPort(address) +func serialDualStackDialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) { + ips, port, err := parseAddr(ctx, network, address, opt.resolver) + if err != nil { + return nil, err + } + if opt.prefer != 4 && opt.prefer != 6 { + return serialDialContext(ctx, network, ips, port, opt) + } + return dualStackDialContext( + ctx, + func(ctx context.Context) (net.Conn, error) { return serialDialContext(ctx, network, ips, port, opt) }, + func(ctx context.Context) (net.Conn, error) { return serialDialContext(ctx, network, ips, port, opt) }, + opt.prefer == 4) +} + +func concurrentSingleStackDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) { + ips, port, err := parseAddr(ctx, network, address, opt.resolver) if err != nil { return nil, err } - returned := make(chan struct{}) - defer close(returned) - - type dialResult struct { - net.Conn - error - resolved bool - ipv6 bool - done bool - } - results := make(chan dialResult) - var primary, fallback dialResult - - startRacer := func(ctx context.Context, network, host string, r resolver.Resolver, ipv6 bool) { - result := dialResult{ipv6: ipv6, done: true} - defer func() { - select { - case results <- result: - case <-returned: - if result.Conn != nil { - _ = result.Conn.Close() - } - } - }() - - var ip netip.Addr - if ipv6 { - if r == nil { - ip, result.error = resolver.ResolveIPv6ProxyServerHost(ctx, host) - } else { - ip, result.error = resolver.ResolveIPv6WithResolver(ctx, host, r) - } - } else { - if r == nil { - ip, result.error = resolver.ResolveIPv4ProxyServerHost(ctx, host) - } else { - ip, result.error = resolver.ResolveIPv4WithResolver(ctx, host, r) - } - } - if result.error != nil { - result.error = fmt.Errorf("dns resolve failed:%w", result.error) - return - } - result.resolved = true - - result.Conn, result.error = dialContext(ctx, network, ip, port, opt) - } - - go startRacer(ctx, network+"4", host, opt.resolver, false) - go startRacer(ctx, network+"6", host, opt.resolver, true) - - count := 2 - for i := 0; i < count; i++ { - select { - case res := <-results: - if res.error == nil { - return res.Conn, nil - } - - if !res.ipv6 { - primary = res - } else { - fallback = res - } - - if primary.done && fallback.done { - if primary.resolved { - return nil, primary.error - } else if fallback.resolved { - return nil, fallback.error - } else { - return nil, primary.error - } - } - case <-ctx.Done(): - err = ctx.Err() - break - } - } - - if err == nil { - err = fmt.Errorf("dual stack dial failed") + if conn, err := parallelDialContext(ctx, network, ips, port, opt); err != nil { + return nil, err } else { - err = fmt.Errorf("dual stack dial failed:%w", err) + return conn, nil } - return nil, err } -func concurrentDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { +func concurrentDualStackDialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) { + ips, port, err := parseAddr(ctx, network, address, opt.resolver) + if err != nil { + return nil, err + } + if opt.prefer != 4 && opt.prefer != 6 { + return parallelDialContext(ctx, network, ips, port, opt) + } + ipv4s, ipv6s := sortationAddr(ips) + return dualStackDialContext( + ctx, + func(ctx context.Context) (net.Conn, error) { + return parallelDialContext(ctx, network, ipv4s, port, opt) + }, + func(ctx context.Context) (net.Conn, error) { + return parallelDialContext(ctx, network, ipv6s, port, opt) + }, + opt.prefer == 4) +} + +func dualStackDialContext( + ctx context.Context, + ipv4DialFn func(ctx context.Context) (net.Conn, error), + ipv6DialFn func(ctx context.Context) (net.Conn, error), + preferIPv4 bool) (net.Conn, error) { + fallbackTicker := time.NewTicker(fallbackTimeout) + defer fallbackTicker.Stop() + results := make(chan dialResult) returned := make(chan struct{}) defer close(returned) - - type dialResult struct { - ip netip.Addr - net.Conn - error - isPrimary bool - done bool - } - - preferCount := atomic.NewInt32(0) - results := make(chan dialResult) - tcpRacer := func(ctx context.Context, ip netip.Addr) { - result := dialResult{ip: ip, done: true} - + racer := func(dial func(ctx context.Context) (net.Conn, error), isPrimary bool) { + result := dialResult{isPrimary: isPrimary} defer func() { select { case results <- result: @@ -274,140 +197,142 @@ func concurrentDialContext(ctx context.Context, network string, ips []netip.Addr } } }() - if strings.Contains(network, "tcp") { - network = "tcp" - } else { - network = "udp" - } - - if ip.Is6() { - network += "6" - if opt.prefer != 4 { - result.isPrimary = true - } - } - - if ip.Is4() { - network += "4" - if opt.prefer != 6 { - result.isPrimary = true - } - } - - if result.isPrimary { - preferCount.Add(1) - } - - result.Conn, result.error = dialContext(ctx, network, ip, port, opt) + result.Conn, result.error = dial(ctx) } - - for _, ip := range ips { - go tcpRacer(ctx, ip) - } - - connCount := len(ips) + go racer(ipv4DialFn, preferIPv4) + go racer(ipv6DialFn, !preferIPv4) var fallback dialResult - var primaryError error - var finalError error - for i := 0; i < connCount; i++ { + var err error + for { select { + case <-ctx.Done(): + if fallback.error == nil && fallback.Conn != nil { + return fallback.Conn, nil + } + return nil, fmt.Errorf("dual stack connect failed: %w", err) + case <-fallbackTicker.C: + if fallback.error == nil && fallback.Conn != nil { + return fallback.Conn, nil + } case res := <-results: if res.error == nil { if res.isPrimary { return res.Conn, nil - } else { - if !fallback.done || fallback.error != nil { - fallback = res - } - } - } else { - if res.isPrimary { - primaryError = res.error - preferCount.Add(-1) - if preferCount.Load() == 0 && fallback.done && fallback.error == nil { - return fallback.Conn, nil - } } + fallback = res } - case <-ctx.Done(): - if fallback.done && fallback.error == nil { - return fallback.Conn, nil - } - finalError = ctx.Err() - break + err = res.error } } - - if fallback.done && fallback.error == nil { - return fallback.Conn, nil - } - - if primaryError != nil { - return nil, primaryError - } - - if fallback.error != nil { - return nil, fallback.error - } - - if finalError == nil { - finalError = fmt.Errorf("all ips %v tcp shake hands failed", ips) - } else { - finalError = fmt.Errorf("concurrent dial failed:%w", finalError) - } - - return nil, finalError } -func concurrentSingleDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) { +func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { + results := make(chan dialResult) + returned := make(chan struct{}) + defer close(returned) + tcpRacer := func(ctx context.Context, ip netip.Addr, port string) { + result := dialResult{isPrimary: true} + defer func() { + select { + case results <- result: + case <-returned: + if result.Conn != nil { + _ = result.Conn.Close() + } + } + }() + result.ip = ip + result.Conn, result.error = dialContext(ctx, network, ip, port, opt) + } + + for _, ip := range ips { + go tcpRacer(ctx, ip, port) + } + var err error + for { + select { + case <-ctx.Done(): + if err != nil { + return nil, err + } + if ctx.Err() == context.DeadlineExceeded { + return nil, ErrorConnTimeout + } + return nil, ctx.Err() + case res := <-results: + if res.error == nil { + return res.Conn, nil + } + err = res.error + } + } +} + +func serialDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { + var ( + conn net.Conn + err error + errs []error + ) + for _, ip := range ips { + if conn, err = dialContext(ctx, network, ip, port, opt); err == nil { + return conn, nil + } else { + errs = append(errs, err) + } + } + return nil, errors.Join(errs...) +} + +type dialResult struct { + ip netip.Addr + net.Conn + error + isPrimary bool +} + +func parseAddr(ctx context.Context, network, address string, preferResolver resolver.Resolver) ([]netip.Addr, string, error) { host, port, err := net.SplitHostPort(address) if err != nil { - return nil, err + return nil, "-1", err } var ips []netip.Addr switch network { case "tcp4", "udp4": - if opt.resolver == nil { + if preferResolver == nil { ips, err = resolver.LookupIPv4ProxyServerHost(ctx, host) } else { - ips, err = resolver.LookupIPv4WithResolver(ctx, host, opt.resolver) + ips, err = resolver.LookupIPv4WithResolver(ctx, host, preferResolver) } - default: - if opt.resolver == nil { + case "tcp6", "udp6": + if preferResolver == nil { ips, err = resolver.LookupIPv6ProxyServerHost(ctx, host) } else { - ips, err = resolver.LookupIPv6WithResolver(ctx, host, opt.resolver) + ips, err = resolver.LookupIPv6WithResolver(ctx, host, preferResolver) + } + default: + if preferResolver == nil { + ips, err = resolver.LookupIP(ctx, host) + } else { + ips, err = resolver.LookupIPWithResolver(ctx, host, preferResolver) } } - if err != nil { - err = fmt.Errorf("dns resolve failed:%w", err) - return nil, err + return nil, "-1", fmt.Errorf("dns resolve failed: %w", err) } - - return concurrentDialContext(ctx, network, ips, port, opt) + return ips, port, nil } -func concurrentDualStackDialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) { - host, port, err := net.SplitHostPort(address) - if err != nil { - return nil, err +func sortationAddr(ips []netip.Addr) (ipv4s, ipv6s []netip.Addr) { + for _, v := range ips { + if v.Is4() || v.Is4In6() { + ipv4s = append(ipv4s, v) + } else { + ipv6s = append(ipv6s, v) + } } - - var ips []netip.Addr - if opt.resolver != nil { - ips, err = resolver.LookupIPWithResolver(ctx, host, opt.resolver) - } else { - ips, err = resolver.LookupIPProxyServerHost(ctx, host) - } - - if err != nil { - err = fmt.Errorf("dns resolve failed:%w", err) - return nil, err - } - - return concurrentDialContext(ctx, network, ips, port, opt) + return } type Dialer struct { diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 34f0f1a1..c55201cd 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -331,11 +331,7 @@ func updateTunnels(tunnels []LC.Tunnel) { func updateGeneral(general *config.General) { tunnel.SetMode(general.Mode) tunnel.SetFindProcessMode(general.FindProcessMode) - dialer.DisableIPv6 = !general.IPv6 - if !dialer.DisableIPv6 { - log.Infoln("Use IPv6") - } - resolver.DisableIPv6 = dialer.DisableIPv6 + resolver.DisableIPv6 =!general.IPv6 if general.TCPConcurrent { dialer.SetDial(general.TCPConcurrent)