diff --git a/dns/doh.go b/dns/doh.go index d2bfeb82..34685578 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -528,14 +528,8 @@ func (doh *dnsOverHTTPS) dialQuic(ctx context.Context, addr string, tlsCfg *tls. IP: net.ParseIP(ip), Port: portInt, } - var conn net.PacketConn - if wrapConn, err := dialContextExtra(ctx, doh.proxyAdapter, "udp", addr, doh.r); err == nil { - if pc, ok := wrapConn.(*wrapPacketConn); ok { - conn = pc - } else { - return nil, fmt.Errorf("conn isn't wrapPacketConn") - } - } else { + conn, err := listenPacket(ctx, doh.proxyAdapter, "udp", addr, doh.r) + if err != nil { return nil, err } return quic.DialEarlyContext(ctx, conn, &udpAddr, doh.url.Host, tlsCfg, cfg) @@ -556,20 +550,10 @@ func (doh *dnsOverHTTPS) probeH3( if err != nil { return "", fmt.Errorf("failed to dial: %w", err) } + addr = rawConn.RemoteAddr().String() // It's never actually used. _ = rawConn.Close() - udpConn, ok := rawConn.(*net.UDPConn) - if !ok { - if packetConn, ok := rawConn.(*wrapPacketConn); !ok { - return "", fmt.Errorf("not a UDP connection to %s", doh.Address()) - } else { - addr = packetConn.RemoteAddr().String() - } - } else { - addr = udpConn.RemoteAddr().String() - } - // Avoid spending time on probing if this upstream only supports HTTP/3. if doh.supportsH3() && !doh.supportsHTTP() { return addr, nil diff --git a/dns/doq.go b/dns/doq.go index 3f2b7d07..1c5956af 100644 --- a/dns/doq.go +++ b/dns/doq.go @@ -313,19 +313,9 @@ func (doq *dnsOverQUIC) openConnection(ctx context.Context) (conn quic.Connectio if err != nil { return nil, fmt.Errorf("failed to open a QUIC connection: %w", err) } + addr := rawConn.RemoteAddr().String() // It's never actually used _ = rawConn.Close() - var addr string - udpConn, ok := rawConn.(*net.UDPConn) - if !ok { - if packetConn, ok := rawConn.(*wrapPacketConn); !ok { - return nil, fmt.Errorf("failed to open connection to %s", doq.addr) - } else { - addr = packetConn.RemoteAddr().String() - } - } else { - addr = udpConn.RemoteAddr().String() - } ip, port, err := net.SplitHostPort(addr) if err != nil { @@ -334,14 +324,8 @@ func (doq *dnsOverQUIC) openConnection(ctx context.Context) (conn quic.Connectio p, err := strconv.Atoi(port) udpAddr := net.UDPAddr{IP: net.ParseIP(ip), Port: p} - var udp net.PacketConn - if wrapConn, err := dialContextExtra(ctx, doq.proxyAdapter, "udp", addr, doq.r); err == nil { - if pc, ok := wrapConn.(*wrapPacketConn); ok { - udp = pc - } else { - return nil, fmt.Errorf("quic create packet failed") - } - } else { + udp, err := listenPacket(ctx, doq.proxyAdapter, "udp", addr, doq.r) + if err != nil { return nil, err } diff --git a/dns/util.go b/dns/util.go index f6b9c090..dfd2bafd 100644 --- a/dns/util.go +++ b/dns/util.go @@ -168,70 +168,90 @@ func getDialHandler(r *Resolver, proxyAdapter string, opts ...dialer.Option) dia opts = append(opts, dialer.WithResolver(r)) return dialer.DialContext(ctx, network, addr, opts...) } else { - return dialContextExtra(ctx, proxyAdapter, network, addr, r, opts...) + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + adapter, ok := tunnel.Proxies()[proxyAdapter] + if !ok { + opts = append(opts, dialer.WithInterface(proxyAdapter)) + } + if strings.Contains(network, "tcp") { + // tcp can resolve host by remote + metadata := &C.Metadata{ + NetWork: C.TCP, + Host: host, + DstPort: port, + } + if ok { + return adapter.DialContext(ctx, metadata, opts...) + } + opts = append(opts, dialer.WithResolver(r)) + return dialer.DialContext(ctx, network, addr, opts...) + } else { + // udp must resolve host first + dstIP, err := resolver.ResolveIPWithResolver(ctx, host, r) + if err != nil { + return nil, err + } + metadata := &C.Metadata{ + NetWork: C.UDP, + Host: "", + DstIP: dstIP, + DstPort: port, + } + if !ok { + return dialer.DialContext(ctx, network, addr, opts...) + } + + if !adapter.SupportUDP() { + return nil, fmt.Errorf("proxy adapter [%s] UDP is not supported", proxyAdapter) + } + + packetConn, err := adapter.ListenPacketContext(ctx, metadata, opts...) + if err != nil { + return nil, err + } + + return &wrapPacketConn{ + PacketConn: packetConn, + rAddr: metadata.UDPAddr(), + }, nil + } } } } -func dialContextExtra(ctx context.Context, adapterName string, network string, addr string, r *Resolver, opts ...dialer.Option) (net.Conn, error) { +func listenPacket(ctx context.Context, proxyAdapter string, network string, addr string, r *Resolver, opts ...dialer.Option) (net.PacketConn, error) { host, port, err := net.SplitHostPort(addr) if err != nil { return nil, err } - adapter, ok := tunnel.Proxies()[adapterName] + adapter, ok := tunnel.Proxies()[proxyAdapter] + if !ok && len(proxyAdapter) != 0 { + opts = append(opts, dialer.WithInterface(proxyAdapter)) + } + + // udp must resolve host first + dstIP, err := resolver.ResolveIPWithResolver(ctx, host, r) + if err != nil { + return nil, err + } + metadata := &C.Metadata{ + NetWork: C.UDP, + Host: "", + DstIP: dstIP, + DstPort: port, + } if !ok { - opts = append(opts, dialer.WithInterface(adapterName)) + return dialer.ListenPacket(ctx, dialer.ParseNetwork(network, dstIP), "", opts...) } - if strings.Contains(network, "tcp") { - // tcp can resolve host by remote - metadata := &C.Metadata{ - NetWork: C.TCP, - Host: host, - DstPort: port, - } - if ok { - return adapter.DialContext(ctx, metadata, opts...) - } - opts = append(opts, dialer.WithResolver(r)) - return dialer.DialContext(ctx, network, addr, opts...) - } else { - // udp must resolve host first - dstIP, err := resolver.ResolveIPWithResolver(ctx, host, r) - if err != nil { - return nil, err - } - metadata := &C.Metadata{ - NetWork: C.UDP, - Host: "", - DstIP: dstIP, - DstPort: port, - } - if !ok { - packetConn, err := dialer.ListenPacket(ctx, dialer.ParseNetwork(network, dstIP), "", opts...) - if err != nil { - return nil, err - } - return &wrapPacketConn{ - PacketConn: packetConn, - rAddr: metadata.UDPAddr(), - }, nil - } - - if !adapter.SupportUDP() { - return nil, fmt.Errorf("proxy adapter [%s] UDP is not supported", adapterName) - } - - packetConn, err := adapter.ListenPacketContext(ctx, metadata, opts...) - if err != nil { - return nil, err - } - - return &wrapPacketConn{ - PacketConn: packetConn, - rAddr: metadata.UDPAddr(), - }, nil + if !adapter.SupportUDP() { + return nil, fmt.Errorf("proxy adapter [%s] UDP is not supported", proxyAdapter) } + + return adapter.ListenPacketContext(ctx, metadata, opts...) } func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) {