chore: cleanup dialer's code

This commit is contained in:
gVisor bot 2023-03-06 23:23:05 +08:00
parent 3d832bc54f
commit 08c113b079
3 changed files with 37 additions and 69 deletions

View file

@ -13,6 +13,8 @@ import (
"github.com/Dreamacro/clash/component/resolver" "github.com/Dreamacro/clash/component/resolver"
) )
type dialFunc func(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error)
var ( var (
dialMux sync.Mutex dialMux sync.Mutex
actualSingleStackDialContext = serialSingleStackDialContext actualSingleStackDialContext = serialSingleStackDialContext
@ -51,11 +53,16 @@ func DialContext(ctx context.Context, network, address string, options ...Option
network = fmt.Sprintf("%s%d", network, opt.network) network = fmt.Sprintf("%s%d", network, opt.network)
} }
ips, port, err := parseAddr(ctx, network, address, opt.resolver)
if err != nil {
return nil, err
}
switch network { switch network {
case "tcp4", "tcp6", "udp4", "udp6": case "tcp4", "tcp6", "udp4", "udp6":
return actualSingleStackDialContext(ctx, network, address, opt) return actualSingleStackDialContext(ctx, network, ips, port, opt)
case "tcp", "udp": case "tcp", "udp":
return actualDualStackDialContext(ctx, network, address, opt) return actualDualStackDialContext(ctx, network, ips, port, opt)
default: default:
return nil, ErrorInvalidedNetworkStack return nil, ErrorInvalidedNetworkStack
} }
@ -82,8 +89,9 @@ func ListenPacket(ctx context.Context, network, address string, options ...Optio
return lc.ListenPacket(ctx, network, address) return lc.ListenPacket(ctx, network, address)
} }
func SetDial(concurrent bool) { func SetTcpConcurrent(concurrent bool) {
dialMux.Lock() dialMux.Lock()
defer dialMux.Unlock()
tcpConcurrent = concurrent tcpConcurrent = concurrent
if concurrent { if concurrent {
actualSingleStackDialContext = concurrentSingleStackDialContext actualSingleStackDialContext = concurrentSingleStackDialContext
@ -92,11 +100,11 @@ func SetDial(concurrent bool) {
actualSingleStackDialContext = serialSingleStackDialContext actualSingleStackDialContext = serialSingleStackDialContext
actualDualStackDialContext = serialDualStackDialContext actualDualStackDialContext = serialDualStackDialContext
} }
dialMux.Unlock()
} }
func GetDial() bool { func GetTcpConcurrent() bool {
dialMux.Lock()
defer dialMux.Unlock()
return tcpConcurrent return tcpConcurrent
} }
@ -118,71 +126,35 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
return dialer.DialContext(ctx, network, address) return dialer.DialContext(ctx, network, address)
} }
func serialSingleStackDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) { func serialSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
ips, port, err := parseAddr(ctx, network, address, opt.resolver)
if err != nil {
return nil, err
}
return serialDialContext(ctx, network, ips, port, opt) return serialDialContext(ctx, network, ips, port, opt)
} }
func serialDualStackDialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) { func serialDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
ips, port, err := parseAddr(ctx, network, address, opt.resolver) return dualStackDialContext(ctx, serialDialContext, network, ips, port, opt)
if err != nil {
return nil, err
}
ipv4s, ipv6s := sortationAddr(ips)
return dualStackDialContext(
ctx,
func(ctx context.Context) (net.Conn, error) { return serialDialContext(ctx, network, ipv4s, port, opt) },
func(ctx context.Context) (net.Conn, error) { return serialDialContext(ctx, network, ipv6s, port, opt) },
opt.prefer)
} }
func concurrentSingleStackDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) { func concurrentSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
ips, port, err := parseAddr(ctx, network, address, opt.resolver) return parallelDialContext(ctx, network, ips, port, opt)
if err != nil {
return nil, err
} }
if conn, err := parallelDialContext(ctx, network, ips, port, opt); err != nil { func concurrentDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
return nil, err
} else {
return conn, nil
}
}
func concurrentDualStackDialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) {
ips, port, err := parseAddr(ctx, network, address, opt.resolver)
if err != nil {
return nil, err
}
if opt.prefer != 4 && opt.prefer != 6 { if opt.prefer != 4 && opt.prefer != 6 {
return parallelDialContext(ctx, network, ips, port, opt) return parallelDialContext(ctx, network, ips, port, opt)
} }
ipv4s, ipv6s := sortationAddr(ips) return dualStackDialContext(ctx, parallelDialContext, network, ips, port, opt)
return dualStackDialContext(
ctx,
func(ctx context.Context) (net.Conn, error) {
return parallelDialContext(ctx, network, ipv4s, port, opt)
},
func(ctx context.Context) (net.Conn, error) {
return parallelDialContext(ctx, network, ipv6s, port, opt)
},
opt.prefer)
} }
func dualStackDialContext( func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
ctx context.Context, ipv4s, ipv6s := sortationAddr(ips)
ipv4DialFn func(ctx context.Context) (net.Conn, error), preferIPVersion := opt.prefer
ipv6DialFn func(ctx context.Context) (net.Conn, error),
preferIPVersion int) (net.Conn, error) {
fallbackTicker := time.NewTicker(fallbackTimeout) fallbackTicker := time.NewTicker(fallbackTimeout)
defer fallbackTicker.Stop() defer fallbackTicker.Stop()
results := make(chan dialResult) results := make(chan dialResult)
returned := make(chan struct{}) returned := make(chan struct{})
defer close(returned) defer close(returned)
racer := func(dial func(ctx context.Context) (net.Conn, error), isPrimary bool) { racer := func(ips []netip.Addr, isPrimary bool) {
result := dialResult{isPrimary: isPrimary} result := dialResult{isPrimary: isPrimary}
defer func() { defer func() {
select { select {
@ -193,10 +165,10 @@ func dualStackDialContext(
} }
} }
}() }()
result.Conn, result.error = dial(ctx) result.Conn, result.error = dialFn(ctx, network, ips, port, opt)
} }
go racer(ipv4DialFn, preferIPVersion != 6) go racer(ipv4s, preferIPVersion != 6)
go racer(ipv6DialFn, preferIPVersion != 4) go racer(ipv6s, preferIPVersion != 4)
var fallback dialResult var fallback dialResult
var err error var err error
for { for {
@ -230,7 +202,7 @@ func parallelDialContext(ctx context.Context, network string, ips []netip.Addr,
results := make(chan dialResult) results := make(chan dialResult)
returned := make(chan struct{}) returned := make(chan struct{})
defer close(returned) defer close(returned)
tcpRacer := func(ctx context.Context, ip netip.Addr, port string) { racer := func(ctx context.Context, ip netip.Addr) {
result := dialResult{isPrimary: true} result := dialResult{isPrimary: true}
defer func() { defer func() {
select { select {
@ -246,7 +218,7 @@ func parallelDialContext(ctx context.Context, network string, ips []netip.Addr,
} }
for _, ip := range ips { for _, ip := range ips {
go tcpRacer(ctx, ip, port) go racer(ctx, ip)
} }
var err error var err error
for { for {
@ -272,13 +244,9 @@ func serialDialContext(ctx context.Context, network string, ips []netip.Addr, po
if len(ips) == 0 { if len(ips) == 0 {
return nil, ErrorNoIpAddress return nil, ErrorNoIpAddress
} }
var ( var errs []error
conn net.Conn
err error
errs []error
)
for _, ip := range ips { for _, ip := range ips {
if conn, err = dialContext(ctx, network, ip, port, opt); err == nil { if conn, err := dialContext(ctx, network, ip, port, opt); err == nil {
return conn, nil return conn, nil
} else { } else {
errs = append(errs, err) errs = append(errs, err)

View file

@ -128,7 +128,7 @@ func GetGeneral() *config.General {
GeodataLoader: G.LoaderName(), GeodataLoader: G.LoaderName(),
Interface: dialer.DefaultInterface.Load(), Interface: dialer.DefaultInterface.Load(),
Sniffing: tunnel.IsSniffing(), Sniffing: tunnel.IsSniffing(),
TCPConcurrent: dialer.GetDial(), TCPConcurrent: dialer.GetTcpConcurrent(),
} }
return general return general
@ -334,7 +334,7 @@ func updateGeneral(general *config.General) {
resolver.DisableIPv6 = !general.IPv6 resolver.DisableIPv6 = !general.IPv6
if general.TCPConcurrent { if general.TCPConcurrent {
dialer.SetDial(general.TCPConcurrent) dialer.SetTcpConcurrent(general.TCPConcurrent)
log.Infoln("Use tcp concurrent") log.Infoln("Use tcp concurrent")
} }

View file

@ -228,7 +228,7 @@ func patchConfigs(w http.ResponseWriter, r *http.Request) {
} }
if general.TcpConcurrent != nil { if general.TcpConcurrent != nil {
dialer.SetDial(*general.TcpConcurrent) dialer.SetTcpConcurrent(*general.TcpConcurrent)
} }
if general.InterfaceName != nil { if general.InterfaceName != nil {