From f979491013c58068385531c507778ccb90abc5b6 Mon Sep 17 00:00:00 2001 From: Skyxim Date: Sat, 25 Jun 2022 09:00:35 +0800 Subject: [PATCH] fix: tcp concurrent force close when context done --- component/dialer/dialer.go | 60 ++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index e0401d42..bfba0079 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -4,11 +4,10 @@ import ( "context" "errors" "fmt" + "github.com/Dreamacro/clash/component/resolver" "net" "net/netip" "sync" - - "github.com/Dreamacro/clash/component/resolver" ) var ( @@ -171,25 +170,31 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt go startRacer(ctx, network+"4", host, opt.direct, false) go startRacer(ctx, network+"6", host, opt.direct, true) - for res := range results { - if res.error == nil { - return res.Conn, nil - } - - if !res.ipv6 { - primary = res - } else { - fallback = res - } - - if primary.done && fallback.done { - if primary.resolved { - return nil, primary.error - } else if fallback.resolved { - return nil, fallback.error - } else { - return nil, primary.error + count := 2 + for i := 0; i < count; i++ { + select { + case res := <-results: + if res.error == nil { + return res.Conn, nil } + + if !res.ipv6 { + primary = res + } else { + fallback = res + } + + if primary.done && fallback.done { + if primary.resolved { + return nil, primary.error + } else if fallback.resolved { + return nil, fallback.error + } else { + return nil, primary.error + } + } + case <-ctx.Done(): + break } } @@ -225,7 +230,6 @@ func concurrentDialContext(ctx context.Context, network string, ips []netip.Addr } results := make(chan dialResult) - tcpRacer := func(ctx context.Context, ip netip.Addr) { result := dialResult{ip: ip} @@ -252,13 +256,13 @@ func concurrentDialContext(ctx context.Context, network string, ips []netip.Addr } connCount := len(ips) - for res := range results { - connCount-- - if res.error == nil { - return res.Conn, nil - } - - if connCount == 0 { + for i := 0; i < connCount; i++ { + select { + case res := <-results: + if res.error == nil { + return res.Conn, nil + } + case <-ctx.Done(): break } }