diff --git a/adapter/outbound/direct.go b/adapter/outbound/direct.go index fdbbcb62..d7ffd478 100644 --- a/adapter/outbound/direct.go +++ b/adapter/outbound/direct.go @@ -5,6 +5,7 @@ import ( "net" "github.com/Dreamacro/clash/component/dialer" + "github.com/Dreamacro/clash/component/resolver" C "github.com/Dreamacro/clash/constant" ) @@ -14,7 +15,7 @@ type Direct struct { // DialContext implements C.ProxyAdapter func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { - opts = append(opts, dialer.WithDirect()) + opts = append(opts, dialer.WithResolver(resolver.DefaultResolver)) c, err := dialer.DialContext(ctx, "tcp", metadata.RemoteAddress(), d.Base.DialOptions(opts...)...) if err != nil { return nil, err @@ -25,7 +26,7 @@ func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ... // ListenPacketContext implements C.ProxyAdapter func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) { - opts = append(opts, dialer.WithDirect()) + opts = append(opts, dialer.WithResolver(resolver.DefaultResolver)) pc, err := dialer.ListenPacket(ctx, "udp", "", d.Base.DialOptions(opts...)...) if err != nil { return nil, err diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index cb87061c..44fa2e90 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -4,12 +4,14 @@ import ( "context" "errors" "fmt" - "github.com/Dreamacro/clash/component/resolver" - "go.uber.org/atomic" "net" "net/netip" "strings" "sync" + + "github.com/Dreamacro/clash/component/resolver" + + "go.uber.org/atomic" ) var ( @@ -149,7 +151,7 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt results := make(chan dialResult) var primary, fallback dialResult - startRacer := func(ctx context.Context, network, host string, direct bool, ipv6 bool) { + startRacer := func(ctx context.Context, network, host string, r resolver.Resolver, ipv6 bool) { result := dialResult{ipv6: ipv6, done: true} defer func() { select { @@ -163,16 +165,16 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt var ip netip.Addr if ipv6 { - if !direct { + if r == nil { ip, result.error = resolver.ResolveIPv6ProxyServerHost(ctx, host) } else { - ip, result.error = resolver.ResolveIPv6(ctx, host) + ip, result.error = resolver.ResolveIPv6WithResolver(ctx, host, r) } } else { - if !direct { + if r == nil { ip, result.error = resolver.ResolveIPv4ProxyServerHost(ctx, host) } else { - ip, result.error = resolver.ResolveIPv4(ctx, host) + ip, result.error = resolver.ResolveIPv4WithResolver(ctx, host, r) } } if result.error != nil { @@ -183,8 +185,8 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt result.Conn, result.error = dialContext(ctx, network, ip, port, opt) } - go startRacer(ctx, network+"4", host, opt.direct, false) - go startRacer(ctx, network+"6", host, opt.direct, true) + go startRacer(ctx, network+"4", host, opt.resolver, false) + go startRacer(ctx, network+"6", host, opt.resolver, true) count := 2 for i := 0; i < count; i++ { @@ -230,8 +232,8 @@ func concurrentDualStackDialContext(ctx context.Context, network, address string } var ips []netip.Addr - if opt.direct { - ips, err = resolver.LookupIP(ctx, host) + if opt.resolver != nil { + ips, err = resolver.LookupIPWithResolver(ctx, host, opt.resolver) } else { ips, err = resolver.LookupIPProxyServerHost(ctx, host) } @@ -363,16 +365,16 @@ func singleDialContext(ctx context.Context, network string, address string, opt var ip netip.Addr switch network { case "tcp4", "udp4": - if !opt.direct { + if opt.resolver == nil { ip, err = resolver.ResolveIPv4ProxyServerHost(ctx, host) } else { - ip, err = resolver.ResolveIPv4(ctx, host) + ip, err = resolver.ResolveIPv4WithResolver(ctx, host, opt.resolver) } default: - if !opt.direct { + if opt.resolver == nil { ip, err = resolver.ResolveIPv6ProxyServerHost(ctx, host) } else { - ip, err = resolver.ResolveIPv6(ctx, host) + ip, err = resolver.ResolveIPv6WithResolver(ctx, host, opt.resolver) } } if err != nil { @@ -398,10 +400,10 @@ func concurrentIPv4DialContext(ctx context.Context, network, address string, opt } var ips []netip.Addr - if !opt.direct { + if opt.resolver == nil { ips, err = resolver.LookupIPv4ProxyServerHost(ctx, host) } else { - ips, err = resolver.LookupIPv4(ctx, host) + ips, err = resolver.LookupIPv4WithResolver(ctx, host, opt.resolver) } if err != nil { @@ -418,10 +420,10 @@ func concurrentIPv6DialContext(ctx context.Context, network, address string, opt } var ips []netip.Addr - if !opt.direct { + if opt.resolver == nil { ips, err = resolver.LookupIPv6ProxyServerHost(ctx, host) } else { - ips, err = resolver.LookupIPv6(ctx, host) + ips, err = resolver.LookupIPv6WithResolver(ctx, host, opt.resolver) } if err != nil { diff --git a/component/dialer/options.go b/component/dialer/options.go index ce911035..98d0b8bd 100644 --- a/component/dialer/options.go +++ b/component/dialer/options.go @@ -1,6 +1,8 @@ package dialer import ( + "github.com/Dreamacro/clash/component/resolver" + "go.uber.org/atomic" ) @@ -14,9 +16,9 @@ type option struct { interfaceName string addrReuse bool routingMark int - direct bool network int prefer int + resolver resolver.Resolver } type Option func(opt *option) @@ -39,9 +41,9 @@ func WithRoutingMark(mark int) Option { } } -func WithDirect() Option { +func WithResolver(r resolver.Resolver) Option { return func(opt *option) { - opt.direct = true + opt.resolver = r } } diff --git a/dns/client.go b/dns/client.go index 0a13469c..a7bf5eb3 100644 --- a/dns/client.go +++ b/dns/client.go @@ -60,13 +60,7 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error) options = append(options, dialer.WithInterface(c.iface.Load())) } - var conn net.Conn - if c.proxyAdapter != "" { - conn, err = dialContextExtra(ctx, c.proxyAdapter, network, ip, c.port, options...) - } else { - conn, err = dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), c.port), options...) - } - + conn, err := getDialHandler(c.r, c.proxyAdapter, options...)(ctx, network, net.JoinHostPort(ip.String(), c.port)) if err != nil { return nil, err } diff --git a/dns/doh.go b/dns/doh.go index fc32a212..ac5d7486 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -536,7 +536,7 @@ func (doh *dnsOverHTTPS) dialQuic(ctx context.Context, addr string, tlsCfg *tls. return nil, err } } else { - if wrapConn, err := dialContextExtra(ctx, doh.proxyAdapter, "udp", udpAddr.AddrPort().Addr(), port); err == nil { + if wrapConn, err := dialContextExtra(ctx, doh.proxyAdapter, "udp", addr, doh.r); err == nil { if pc, ok := wrapConn.(*wrapPacketConn); ok { conn = pc } else { diff --git a/dns/doq.go b/dns/doq.go index d4fbb037..16a32a3a 100644 --- a/dns/doq.go +++ b/dns/doq.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "net" - "net/netip" "runtime" "strconv" "sync" @@ -41,8 +40,6 @@ const ( DefaultTimeout = time.Second * 5 ) -type dialHandler func(ctx context.Context, network, addr string) (net.Conn, error) - // dnsOverQUIC is a struct that implements the Upstream interface for the // DNS-over-QUIC protocol (spec: https://www.rfc-editor.org/rfc/rfc9250.html). type dnsOverQUIC struct { @@ -345,12 +342,7 @@ func (doq *dnsOverQUIC) openConnection(ctx context.Context) (conn quic.Connectio return nil, err } } else { - ipAddr, err := netip.ParseAddr(ip) - if err != nil { - return nil, err - } - - conn, err := dialContextExtra(ctx, doq.proxyAdapter, "udp", ipAddr, port) + conn, err := dialContextExtra(ctx, doq.proxyAdapter, "udp", addr, doq.r) if err != nil { return nil, err } @@ -498,21 +490,3 @@ func isQUICRetryError(err error) (ok bool) { return false } - -func getDialHandler(r *Resolver, proxyAdapter string) dialHandler { - return func(ctx context.Context, network, addr string) (net.Conn, error) { - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - ip, err := r.ResolveIP(ctx, host) - if err != nil { - return nil, err - } - if len(proxyAdapter) == 0 { - return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port), dialer.WithDirect()) - } else { - return dialContextExtra(ctx, proxyAdapter, network, ip.Unmap(), port) - } - } -} diff --git a/dns/util.go b/dns/util.go index 5bf09b8f..fb42a9e7 100644 --- a/dns/util.go +++ b/dns/util.go @@ -160,27 +160,54 @@ func (wpc *wrapPacketConn) LocalAddr() net.Addr { } } -func dialContextExtra(ctx context.Context, adapterName string, network string, dstIP netip.Addr, port string, opts ...dialer.Option) (net.Conn, error) { - networkType := C.TCP - if network == "udp" { +type dialHandler func(ctx context.Context, network, addr string) (net.Conn, error) - networkType = C.UDP +func getDialHandler(r *Resolver, proxyAdapter string, opts ...dialer.Option) dialHandler { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + if len(proxyAdapter) == 0 { + opts = append(opts, dialer.WithResolver(r)) + return dialer.DialContext(ctx, network, addr, opts...) + } else { + return dialContextExtra(ctx, proxyAdapter, network, addr, r, opts...) + } } +} - metadata := &C.Metadata{ - NetWork: networkType, - Host: "", - DstIP: dstIP, - DstPort: port, +func dialContextExtra(ctx context.Context, adapterName string, network string, addr string, r *Resolver, opts ...dialer.Option) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err } - adapter, ok := tunnel.Proxies()[adapterName] if !ok { opts = append(opts, dialer.WithInterface(adapterName)) - if C.TCP == networkType { - return dialer.DialContext(ctx, network, dstIP.String()+":"+port, opts...) - } else { - packetConn, err := dialer.ListenPacket(ctx, network, dstIP.String()+":"+port, opts...) + } + if strings.Contains(network, "tcp") { + // tcp can resolve host by remote + metadata := &C.Metadata{ + NetWork: C.TCP, + Host: host, + DstPort: port, + } + if ok { + return adapter.DialContext(ctx, metadata, opts...) + } + opts = append(opts, dialer.WithResolver(r)) + return dialer.DialContext(ctx, network, addr, opts...) + } else { + // udp must resolve host first + dstIP, err := resolver.ResolveIPWithResolver(ctx, host, r) + if err != nil { + return nil, err + } + metadata := &C.Metadata{ + NetWork: C.UDP, + Host: "", + DstIP: dstIP, + DstPort: port, + } + if !ok { + packetConn, err := dialer.ListenPacket(ctx, network, metadata.RemoteAddress(), opts...) if err != nil { return nil, err } @@ -189,15 +216,12 @@ func dialContextExtra(ctx context.Context, adapterName string, network string, d PacketConn: packetConn, rAddr: metadata.UDPAddr(), }, nil - } - } - if networkType == C.UDP && !adapter.SupportUDP() { - return nil, fmt.Errorf("proxy adapter [%s] UDP is not supported", adapterName) - } + if !adapter.SupportUDP() { + return nil, fmt.Errorf("proxy adapter [%s] UDP is not supported", adapterName) + } - if networkType == C.UDP { packetConn, err := adapter.ListenPacketContext(ctx, metadata, opts...) if err != nil { return nil, err @@ -208,8 +232,6 @@ func dialContextExtra(ctx context.Context, adapterName string, network string, d rAddr: metadata.UDPAddr(), }, nil } - - return adapter.DialContext(ctx, metadata, opts...) } func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) {