chore: use WaitGroup in dualStackDialContext

This commit is contained in:
gVisor bot 2023-08-30 17:07:49 +08:00
parent c87597408e
commit 02cb4e1c92

View file

@ -169,10 +169,15 @@ func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string,
preferIPVersion := opt.prefer preferIPVersion := opt.prefer
fallbackTicker := time.NewTicker(fallbackTimeout) fallbackTicker := time.NewTicker(fallbackTimeout)
defer fallbackTicker.Stop() defer fallbackTicker.Stop()
results := make(chan dialResult) results := make(chan dialResult)
returned := make(chan struct{}) returned := make(chan struct{})
defer close(returned) defer close(returned)
var wg sync.WaitGroup
racer := func(ips []netip.Addr, isPrimary bool) { racer := func(ips []netip.Addr, isPrimary bool) {
defer wg.Done()
result := dialResult{isPrimary: isPrimary} result := dialResult{isPrimary: isPrimary}
defer func() { defer func() {
select { select {
@ -186,27 +191,35 @@ func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string,
result.Conn, result.error = dialFn(ctx, network, ips, port, opt) result.Conn, result.error = dialFn(ctx, network, ips, port, opt)
} }
var wait int
if len(ipv4s) != 0 { if len(ipv4s) != 0 {
wait++ wg.Add(1)
go racer(ipv4s, preferIPVersion != 6) go racer(ipv4s, preferIPVersion != 6)
} }
if len(ipv6s) != 0 { if len(ipv6s) != 0 {
wait++ wg.Add(1)
go racer(ipv6s, preferIPVersion != 4) go racer(ipv6s, preferIPVersion != 4)
} }
go func() {
wg.Wait()
close(results)
}()
var fallback dialResult var fallback dialResult
var errs []error var errs []error
for i := 0; i < wait; {
loop:
for {
select { select {
case <-fallbackTicker.C: case <-fallbackTicker.C:
if fallback.error == nil && fallback.Conn != nil { if fallback.error == nil && fallback.Conn != nil {
return fallback.Conn, nil return fallback.Conn, nil
} }
case res := <-results: case res, ok := <-results:
i++ if !ok {
break loop
}
if res.error == nil { if res.error == nil {
if res.isPrimary { if res.isPrimary {
return res.Conn, nil return res.Conn, nil
@ -221,6 +234,7 @@ func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string,
} }
} }
} }
if fallback.error == nil && fallback.Conn != nil { if fallback.error == nil && fallback.Conn != nil {
return fallback.Conn, nil return fallback.Conn, nil
} }