fix: let doh/dot server follow hosts and can remotely resolve itself ip

This commit is contained in:
wwqgtxx 2022-12-07 20:01:44 +08:00
parent e03fcd24dd
commit a6f7e1472b
7 changed files with 76 additions and 81 deletions

View file

@ -5,6 +5,7 @@ import (
"net" "net"
"github.com/Dreamacro/clash/component/dialer" "github.com/Dreamacro/clash/component/dialer"
"github.com/Dreamacro/clash/component/resolver"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
) )
@ -14,7 +15,7 @@ type Direct struct {
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
opts = append(opts, dialer.WithDirect()) opts = append(opts, dialer.WithResolver(resolver.DefaultResolver))
c, err := dialer.DialContext(ctx, "tcp", metadata.RemoteAddress(), d.Base.DialOptions(opts...)...) c, err := dialer.DialContext(ctx, "tcp", metadata.RemoteAddress(), d.Base.DialOptions(opts...)...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -25,7 +26,7 @@ func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) { func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
opts = append(opts, dialer.WithDirect()) opts = append(opts, dialer.WithResolver(resolver.DefaultResolver))
pc, err := dialer.ListenPacket(ctx, "udp", "", d.Base.DialOptions(opts...)...) pc, err := dialer.ListenPacket(ctx, "udp", "", d.Base.DialOptions(opts...)...)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -4,12 +4,14 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/Dreamacro/clash/component/resolver"
"go.uber.org/atomic"
"net" "net"
"net/netip" "net/netip"
"strings" "strings"
"sync" "sync"
"github.com/Dreamacro/clash/component/resolver"
"go.uber.org/atomic"
) )
var ( var (
@ -149,7 +151,7 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt
results := make(chan dialResult) results := make(chan dialResult)
var primary, fallback dialResult var primary, fallback dialResult
startRacer := func(ctx context.Context, network, host string, direct bool, ipv6 bool) { startRacer := func(ctx context.Context, network, host string, r resolver.Resolver, ipv6 bool) {
result := dialResult{ipv6: ipv6, done: true} result := dialResult{ipv6: ipv6, done: true}
defer func() { defer func() {
select { select {
@ -163,16 +165,16 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt
var ip netip.Addr var ip netip.Addr
if ipv6 { if ipv6 {
if !direct { if r == nil {
ip, result.error = resolver.ResolveIPv6ProxyServerHost(ctx, host) ip, result.error = resolver.ResolveIPv6ProxyServerHost(ctx, host)
} else { } else {
ip, result.error = resolver.ResolveIPv6(ctx, host) ip, result.error = resolver.ResolveIPv6WithResolver(ctx, host, r)
} }
} else { } else {
if !direct { if r == nil {
ip, result.error = resolver.ResolveIPv4ProxyServerHost(ctx, host) ip, result.error = resolver.ResolveIPv4ProxyServerHost(ctx, host)
} else { } else {
ip, result.error = resolver.ResolveIPv4(ctx, host) ip, result.error = resolver.ResolveIPv4WithResolver(ctx, host, r)
} }
} }
if result.error != nil { if result.error != nil {
@ -183,8 +185,8 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt
result.Conn, result.error = dialContext(ctx, network, ip, port, opt) result.Conn, result.error = dialContext(ctx, network, ip, port, opt)
} }
go startRacer(ctx, network+"4", host, opt.direct, false) go startRacer(ctx, network+"4", host, opt.resolver, false)
go startRacer(ctx, network+"6", host, opt.direct, true) go startRacer(ctx, network+"6", host, opt.resolver, true)
count := 2 count := 2
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
@ -230,8 +232,8 @@ func concurrentDualStackDialContext(ctx context.Context, network, address string
} }
var ips []netip.Addr var ips []netip.Addr
if opt.direct { if opt.resolver != nil {
ips, err = resolver.LookupIP(ctx, host) ips, err = resolver.LookupIPWithResolver(ctx, host, opt.resolver)
} else { } else {
ips, err = resolver.LookupIPProxyServerHost(ctx, host) ips, err = resolver.LookupIPProxyServerHost(ctx, host)
} }
@ -363,16 +365,16 @@ func singleDialContext(ctx context.Context, network string, address string, opt
var ip netip.Addr var ip netip.Addr
switch network { switch network {
case "tcp4", "udp4": case "tcp4", "udp4":
if !opt.direct { if opt.resolver == nil {
ip, err = resolver.ResolveIPv4ProxyServerHost(ctx, host) ip, err = resolver.ResolveIPv4ProxyServerHost(ctx, host)
} else { } else {
ip, err = resolver.ResolveIPv4(ctx, host) ip, err = resolver.ResolveIPv4WithResolver(ctx, host, opt.resolver)
} }
default: default:
if !opt.direct { if opt.resolver == nil {
ip, err = resolver.ResolveIPv6ProxyServerHost(ctx, host) ip, err = resolver.ResolveIPv6ProxyServerHost(ctx, host)
} else { } else {
ip, err = resolver.ResolveIPv6(ctx, host) ip, err = resolver.ResolveIPv6WithResolver(ctx, host, opt.resolver)
} }
} }
if err != nil { if err != nil {
@ -398,10 +400,10 @@ func concurrentIPv4DialContext(ctx context.Context, network, address string, opt
} }
var ips []netip.Addr var ips []netip.Addr
if !opt.direct { if opt.resolver == nil {
ips, err = resolver.LookupIPv4ProxyServerHost(ctx, host) ips, err = resolver.LookupIPv4ProxyServerHost(ctx, host)
} else { } else {
ips, err = resolver.LookupIPv4(ctx, host) ips, err = resolver.LookupIPv4WithResolver(ctx, host, opt.resolver)
} }
if err != nil { if err != nil {
@ -418,10 +420,10 @@ func concurrentIPv6DialContext(ctx context.Context, network, address string, opt
} }
var ips []netip.Addr var ips []netip.Addr
if !opt.direct { if opt.resolver == nil {
ips, err = resolver.LookupIPv6ProxyServerHost(ctx, host) ips, err = resolver.LookupIPv6ProxyServerHost(ctx, host)
} else { } else {
ips, err = resolver.LookupIPv6(ctx, host) ips, err = resolver.LookupIPv6WithResolver(ctx, host, opt.resolver)
} }
if err != nil { if err != nil {

View file

@ -1,6 +1,8 @@
package dialer package dialer
import ( import (
"github.com/Dreamacro/clash/component/resolver"
"go.uber.org/atomic" "go.uber.org/atomic"
) )
@ -14,9 +16,9 @@ type option struct {
interfaceName string interfaceName string
addrReuse bool addrReuse bool
routingMark int routingMark int
direct bool
network int network int
prefer int prefer int
resolver resolver.Resolver
} }
type Option func(opt *option) type Option func(opt *option)
@ -39,9 +41,9 @@ func WithRoutingMark(mark int) Option {
} }
} }
func WithDirect() Option { func WithResolver(r resolver.Resolver) Option {
return func(opt *option) { return func(opt *option) {
opt.direct = true opt.resolver = r
} }
} }

View file

@ -60,13 +60,7 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error)
options = append(options, dialer.WithInterface(c.iface.Load())) options = append(options, dialer.WithInterface(c.iface.Load()))
} }
var conn net.Conn conn, err := getDialHandler(c.r, c.proxyAdapter, options...)(ctx, network, net.JoinHostPort(ip.String(), c.port))
if c.proxyAdapter != "" {
conn, err = dialContextExtra(ctx, c.proxyAdapter, network, ip, c.port, options...)
} else {
conn, err = dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), c.port), options...)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -536,7 +536,7 @@ func (doh *dnsOverHTTPS) dialQuic(ctx context.Context, addr string, tlsCfg *tls.
return nil, err return nil, err
} }
} else { } else {
if wrapConn, err := dialContextExtra(ctx, doh.proxyAdapter, "udp", udpAddr.AddrPort().Addr(), port); err == nil { if wrapConn, err := dialContextExtra(ctx, doh.proxyAdapter, "udp", addr, doh.r); err == nil {
if pc, ok := wrapConn.(*wrapPacketConn); ok { if pc, ok := wrapConn.(*wrapPacketConn); ok {
conn = pc conn = pc
} else { } else {

View file

@ -7,7 +7,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"runtime" "runtime"
"strconv" "strconv"
"sync" "sync"
@ -41,8 +40,6 @@ const (
DefaultTimeout = time.Second * 5 DefaultTimeout = time.Second * 5
) )
type dialHandler func(ctx context.Context, network, addr string) (net.Conn, error)
// dnsOverQUIC is a struct that implements the Upstream interface for the // dnsOverQUIC is a struct that implements the Upstream interface for the
// DNS-over-QUIC protocol (spec: https://www.rfc-editor.org/rfc/rfc9250.html). // DNS-over-QUIC protocol (spec: https://www.rfc-editor.org/rfc/rfc9250.html).
type dnsOverQUIC struct { type dnsOverQUIC struct {
@ -345,12 +342,7 @@ func (doq *dnsOverQUIC) openConnection(ctx context.Context) (conn quic.Connectio
return nil, err return nil, err
} }
} else { } else {
ipAddr, err := netip.ParseAddr(ip) conn, err := dialContextExtra(ctx, doq.proxyAdapter, "udp", addr, doq.r)
if err != nil {
return nil, err
}
conn, err := dialContextExtra(ctx, doq.proxyAdapter, "udp", ipAddr, port)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -498,21 +490,3 @@ func isQUICRetryError(err error) (ok bool) {
return false return false
} }
func getDialHandler(r *Resolver, proxyAdapter string) dialHandler {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
ip, err := r.ResolveIP(ctx, host)
if err != nil {
return nil, err
}
if len(proxyAdapter) == 0 {
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port), dialer.WithDirect())
} else {
return dialContextExtra(ctx, proxyAdapter, network, ip.Unmap(), port)
}
}
}

View file

@ -160,27 +160,54 @@ func (wpc *wrapPacketConn) LocalAddr() net.Addr {
} }
} }
func dialContextExtra(ctx context.Context, adapterName string, network string, dstIP netip.Addr, port string, opts ...dialer.Option) (net.Conn, error) { type dialHandler func(ctx context.Context, network, addr string) (net.Conn, error)
networkType := C.TCP
if network == "udp" {
networkType = C.UDP func getDialHandler(r *Resolver, proxyAdapter string, opts ...dialer.Option) dialHandler {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
if len(proxyAdapter) == 0 {
opts = append(opts, dialer.WithResolver(r))
return dialer.DialContext(ctx, network, addr, opts...)
} else {
return dialContextExtra(ctx, proxyAdapter, network, addr, r, opts...)
}
} }
}
metadata := &C.Metadata{ func dialContextExtra(ctx context.Context, adapterName string, network string, addr string, r *Resolver, opts ...dialer.Option) (net.Conn, error) {
NetWork: networkType, host, port, err := net.SplitHostPort(addr)
Host: "", if err != nil {
DstIP: dstIP, return nil, err
DstPort: port,
} }
adapter, ok := tunnel.Proxies()[adapterName] adapter, ok := tunnel.Proxies()[adapterName]
if !ok { if !ok {
opts = append(opts, dialer.WithInterface(adapterName)) opts = append(opts, dialer.WithInterface(adapterName))
if C.TCP == networkType { }
return dialer.DialContext(ctx, network, dstIP.String()+":"+port, opts...) if strings.Contains(network, "tcp") {
} else { // tcp can resolve host by remote
packetConn, err := dialer.ListenPacket(ctx, network, dstIP.String()+":"+port, opts...) metadata := &C.Metadata{
NetWork: C.TCP,
Host: host,
DstPort: port,
}
if ok {
return adapter.DialContext(ctx, metadata, opts...)
}
opts = append(opts, dialer.WithResolver(r))
return dialer.DialContext(ctx, network, addr, opts...)
} else {
// udp must resolve host first
dstIP, err := resolver.ResolveIPWithResolver(ctx, host, r)
if err != nil {
return nil, err
}
metadata := &C.Metadata{
NetWork: C.UDP,
Host: "",
DstIP: dstIP,
DstPort: port,
}
if !ok {
packetConn, err := dialer.ListenPacket(ctx, network, metadata.RemoteAddress(), opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -189,15 +216,12 @@ func dialContextExtra(ctx context.Context, adapterName string, network string, d
PacketConn: packetConn, PacketConn: packetConn,
rAddr: metadata.UDPAddr(), rAddr: metadata.UDPAddr(),
}, nil }, nil
} }
}
if networkType == C.UDP && !adapter.SupportUDP() { if !adapter.SupportUDP() {
return nil, fmt.Errorf("proxy adapter [%s] UDP is not supported", adapterName) return nil, fmt.Errorf("proxy adapter [%s] UDP is not supported", adapterName)
} }
if networkType == C.UDP {
packetConn, err := adapter.ListenPacketContext(ctx, metadata, opts...) packetConn, err := adapter.ListenPacketContext(ctx, metadata, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -208,8 +232,6 @@ func dialContextExtra(ctx context.Context, adapterName string, network string, d
rAddr: metadata.UDPAddr(), rAddr: metadata.UDPAddr(),
}, nil }, nil
} }
return adapter.DialContext(ctx, metadata, opts...)
} }
func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) { func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) {