From 02cb4e1c926d2b2c63d955c7a89da4744150516a Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Wed, 30 Aug 2023 17:07:49 +0800 Subject: [PATCH] chore: use WaitGroup in dualStackDialContext --- component/dialer/dialer.go | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index 2bc5b66c..0cfa1b6c 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -169,10 +169,15 @@ func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, preferIPVersion := opt.prefer fallbackTicker := time.NewTicker(fallbackTimeout) defer fallbackTicker.Stop() + results := make(chan dialResult) returned := make(chan struct{}) defer close(returned) + + var wg sync.WaitGroup + racer := func(ips []netip.Addr, isPrimary bool) { + defer wg.Done() result := dialResult{isPrimary: isPrimary} defer func() { select { @@ -186,27 +191,35 @@ func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, result.Conn, result.error = dialFn(ctx, network, ips, port, opt) } - var wait int if len(ipv4s) != 0 { - wait++ + wg.Add(1) go racer(ipv4s, preferIPVersion != 6) } if len(ipv6s) != 0 { - wait++ + wg.Add(1) go racer(ipv6s, preferIPVersion != 4) } + go func() { + wg.Wait() + close(results) + }() + var fallback dialResult var errs []error - for i := 0; i < wait; { + +loop: + for { select { case <-fallbackTicker.C: if fallback.error == nil && fallback.Conn != nil { return fallback.Conn, nil } - case res := <-results: - i++ + case res, ok := <-results: + if !ok { + break loop + } if res.error == nil { if res.isPrimary { return res.Conn, nil @@ -221,6 +234,7 @@ func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, } } } + if fallback.error == nil && fallback.Conn != nil { return fallback.Conn, nil }