refactor: tcp dial (#412)

Non-concurrent support to try to connect in turn

fix: serial dual stack dial
This commit is contained in:
gVisor bot 2023-02-26 11:24:49 +08:00
parent f6911e21ea
commit bdb4aa3c1f
2 changed files with 170 additions and 249 deletions

View file

@ -8,20 +8,19 @@ import (
"net/netip" "net/netip"
"strings" "strings"
"sync" "sync"
"time"
"github.com/Dreamacro/clash/component/resolver" "github.com/Dreamacro/clash/component/resolver"
"go.uber.org/atomic"
) )
var ( var (
dialMux sync.Mutex dialMux sync.Mutex
actualSingleDialContext = singleDialContext actualSingleStackDialContext = serialSingleStackDialContext
actualDualStackDialContext = dualStackDialContext actualDualStackDialContext = serialDualStackDialContext
tcpConcurrent = false tcpConcurrent = false
DisableIPv6 = false ErrorInvalidedNetworkStack = errors.New("invalided network stack")
ErrorInvalidedNetworkStack = errors.New("invalided network stack") ErrorConnTimeout = errors.New("connect timeout")
ErrorDisableIPv6 = errors.New("IPv6 is disabled, dialer cancel") fallbackTimeout = 300 * time.Millisecond
) )
func applyOptions(options ...Option) *option { func applyOptions(options ...Option) *option {
@ -56,7 +55,7 @@ func DialContext(ctx context.Context, network, address string, options ...Option
switch network { switch network {
case "tcp4", "tcp6", "udp4", "udp6": case "tcp4", "tcp6", "udp4", "udp6":
return actualSingleDialContext(ctx, network, address, opt) return actualSingleStackDialContext(ctx, network, address, opt)
case "tcp", "udp": case "tcp", "udp":
return actualDualStackDialContext(ctx, network, address, opt) return actualDualStackDialContext(ctx, network, address, opt)
default: default:
@ -89,11 +88,11 @@ func SetDial(concurrent bool) {
dialMux.Lock() dialMux.Lock()
tcpConcurrent = concurrent tcpConcurrent = concurrent
if concurrent { if concurrent {
actualSingleDialContext = concurrentSingleDialContext actualSingleStackDialContext = concurrentSingleStackDialContext
actualDualStackDialContext = concurrentDualStackDialContext actualDualStackDialContext = concurrentDualStackDialContext
} else { } else {
actualSingleDialContext = singleDialContext actualSingleStackDialContext = serialSingleStackDialContext
actualDualStackDialContext = dualStackDialContext actualDualStackDialContext = serialDualStackDialContext
} }
dialMux.Unlock() dialMux.Unlock()
@ -114,10 +113,6 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
bindMarkToDialer(opt.routingMark, dialer, network, destination) bindMarkToDialer(opt.routingMark, dialer, network, destination)
} }
if DisableIPv6 && destination.Is6() {
return nil, ErrorDisableIPv6
}
address := net.JoinHostPort(destination.String(), port) address := net.JoinHostPort(destination.String(), port)
if opt.tfo { if opt.tfo {
return dialTFO(ctx, *dialer, network, address) return dialTFO(ctx, *dialer, network, address)
@ -125,146 +120,74 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
return dialer.DialContext(ctx, network, address) return dialer.DialContext(ctx, network, address)
} }
func singleDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) { func serialSingleStackDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) {
host, port, err := net.SplitHostPort(address) ips, port, err := parseAddr(ctx, network, address, opt.resolver)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return serialDialContext(ctx, network, ips, port, opt)
var ip netip.Addr
switch network {
case "tcp4", "udp4":
if opt.resolver == nil {
ip, err = resolver.ResolveIPv4ProxyServerHost(ctx, host)
} else {
ip, err = resolver.ResolveIPv4WithResolver(ctx, host, opt.resolver)
}
default:
if opt.resolver == nil {
ip, err = resolver.ResolveIPv6ProxyServerHost(ctx, host)
} else {
ip, err = resolver.ResolveIPv6WithResolver(ctx, host, opt.resolver)
}
}
if err != nil {
err = fmt.Errorf("dns resolve failed:%w", err)
return nil, err
}
return dialContext(ctx, network, ip, port, opt)
} }
func dualStackDialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) { func serialDualStackDialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) {
host, port, err := net.SplitHostPort(address) ips, port, err := parseAddr(ctx, network, address, opt.resolver)
if err != nil {
return nil, err
}
if opt.prefer != 4 && opt.prefer != 6 {
return serialDialContext(ctx, network, ips, port, opt)
}
return dualStackDialContext(
ctx,
func(ctx context.Context) (net.Conn, error) { return serialDialContext(ctx, network, ips, port, opt) },
func(ctx context.Context) (net.Conn, error) { return serialDialContext(ctx, network, ips, port, opt) },
opt.prefer == 4)
}
func concurrentSingleStackDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) {
ips, port, err := parseAddr(ctx, network, address, opt.resolver)
if err != nil { if err != nil {
return nil, err return nil, err
} }
returned := make(chan struct{}) if conn, err := parallelDialContext(ctx, network, ips, port, opt); err != nil {
defer close(returned) return nil, err
type dialResult struct {
net.Conn
error
resolved bool
ipv6 bool
done bool
}
results := make(chan dialResult)
var primary, fallback dialResult
startRacer := func(ctx context.Context, network, host string, r resolver.Resolver, ipv6 bool) {
result := dialResult{ipv6: ipv6, done: true}
defer func() {
select {
case results <- result:
case <-returned:
if result.Conn != nil {
_ = result.Conn.Close()
}
}
}()
var ip netip.Addr
if ipv6 {
if r == nil {
ip, result.error = resolver.ResolveIPv6ProxyServerHost(ctx, host)
} else {
ip, result.error = resolver.ResolveIPv6WithResolver(ctx, host, r)
}
} else {
if r == nil {
ip, result.error = resolver.ResolveIPv4ProxyServerHost(ctx, host)
} else {
ip, result.error = resolver.ResolveIPv4WithResolver(ctx, host, r)
}
}
if result.error != nil {
result.error = fmt.Errorf("dns resolve failed:%w", result.error)
return
}
result.resolved = true
result.Conn, result.error = dialContext(ctx, network, ip, port, opt)
}
go startRacer(ctx, network+"4", host, opt.resolver, false)
go startRacer(ctx, network+"6", host, opt.resolver, true)
count := 2
for i := 0; i < count; i++ {
select {
case res := <-results:
if res.error == nil {
return res.Conn, nil
}
if !res.ipv6 {
primary = res
} else {
fallback = res
}
if primary.done && fallback.done {
if primary.resolved {
return nil, primary.error
} else if fallback.resolved {
return nil, fallback.error
} else {
return nil, primary.error
}
}
case <-ctx.Done():
err = ctx.Err()
break
}
}
if err == nil {
err = fmt.Errorf("dual stack dial failed")
} else { } else {
err = fmt.Errorf("dual stack dial failed:%w", err) return conn, nil
} }
return nil, err
} }
func concurrentDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { func concurrentDualStackDialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) {
ips, port, err := parseAddr(ctx, network, address, opt.resolver)
if err != nil {
return nil, err
}
if opt.prefer != 4 && opt.prefer != 6 {
return parallelDialContext(ctx, network, ips, port, opt)
}
ipv4s, ipv6s := sortationAddr(ips)
return dualStackDialContext(
ctx,
func(ctx context.Context) (net.Conn, error) {
return parallelDialContext(ctx, network, ipv4s, port, opt)
},
func(ctx context.Context) (net.Conn, error) {
return parallelDialContext(ctx, network, ipv6s, port, opt)
},
opt.prefer == 4)
}
func dualStackDialContext(
ctx context.Context,
ipv4DialFn func(ctx context.Context) (net.Conn, error),
ipv6DialFn func(ctx context.Context) (net.Conn, error),
preferIPv4 bool) (net.Conn, error) {
fallbackTicker := time.NewTicker(fallbackTimeout)
defer fallbackTicker.Stop()
results := make(chan dialResult)
returned := make(chan struct{}) returned := make(chan struct{})
defer close(returned) defer close(returned)
racer := func(dial func(ctx context.Context) (net.Conn, error), isPrimary bool) {
type dialResult struct { result := dialResult{isPrimary: isPrimary}
ip netip.Addr
net.Conn
error
isPrimary bool
done bool
}
preferCount := atomic.NewInt32(0)
results := make(chan dialResult)
tcpRacer := func(ctx context.Context, ip netip.Addr) {
result := dialResult{ip: ip, done: true}
defer func() { defer func() {
select { select {
case results <- result: case results <- result:
@ -274,140 +197,142 @@ func concurrentDialContext(ctx context.Context, network string, ips []netip.Addr
} }
} }
}() }()
if strings.Contains(network, "tcp") { result.Conn, result.error = dial(ctx)
network = "tcp"
} else {
network = "udp"
}
if ip.Is6() {
network += "6"
if opt.prefer != 4 {
result.isPrimary = true
}
}
if ip.Is4() {
network += "4"
if opt.prefer != 6 {
result.isPrimary = true
}
}
if result.isPrimary {
preferCount.Add(1)
}
result.Conn, result.error = dialContext(ctx, network, ip, port, opt)
} }
go racer(ipv4DialFn, preferIPv4)
for _, ip := range ips { go racer(ipv6DialFn, !preferIPv4)
go tcpRacer(ctx, ip)
}
connCount := len(ips)
var fallback dialResult var fallback dialResult
var primaryError error var err error
var finalError error for {
for i := 0; i < connCount; i++ {
select { select {
case <-ctx.Done():
if fallback.error == nil && fallback.Conn != nil {
return fallback.Conn, nil
}
return nil, fmt.Errorf("dual stack connect failed: %w", err)
case <-fallbackTicker.C:
if fallback.error == nil && fallback.Conn != nil {
return fallback.Conn, nil
}
case res := <-results: case res := <-results:
if res.error == nil { if res.error == nil {
if res.isPrimary { if res.isPrimary {
return res.Conn, nil return res.Conn, nil
} else {
if !fallback.done || fallback.error != nil {
fallback = res
}
}
} else {
if res.isPrimary {
primaryError = res.error
preferCount.Add(-1)
if preferCount.Load() == 0 && fallback.done && fallback.error == nil {
return fallback.Conn, nil
}
} }
fallback = res
} }
case <-ctx.Done(): err = res.error
if fallback.done && fallback.error == nil {
return fallback.Conn, nil
}
finalError = ctx.Err()
break
} }
} }
if fallback.done && fallback.error == nil {
return fallback.Conn, nil
}
if primaryError != nil {
return nil, primaryError
}
if fallback.error != nil {
return nil, fallback.error
}
if finalError == nil {
finalError = fmt.Errorf("all ips %v tcp shake hands failed", ips)
} else {
finalError = fmt.Errorf("concurrent dial failed:%w", finalError)
}
return nil, finalError
} }
func concurrentSingleDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) { func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
results := make(chan dialResult)
returned := make(chan struct{})
defer close(returned)
tcpRacer := func(ctx context.Context, ip netip.Addr, port string) {
result := dialResult{isPrimary: true}
defer func() {
select {
case results <- result:
case <-returned:
if result.Conn != nil {
_ = result.Conn.Close()
}
}
}()
result.ip = ip
result.Conn, result.error = dialContext(ctx, network, ip, port, opt)
}
for _, ip := range ips {
go tcpRacer(ctx, ip, port)
}
var err error
for {
select {
case <-ctx.Done():
if err != nil {
return nil, err
}
if ctx.Err() == context.DeadlineExceeded {
return nil, ErrorConnTimeout
}
return nil, ctx.Err()
case res := <-results:
if res.error == nil {
return res.Conn, nil
}
err = res.error
}
}
}
func serialDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
var (
conn net.Conn
err error
errs []error
)
for _, ip := range ips {
if conn, err = dialContext(ctx, network, ip, port, opt); err == nil {
return conn, nil
} else {
errs = append(errs, err)
}
}
return nil, errors.Join(errs...)
}
type dialResult struct {
ip netip.Addr
net.Conn
error
isPrimary bool
}
func parseAddr(ctx context.Context, network, address string, preferResolver resolver.Resolver) ([]netip.Addr, string, error) {
host, port, err := net.SplitHostPort(address) host, port, err := net.SplitHostPort(address)
if err != nil { if err != nil {
return nil, err return nil, "-1", err
} }
var ips []netip.Addr var ips []netip.Addr
switch network { switch network {
case "tcp4", "udp4": case "tcp4", "udp4":
if opt.resolver == nil { if preferResolver == nil {
ips, err = resolver.LookupIPv4ProxyServerHost(ctx, host) ips, err = resolver.LookupIPv4ProxyServerHost(ctx, host)
} else { } else {
ips, err = resolver.LookupIPv4WithResolver(ctx, host, opt.resolver) ips, err = resolver.LookupIPv4WithResolver(ctx, host, preferResolver)
} }
default: case "tcp6", "udp6":
if opt.resolver == nil { if preferResolver == nil {
ips, err = resolver.LookupIPv6ProxyServerHost(ctx, host) ips, err = resolver.LookupIPv6ProxyServerHost(ctx, host)
} else { } else {
ips, err = resolver.LookupIPv6WithResolver(ctx, host, opt.resolver) ips, err = resolver.LookupIPv6WithResolver(ctx, host, preferResolver)
}
default:
if preferResolver == nil {
ips, err = resolver.LookupIP(ctx, host)
} else {
ips, err = resolver.LookupIPWithResolver(ctx, host, preferResolver)
} }
} }
if err != nil { if err != nil {
err = fmt.Errorf("dns resolve failed:%w", err) return nil, "-1", fmt.Errorf("dns resolve failed: %w", err)
return nil, err
} }
return ips, port, nil
return concurrentDialContext(ctx, network, ips, port, opt)
} }
func concurrentDualStackDialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) { func sortationAddr(ips []netip.Addr) (ipv4s, ipv6s []netip.Addr) {
host, port, err := net.SplitHostPort(address) for _, v := range ips {
if err != nil { if v.Is4() || v.Is4In6() {
return nil, err ipv4s = append(ipv4s, v)
} else {
ipv6s = append(ipv6s, v)
}
} }
return
var ips []netip.Addr
if opt.resolver != nil {
ips, err = resolver.LookupIPWithResolver(ctx, host, opt.resolver)
} else {
ips, err = resolver.LookupIPProxyServerHost(ctx, host)
}
if err != nil {
err = fmt.Errorf("dns resolve failed:%w", err)
return nil, err
}
return concurrentDialContext(ctx, network, ips, port, opt)
} }
type Dialer struct { type Dialer struct {

View file

@ -331,11 +331,7 @@ func updateTunnels(tunnels []LC.Tunnel) {
func updateGeneral(general *config.General) { func updateGeneral(general *config.General) {
tunnel.SetMode(general.Mode) tunnel.SetMode(general.Mode)
tunnel.SetFindProcessMode(general.FindProcessMode) tunnel.SetFindProcessMode(general.FindProcessMode)
dialer.DisableIPv6 = !general.IPv6 resolver.DisableIPv6 =!general.IPv6
if !dialer.DisableIPv6 {
log.Infoln("Use IPv6")
}
resolver.DisableIPv6 = dialer.DisableIPv6
if general.TCPConcurrent { if general.TCPConcurrent {
dialer.SetDial(general.TCPConcurrent) dialer.SetDial(general.TCPConcurrent)