From 8acadd2932000175af26bd54d83ca1a74c03848d Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Sat, 14 Sep 2019 20:00:40 +0800 Subject: [PATCH] Fix: tcp dual stack dial --- adapters/outbound/util.go | 72 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 68 insertions(+), 4 deletions(-) diff --git a/adapters/outbound/util.go b/adapters/outbound/util.go index 0cccf5d9..b3670b3e 100644 --- a/adapters/outbound/util.go +++ b/adapters/outbound/util.go @@ -2,6 +2,7 @@ package adapters import ( "bytes" + "context" "crypto/tls" "fmt" "net" @@ -104,12 +105,75 @@ func dialTimeout(network, address string, timeout time.Duration) (net.Conn, erro return nil, err } - ip, err := dns.ResolveIP(host) - if err != nil { - return nil, err + dialer := net.Dialer{Timeout: timeout} + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + returned := make(chan struct{}) + defer close(returned) + + type dialResult struct { + net.Conn + error + ipv6 bool + done bool + } + results := make(chan dialResult) + var primary, fallback dialResult + + startRacer := func(ctx context.Context, host string, ipv6 bool) { + var err error + + var ip net.IP + if ipv6 { + ip, err = dns.ResolveIPv6(host) + } else { + ip, err = dns.ResolveIPv4(host) + } + if err != nil { + return + } + + var c net.Conn + if ipv6 { + c, err = dialer.DialContext(ctx, "tcp6", net.JoinHostPort(ip.String(), port)) + } else { + c, err = dialer.DialContext(ctx, "tcp4", net.JoinHostPort(ip.String(), port)) + } + if err != nil { + return + } + + select { + case results <- dialResult{Conn: c, error: err, ipv6: ipv6}: + case <-returned: + if c != nil { + c.Close() + } + } } - return net.DialTimeout(network, net.JoinHostPort(ip.String(), port), timeout) + go startRacer(ctx, host, false) + go startRacer(ctx, host, true) + + for { + select { + case res := <-results: + if res.error == nil { + return res.Conn, nil + } + + if res.ipv6 { + primary = res + } else { + fallback = res + } + + if primary.done && fallback.done { + return nil, primary.error + } + } + } } func resolveUDPAddr(network, address string) (*net.UDPAddr, error) {