diff --git a/adapter/outbound/hysteria.go b/adapter/outbound/hysteria.go index 490b14e1..4fa40f3a 100644 --- a/adapter/outbound/hysteria.go +++ b/adapter/outbound/hysteria.go @@ -56,7 +56,7 @@ func (h *Hysteria) DialContext(ctx context.Context, metadata *C.Metadata, opts . return dialer.ListenPacket(ctx, "udp", "", h.Base.DialOptions(opts...)...) }, remoteAddr: func(addr string) (net.Addr, error) { - return resolveUDPAddrWithPrefer("udp", addr, h.prefer) + return resolveUDPAddrWithPrefer(ctx, "udp", addr, h.prefer) }, } @@ -75,7 +75,7 @@ func (h *Hysteria) ListenPacketContext(ctx context.Context, metadata *C.Metadata return dialer.ListenPacket(ctx, "udp", "", h.Base.DialOptions(opts...)...) }, remoteAddr: func(addr string) (net.Addr, error) { - return resolveUDPAddrWithPrefer("udp", addr, h.prefer) + return resolveUDPAddrWithPrefer(ctx, "udp", addr, h.prefer) }, } udpConn, err := h.client.DialUDP(&hdc) diff --git a/adapter/outbound/shadowsocks.go b/adapter/outbound/shadowsocks.go index 6eeacf45..cb0ea4fc 100644 --- a/adapter/outbound/shadowsocks.go +++ b/adapter/outbound/shadowsocks.go @@ -109,7 +109,7 @@ func (ss *ShadowSocks) ListenPacketContext(ctx context.Context, metadata *C.Meta return nil, err } - addr, err := resolveUDPAddrWithPrefer("udp", ss.addr, ss.prefer) + addr, err := resolveUDPAddrWithPrefer(ctx, "udp", ss.addr, ss.prefer) if err != nil { pc.Close() return nil, err diff --git a/adapter/outbound/shadowsocksr.go b/adapter/outbound/shadowsocksr.go index 6b6b9a98..8bc9ef65 100644 --- a/adapter/outbound/shadowsocksr.go +++ b/adapter/outbound/shadowsocksr.go @@ -79,7 +79,7 @@ func (ssr *ShadowSocksR) ListenPacketContext(ctx context.Context, metadata *C.Me return nil, err } - addr, err := resolveUDPAddrWithPrefer("udp", ssr.addr, ssr.prefer) + addr, err := resolveUDPAddrWithPrefer(ctx, "udp", ssr.addr, ssr.prefer) if err != nil { pc.Close() return nil, err diff --git a/adapter/outbound/socks5.go b/adapter/outbound/socks5.go index 43900b1e..915e192e 100644 --- a/adapter/outbound/socks5.go +++ b/adapter/outbound/socks5.go @@ -129,7 +129,7 @@ func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata, err = errors.New("invalid UDP bind address") return } else if bindUDPAddr.IP.IsUnspecified() { - serverAddr, err := resolveUDPAddr("udp", ss.Addr()) + serverAddr, err := resolveUDPAddr(ctx, "udp", ss.Addr()) if err != nil { return nil, err } diff --git a/adapter/outbound/util.go b/adapter/outbound/util.go index 06c81868..a3d88a4e 100644 --- a/adapter/outbound/util.go +++ b/adapter/outbound/util.go @@ -2,6 +2,7 @@ package outbound import ( "bytes" + "context" "crypto/tls" xtls "github.com/xtls/go" "net" @@ -63,20 +64,20 @@ func serializesSocksAddr(metadata *C.Metadata) []byte { return bytes.Join(buf, nil) } -func resolveUDPAddr(network, address string) (*net.UDPAddr, error) { +func resolveUDPAddr(ctx context.Context, network, address string) (*net.UDPAddr, error) { host, port, err := net.SplitHostPort(address) if err != nil { return nil, err } - ip, err := resolver.ResolveProxyServerHost(host) + ip, err := resolver.ResolveProxyServerHost(ctx, host) if err != nil { return nil, err } return net.ResolveUDPAddr(network, net.JoinHostPort(ip.String(), port)) } -func resolveUDPAddrWithPrefer(network, address string, prefer C.DNSPrefer) (*net.UDPAddr, error) { +func resolveUDPAddrWithPrefer(ctx context.Context, network, address string, prefer C.DNSPrefer) (*net.UDPAddr, error) { host, port, err := net.SplitHostPort(address) if err != nil { return nil, err @@ -84,12 +85,12 @@ func resolveUDPAddrWithPrefer(network, address string, prefer C.DNSPrefer) (*net var ip netip.Addr switch prefer { case C.IPv4Only: - ip, err = resolver.ResolveIPv4ProxyServerHost(host) + ip, err = resolver.ResolveIPv4ProxyServerHost(ctx, host) case C.IPv6Only: - ip, err = resolver.ResolveIPv6ProxyServerHost(host) + ip, err = resolver.ResolveIPv6ProxyServerHost(ctx, host) case C.IPv6Prefer: var ips []netip.Addr - ips, err = resolver.ResolveAllIPProxyServerHost(host) + ips, err = resolver.LookupIPProxyServerHost(ctx, host) var fallback netip.Addr if err == nil { for _, addr := range ips { @@ -107,7 +108,7 @@ func resolveUDPAddrWithPrefer(network, address string, prefer C.DNSPrefer) (*net default: // C.IPv4Prefer, C.DualStack and other var ips []netip.Addr - ips, err = resolver.ResolveAllIPProxyServerHost(host) + ips, err = resolver.LookupIPProxyServerHost(ctx, host) var fallback netip.Addr if err == nil { for _, addr := range ips { diff --git a/adapter/outbound/vless.go b/adapter/outbound/vless.go index 5eb081c1..b86a9d6f 100644 --- a/adapter/outbound/vless.go +++ b/adapter/outbound/vless.go @@ -233,7 +233,7 @@ func (v *Vless) DialContext(ctx context.Context, metadata *C.Metadata, opts ...d func (v *Vless) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { // vless use stream-oriented udp with a special address, so we needs a net.UDPAddr if !metadata.Resolved() { - ip, err := resolver.ResolveIP(metadata.Host) + ip, err := resolver.ResolveIP(ctx, metadata.Host) if err != nil { return nil, errors.New("can't resolve ip") } diff --git a/adapter/outbound/vmess.go b/adapter/outbound/vmess.go index a66ee40f..a8777bf7 100644 --- a/adapter/outbound/vmess.go +++ b/adapter/outbound/vmess.go @@ -243,7 +243,7 @@ func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata, opts ...d func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { // vmess use stream-oriented udp with a special address, so we needs a net.UDPAddr if !metadata.Resolved() { - ip, err := resolver.ResolveIP(metadata.Host) + ip, err := resolver.ResolveIP(ctx, metadata.Host) if err != nil { return nil, errors.New("can't resolve ip") } diff --git a/adapter/outbound/wireguard.go b/adapter/outbound/wireguard.go index a6808078..eb48c0c0 100644 --- a/adapter/outbound/wireguard.go +++ b/adapter/outbound/wireguard.go @@ -205,7 +205,7 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts } if !metadata.Resolved() { var addrs []netip.Addr - addrs, err = resolver.ResolveAllIP(metadata.Host) + addrs, err = resolver.LookupIP(ctx, metadata.Host) if err != nil { return nil, err } @@ -229,7 +229,7 @@ func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadat return nil, err } if !metadata.Resolved() { - ip, err := resolver.ResolveIP(metadata.Host) + ip, err := resolver.ResolveIP(ctx, metadata.Host) if err != nil { return nil, errors.New("can't resolve ip") } diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index ea4d2ece..330dcf81 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -158,15 +158,15 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt var ip netip.Addr if ipv6 { if !direct { - ip, result.error = resolver.ResolveIPv6ProxyServerHost(host) + ip, result.error = resolver.ResolveIPv6ProxyServerHost(ctx, host) } else { - ip, result.error = resolver.ResolveIPv6(host) + ip, result.error = resolver.ResolveIPv6(ctx, host) } } else { if !direct { - ip, result.error = resolver.ResolveIPv4ProxyServerHost(host) + ip, result.error = resolver.ResolveIPv4ProxyServerHost(ctx, host) } else { - ip, result.error = resolver.ResolveIPv4(host) + ip, result.error = resolver.ResolveIPv4(ctx, host) } } if result.error != nil { @@ -219,9 +219,9 @@ func concurrentDualStackDialContext(ctx context.Context, network, address string var ips []netip.Addr if opt.direct { - ips, err = resolver.ResolveAllIP(host) + ips, err = resolver.LookupIP(ctx, host) } else { - ips, err = resolver.ResolveAllIPProxyServerHost(host) + ips, err = resolver.LookupIPProxyServerHost(ctx, host) } if err != nil { @@ -344,15 +344,15 @@ func singleDialContext(ctx context.Context, network string, address string, opt switch network { case "tcp4", "udp4": if !opt.direct { - ip, err = resolver.ResolveIPv4ProxyServerHost(host) + ip, err = resolver.ResolveIPv4ProxyServerHost(ctx, host) } else { - ip, err = resolver.ResolveIPv4(host) + ip, err = resolver.ResolveIPv4(ctx, host) } default: if !opt.direct { - ip, err = resolver.ResolveIPv6ProxyServerHost(host) + ip, err = resolver.ResolveIPv6ProxyServerHost(ctx, host) } else { - ip, err = resolver.ResolveIPv6(host) + ip, err = resolver.ResolveIPv6(ctx, host) } } if err != nil { @@ -379,9 +379,9 @@ func concurrentIPv4DialContext(ctx context.Context, network, address string, opt var ips []netip.Addr if !opt.direct { - ips, err = resolver.ResolveAllIPv4ProxyServerHost(host) + ips, err = resolver.LookupIPv4ProxyServerHost(ctx, host) } else { - ips, err = resolver.ResolveAllIPv4(host) + ips, err = resolver.LookupIPv4(ctx, host) } if err != nil { @@ -399,9 +399,9 @@ func concurrentIPv6DialContext(ctx context.Context, network, address string, opt var ips []netip.Addr if !opt.direct { - ips, err = resolver.ResolveAllIPv6ProxyServerHost(host) + ips, err = resolver.LookupIPv6ProxyServerHost(ctx, host) } else { - ips, err = resolver.ResolveAllIPv6(host) + ips, err = resolver.LookupIPv6(ctx, host) } if err != nil { diff --git a/component/resolver/local.go b/component/resolver/local.go index be84e693..e8505118 100644 --- a/component/resolver/local.go +++ b/component/resolver/local.go @@ -1,17 +1,21 @@ package resolver -import D "github.com/miekg/dns" +import ( + "context" + + D "github.com/miekg/dns" +) var DefaultLocalServer LocalServer type LocalServer interface { - ServeMsg(msg *D.Msg) (*D.Msg, error) + ServeMsg(ctx context.Context, msg *D.Msg) (*D.Msg, error) } // ServeMsg with a dns.Msg, return resolve dns.Msg -func ServeMsg(msg *D.Msg) (*D.Msg, error) { +func ServeMsg(ctx context.Context, msg *D.Msg) (*D.Msg, error) { if server := DefaultLocalServer; server != nil { - return server.ServeMsg(msg) + return server.ServeMsg(ctx, msg) } return nil, ErrIPNotFound diff --git a/component/resolver/resolver.go b/component/resolver/resolver.go index 51df094f..0c09d23c 100644 --- a/component/resolver/resolver.go +++ b/component/resolver/resolver.go @@ -37,21 +37,21 @@ var ( ) type Resolver interface { - ResolveIP(host string) (ip netip.Addr, err error) - ResolveIPv4(host string) (ip netip.Addr, err error) - ResolveIPv6(host string) (ip netip.Addr, err error) - ResolveAllIP(host string) (ip []netip.Addr, err error) - ResolveAllIPv4(host string) (ips []netip.Addr, err error) - ResolveAllIPv6(host string) (ips []netip.Addr, err error) + LookupIP(ctx context.Context, host string) (ips []netip.Addr, err error) + LookupIPv4(ctx context.Context, host string) (ips []netip.Addr, err error) + LookupIPv6(ctx context.Context, host string) (ips []netip.Addr, err error) + ResolveIP(ctx context.Context, host string) (ip netip.Addr, err error) + ResolveIPv4(ctx context.Context, host string) (ip netip.Addr, err error) + ResolveIPv6(ctx context.Context, host string) (ip netip.Addr, err error) } // ResolveIPv4 with a host, return ipv4 -func ResolveIPv4(host string) (netip.Addr, error) { - return ResolveIPv4WithResolver(host, DefaultResolver) +func ResolveIPv4(ctx context.Context, host string) (netip.Addr, error) { + return ResolveIPv4WithResolver(ctx, host, DefaultResolver) } -func ResolveIPv4WithResolver(host string, r Resolver) (netip.Addr, error) { - if ips, err := ResolveAllIPv4WithResolver(host, r); err == nil { +func ResolveIPv4WithResolver(ctx context.Context, host string, r Resolver) (netip.Addr, error) { + if ips, err := LookupIPv4WithResolver(ctx, host, r); err == nil { return ips[rand.Intn(len(ips))], nil } else { return netip.Addr{}, err @@ -59,12 +59,12 @@ func ResolveIPv4WithResolver(host string, r Resolver) (netip.Addr, error) { } // ResolveIPv6 with a host, return ipv6 -func ResolveIPv6(host string) (netip.Addr, error) { - return ResolveIPv6WithResolver(host, DefaultResolver) +func ResolveIPv6(ctx context.Context, host string) (netip.Addr, error) { + return ResolveIPv6WithResolver(ctx, host, DefaultResolver) } -func ResolveIPv6WithResolver(host string, r Resolver) (netip.Addr, error) { - if ips, err := ResolveAllIPv6WithResolver(host, r); err == nil { +func ResolveIPv6WithResolver(ctx context.Context, host string, r Resolver) (netip.Addr, error) { + if ips, err := LookupIPv6WithResolver(ctx, host, r); err == nil { return ips[rand.Intn(len(ips))], nil } else { return netip.Addr{}, err @@ -72,56 +72,56 @@ func ResolveIPv6WithResolver(host string, r Resolver) (netip.Addr, error) { } // ResolveIPWithResolver same as ResolveIP, but with a resolver -func ResolveIPWithResolver(host string, r Resolver) (netip.Addr, error) { - if ip, err := ResolveIPv4WithResolver(host, r); err == nil { +func ResolveIPWithResolver(ctx context.Context, host string, r Resolver) (netip.Addr, error) { + if ip, err := ResolveIPv4WithResolver(ctx, host, r); err == nil { return ip, nil } else { - return ResolveIPv6WithResolver(host, r) + return ResolveIPv6WithResolver(ctx, host, r) } } // ResolveIP with a host, return ip -func ResolveIP(host string) (netip.Addr, error) { - return ResolveIPWithResolver(host, DefaultResolver) +func ResolveIP(ctx context.Context, host string) (netip.Addr, error) { + return ResolveIPWithResolver(ctx, host, DefaultResolver) } // ResolveIPv4ProxyServerHost proxies server host only -func ResolveIPv4ProxyServerHost(host string) (netip.Addr, error) { +func ResolveIPv4ProxyServerHost(ctx context.Context, host string) (netip.Addr, error) { if ProxyServerHostResolver != nil { - if ip, err := ResolveIPv4WithResolver(host, ProxyServerHostResolver); err != nil { - return ResolveIPv4(host) + if ip, err := ResolveIPv4WithResolver(ctx, host, ProxyServerHostResolver); err != nil { + return ResolveIPv4(ctx, host) } else { return ip, nil } } - return ResolveIPv4(host) + return ResolveIPv4(ctx, host) } // ResolveIPv6ProxyServerHost proxies server host only -func ResolveIPv6ProxyServerHost(host string) (netip.Addr, error) { +func ResolveIPv6ProxyServerHost(ctx context.Context, host string) (netip.Addr, error) { if ProxyServerHostResolver != nil { - if ip, err := ResolveIPv6WithResolver(host, ProxyServerHostResolver); err != nil { - return ResolveIPv6(host) + if ip, err := ResolveIPv6WithResolver(ctx, host, ProxyServerHostResolver); err != nil { + return ResolveIPv6(ctx, host) } else { return ip, nil } } - return ResolveIPv6(host) + return ResolveIPv6(ctx, host) } // ResolveProxyServerHost proxies server host only -func ResolveProxyServerHost(host string) (netip.Addr, error) { +func ResolveProxyServerHost(ctx context.Context, host string) (netip.Addr, error) { if ProxyServerHostResolver != nil { - if ip, err := ResolveIPWithResolver(host, ProxyServerHostResolver); err != nil { - return ResolveIP(host) + if ip, err := ResolveIPWithResolver(ctx, host, ProxyServerHostResolver); err != nil { + return ResolveIP(ctx, host) } else { return ip, err } } - return ResolveIP(host) + return ResolveIP(ctx, host) } -func ResolveAllIPv6WithResolver(host string, r Resolver) ([]netip.Addr, error) { +func LookupIPv6WithResolver(ctx context.Context, host string, r Resolver) ([]netip.Addr, error) { if DisableIPv6 { return []netip.Addr{}, ErrIPv6Disabled } @@ -141,12 +141,10 @@ func ResolveAllIPv6WithResolver(host string, r Resolver) ([]netip.Addr, error) { } if r != nil { - return r.ResolveAllIPv6(host) + return r.LookupIPv6(ctx, host) } if DefaultResolver == nil { - ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) - defer cancel() ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip6", host) if err != nil { return []netip.Addr{}, err @@ -167,7 +165,7 @@ func ResolveAllIPv6WithResolver(host string, r Resolver) ([]netip.Addr, error) { return []netip.Addr{}, ErrIPNotFound } -func ResolveAllIPv4WithResolver(host string, r Resolver) ([]netip.Addr, error) { +func LookupIPv4WithResolver(ctx context.Context, host string, r Resolver) ([]netip.Addr, error) { if node := DefaultHosts.Search(host); node != nil { if ip := node.Data(); ip.Is4() { return []netip.Addr{node.Data()}, nil @@ -183,12 +181,10 @@ func ResolveAllIPv4WithResolver(host string, r Resolver) ([]netip.Addr, error) { } if r != nil { - return r.ResolveAllIPv4(host) + return r.LookupIPv4(ctx, host) } if DefaultResolver == nil { - ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) - defer cancel() ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip4", host) if err != nil { return []netip.Addr{}, err @@ -209,7 +205,7 @@ func ResolveAllIPv4WithResolver(host string, r Resolver) ([]netip.Addr, error) { return []netip.Addr{}, ErrIPNotFound } -func ResolveAllIPWithResolver(host string, r Resolver) ([]netip.Addr, error) { +func LookupIPWithResolver(ctx context.Context, host string, r Resolver) ([]netip.Addr, error) { if node := DefaultHosts.Search(host); node != nil { return []netip.Addr{node.Data()}, nil } @@ -221,16 +217,16 @@ func ResolveAllIPWithResolver(host string, r Resolver) ([]netip.Addr, error) { if r != nil { if DisableIPv6 { - return r.ResolveAllIPv4(host) + return r.LookupIPv4(ctx, host) } - return r.ResolveAllIP(host) + return r.LookupIP(ctx, host) } else if DisableIPv6 { - return ResolveAllIPv4(host) + return LookupIPv4(ctx, host) } if DefaultResolver == nil { - ipAddrs, err := net.DefaultResolver.LookupIP(context.Background(), "ip", host) + ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip", host) if err != nil { return []netip.Addr{}, err } else if len(ipAddrs) == 0 { @@ -249,35 +245,35 @@ func ResolveAllIPWithResolver(host string, r Resolver) ([]netip.Addr, error) { return []netip.Addr{}, ErrIPNotFound } -func ResolveAllIP(host string) ([]netip.Addr, error) { - return ResolveAllIPWithResolver(host, DefaultResolver) +func LookupIP(ctx context.Context, host string) ([]netip.Addr, error) { + return LookupIPWithResolver(ctx, host, DefaultResolver) } -func ResolveAllIPv4(host string) ([]netip.Addr, error) { - return ResolveAllIPv4WithResolver(host, DefaultResolver) +func LookupIPv4(ctx context.Context, host string) ([]netip.Addr, error) { + return LookupIPv4WithResolver(ctx, host, DefaultResolver) } -func ResolveAllIPv6(host string) ([]netip.Addr, error) { - return ResolveAllIPv6WithResolver(host, DefaultResolver) +func LookupIPv6(ctx context.Context, host string) ([]netip.Addr, error) { + return LookupIPv6WithResolver(ctx, host, DefaultResolver) } -func ResolveAllIPv6ProxyServerHost(host string) ([]netip.Addr, error) { +func LookupIPv6ProxyServerHost(ctx context.Context, host string) ([]netip.Addr, error) { if ProxyServerHostResolver != nil { - return ResolveAllIPv6WithResolver(host, ProxyServerHostResolver) + return LookupIPv6WithResolver(ctx, host, ProxyServerHostResolver) } - return ResolveAllIPv6(host) + return LookupIPv6(ctx, host) } -func ResolveAllIPv4ProxyServerHost(host string) ([]netip.Addr, error) { +func LookupIPv4ProxyServerHost(ctx context.Context, host string) ([]netip.Addr, error) { if ProxyServerHostResolver != nil { - return ResolveAllIPv4WithResolver(host, ProxyServerHostResolver) + return LookupIPv4WithResolver(ctx, host, ProxyServerHostResolver) } - return ResolveAllIPv4(host) + return LookupIPv4(ctx, host) } -func ResolveAllIPProxyServerHost(host string) ([]netip.Addr, error) { +func LookupIPProxyServerHost(ctx context.Context, host string) ([]netip.Addr, error) { if ProxyServerHostResolver != nil { - return ResolveAllIPWithResolver(host, ProxyServerHostResolver) + return LookupIPWithResolver(ctx, host, ProxyServerHostResolver) } - return ResolveAllIP(host) + return LookupIP(ctx, host) } diff --git a/context/dns.go b/context/dns.go index 0be4a1fc..59130961 100644 --- a/context/dns.go +++ b/context/dns.go @@ -1,6 +1,8 @@ package context import ( + "context" + "github.com/gofrs/uuid" "github.com/miekg/dns" ) @@ -12,14 +14,18 @@ const ( ) type DNSContext struct { + context.Context + id uuid.UUID msg *dns.Msg tp string } -func NewDNSContext(msg *dns.Msg) *DNSContext { +func NewDNSContext(ctx context.Context, msg *dns.Msg) *DNSContext { id, _ := uuid.NewV4() return &DNSContext{ + Context: ctx, + id: id, msg: msg, } diff --git a/dns/client.go b/dns/client.go index a377ee42..13b01422 100644 --- a/dns/client.go +++ b/dns/client.go @@ -38,7 +38,7 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error) if c.r == nil { return nil, fmt.Errorf("dns %s not a valid ip", c.host) } else { - if ip, err = resolver.ResolveIPWithResolver(c.host, c.r); err != nil { + if ip, err = resolver.ResolveIPWithResolver(ctx, c.host, c.r); err != nil { return nil, fmt.Errorf("use default dns resolve failed: %w", err) } c.host = ip.String() diff --git a/dns/doq.go b/dns/doq.go index 734d26d0..07cceff6 100644 --- a/dns/doq.go +++ b/dns/doq.go @@ -347,7 +347,7 @@ func (doq *dnsOverQUIC) openConnection() (conn quic.Connection, err error) { if err != nil { return nil, err } - + conn, err := dialContextExtra(ctx, doq.proxyAdapter, "udp", ipAddr, port) if err != nil { return nil, err @@ -505,7 +505,7 @@ func getDialHandler(r *Resolver, proxyAdapter string) dialHandler { if err != nil { return nil, err } - ip, err := r.ResolveIP(host) + ip, err := r.ResolveIP(ctx, host) if err != nil { return nil, err } diff --git a/dns/middleware.go b/dns/middleware.go index 0e1335f9..28ced849 100644 --- a/dns/middleware.go +++ b/dns/middleware.go @@ -156,7 +156,7 @@ func withResolver(resolver *Resolver) handler { return handleMsgWithEmptyAnswer(r), nil } - msg, err := resolver.Exchange(r) + msg, err := resolver.ExchangeContext(ctx, r) if err != nil { log.Debugln("[DNS Server] Exchange %s failed: %v", q.String(), err) return msg, err diff --git a/dns/patch.go b/dns/patch.go index 76974243..37b5d41b 100644 --- a/dns/patch.go +++ b/dns/patch.go @@ -1,14 +1,18 @@ package dns -import D "github.com/miekg/dns" +import ( + "context" + + D "github.com/miekg/dns" +) type LocalServer struct { handler handler } // ServeMsg implement resolver.LocalServer ResolveMsg -func (s *LocalServer) ServeMsg(msg *D.Msg) (*D.Msg, error) { - return handlerWithContext(s.handler, msg) +func (s *LocalServer) ServeMsg(ctx context.Context, msg *D.Msg) (*D.Msg, error) { + return handlerWithContext(ctx, s.handler, msg) } func NewLocalServer(resolver *Resolver, mapper *ResolverEnhancer) *LocalServer { diff --git a/dns/resolver.go b/dns/resolver.go index 1184c2e7..eacb7dc4 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -44,18 +44,18 @@ type Resolver struct { proxyServer []dnsClient } -func (r *Resolver) ResolveAllIPPrimaryIPv4(host string) (ips []netip.Addr, err error) { +func (r *Resolver) LookupIPPrimaryIPv4(ctx context.Context, host string) (ips []netip.Addr, err error) { ch := make(chan []netip.Addr, 1) go func() { defer close(ch) - ip, err := r.resolveIP(host, D.TypeAAAA) + ip, err := r.resolveIP(ctx, host, D.TypeAAAA) if err != nil { return } ch <- ip }() - ips, err = r.resolveIP(host, D.TypeA) + ips, err = r.resolveIP(ctx, host, D.TypeA) if err == nil { return } @@ -68,11 +68,11 @@ func (r *Resolver) ResolveAllIPPrimaryIPv4(host string) (ips []netip.Addr, err e return ip, nil } -func (r *Resolver) ResolveAllIP(host string) (ips []netip.Addr, err error) { +func (r *Resolver) LookupIP(ctx context.Context, host string) (ips []netip.Addr, err error) { ch := make(chan []netip.Addr, 1) go func() { defer close(ch) - ip, err := r.resolveIP(host, D.TypeAAAA) + ip, err := r.resolveIP(ctx, host, D.TypeAAAA) if err != nil { return } @@ -80,7 +80,7 @@ func (r *Resolver) ResolveAllIP(host string) (ips []netip.Addr, err error) { ch <- ip }() - ips, err = r.resolveIP(host, D.TypeA) + ips, err = r.resolveIP(ctx, host, D.TypeA) select { case ipv6s, open := <-ch: @@ -95,17 +95,17 @@ func (r *Resolver) ResolveAllIP(host string) (ips []netip.Addr, err error) { return ips, nil } -func (r *Resolver) ResolveAllIPv4(host string) (ips []netip.Addr, err error) { - return r.resolveIP(host, D.TypeA) +func (r *Resolver) LookupIPv4(ctx context.Context, host string) (ips []netip.Addr, err error) { + return r.resolveIP(ctx, host, D.TypeA) } -func (r *Resolver) ResolveAllIPv6(host string) (ips []netip.Addr, err error) { - return r.resolveIP(host, D.TypeAAAA) +func (r *Resolver) LookupIPv6(ctx context.Context, host string) (ips []netip.Addr, err error) { + return r.resolveIP(ctx, host, D.TypeAAAA) } // ResolveIP request with TypeA and TypeAAAA, priority return TypeA -func (r *Resolver) ResolveIP(host string) (ip netip.Addr, err error) { - if ips, err := r.ResolveAllIPPrimaryIPv4(host); err == nil { +func (r *Resolver) ResolveIP(ctx context.Context, host string) (ip netip.Addr, err error) { + if ips, err := r.LookupIPPrimaryIPv4(ctx, host); err == nil { return ips[rand.Intn(len(ips))], nil } else { return netip.Addr{}, err @@ -113,8 +113,8 @@ func (r *Resolver) ResolveIP(host string) (ip netip.Addr, err error) { } // ResolveIPv4 request with TypeA -func (r *Resolver) ResolveIPv4(host string) (ip netip.Addr, err error) { - if ips, err := r.ResolveAllIPv4(host); err == nil { +func (r *Resolver) ResolveIPv4(ctx context.Context, host string) (ip netip.Addr, err error) { + if ips, err := r.LookupIPv4(ctx, host); err == nil { return ips[rand.Intn(len(ips))], nil } else { return netip.Addr{}, err @@ -122,8 +122,8 @@ func (r *Resolver) ResolveIPv4(host string) (ip netip.Addr, err error) { } // ResolveIPv6 request with TypeAAAA -func (r *Resolver) ResolveIPv6(host string) (ip netip.Addr, err error) { - if ips, err := r.ResolveAllIPv6(host); err == nil { +func (r *Resolver) ResolveIPv6(ctx context.Context, host string) (ip netip.Addr, err error) { + if ips, err := r.LookupIPv6(ctx, host); err == nil { return ips[rand.Intn(len(ips))], nil } else { return netip.Addr{}, err @@ -305,7 +305,7 @@ func (r *Resolver) ipExchange(ctx context.Context, m *D.Msg) (msg *D.Msg, err er return } -func (r *Resolver) resolveIP(host string, dnsType uint16) (ips []netip.Addr, err error) { +func (r *Resolver) resolveIP(ctx context.Context, host string, dnsType uint16) (ips []netip.Addr, err error) { ip, err := netip.ParseAddr(host) if err == nil { isIPv4 := ip.Is4() @@ -321,7 +321,7 @@ func (r *Resolver) resolveIP(host string, dnsType uint16) (ips []netip.Addr, err query := &D.Msg{} query.SetQuestion(D.Fqdn(host), dnsType) - msg, err := r.Exchange(query) + msg, err := r.ExchangeContext(ctx, query) if err != nil { return []netip.Addr{}, err } diff --git a/dns/server.go b/dns/server.go index 1fbde824..5c5970db 100644 --- a/dns/server.go +++ b/dns/server.go @@ -1,6 +1,7 @@ package dns import ( + stdContext "context" "errors" "net" @@ -25,7 +26,7 @@ type Server struct { // ServeDNS implement D.Handler ServeDNS func (s *Server) ServeDNS(w D.ResponseWriter, r *D.Msg) { - msg, err := handlerWithContext(s.handler, r) + msg, err := handlerWithContext(stdContext.Background(), s.handler, r) if err != nil { D.HandleFailed(w, r) return @@ -34,12 +35,12 @@ func (s *Server) ServeDNS(w D.ResponseWriter, r *D.Msg) { w.WriteMsg(msg) } -func handlerWithContext(handler handler, msg *D.Msg) (*D.Msg, error) { +func handlerWithContext(stdCtx stdContext.Context, handler handler, msg *D.Msg) (*D.Msg, error) { if len(msg.Question) == 0 { return nil, errors.New("at least one question is required") } - ctx := context.NewDNSContext(msg) + ctx := context.NewDNSContext(stdCtx, msg) return handler(ctx, msg) } diff --git a/listener/sing_tun/dns.go b/listener/sing_tun/dns.go index 39e2b1e5..21dee43c 100644 --- a/listener/sing_tun/dns.go +++ b/listener/sing_tun/dns.go @@ -23,6 +23,7 @@ import ( ) const DefaultDnsReadTimeout = time.Second * 10 +const DefaultDnsRelayTimeout = time.Second * 5 type ListenerHandler struct { sing.ListenerHandler @@ -69,8 +70,10 @@ func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, meta } err = func() error { + ctx, cancel := context.WithTimeout(ctx, DefaultDnsRelayTimeout) + defer cancel() inData := buff[:n] - msg, err := RelayDnsPacket(inData) + msg, err := RelayDnsPacket(ctx, inData) if err != nil { return err } @@ -117,8 +120,10 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network. return err } go func() { + ctx, cancel := context.WithTimeout(ctx, DefaultDnsRelayTimeout) + defer cancel() inData := buff.Bytes() - msg, err := RelayDnsPacket(inData) + msg, err := RelayDnsPacket(ctx, inData) if err != nil { buff.Release() return @@ -146,13 +151,13 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network. return h.ListenerHandler.NewPacketConnection(ctx, conn, metadata) } -func RelayDnsPacket(payload []byte) ([]byte, error) { +func RelayDnsPacket(ctx context.Context, payload []byte) ([]byte, error) { msg := &D.Msg{} if err := msg.Unpack(payload); err != nil { return nil, err } - r, err := resolver.ServeMsg(msg) + r, err := resolver.ServeMsg(ctx, msg) if err != nil { m := new(D.Msg) m.SetRcode(msg, D.RcodeServerFailure) diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 46b603a3..bdac6cce 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -225,7 +225,7 @@ func handleUDPConn(packet *inbound.PacketAdapter) { // local resolve UDP dns if !metadata.Resolved() { - ip, err := resolver.ResolveIP(metadata.Host) + ip, err := resolver.ResolveIP(context.Background(), metadata.Host) if err != nil { return } @@ -400,14 +400,18 @@ func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { for _, rule := range rules { if !resolved && shouldResolveIP(rule, metadata) { - ip, err := resolver.ResolveIP(metadata.Host) - if err != nil { - log.Debugln("[DNS] resolve %s error: %s", metadata.Host, err.Error()) - } else { - log.Debugln("[DNS] %s --> %s", metadata.Host, ip.String()) - metadata.DstIP = ip - } - resolved = true + func() { + ctx, cancel := context.WithTimeout(context.Background(), resolver.DefaultDNSTimeout) + defer cancel() + ip, err := resolver.ResolveIP(ctx, metadata.Host) + if err != nil { + log.Debugln("[DNS] resolve %s error: %s", metadata.Host, err.Error()) + } else { + log.Debugln("[DNS] %s --> %s", metadata.Host, ip.String()) + metadata.DstIP = ip + } + resolved = true + }() } if !processFound && (alwaysFindProcess || rule.ShouldFindProcess()) {