diff --git a/adapter/outbound/hysteria.go b/adapter/outbound/hysteria.go index c38946c8..9dc34b44 100644 --- a/adapter/outbound/hysteria.go +++ b/adapter/outbound/hysteria.go @@ -58,7 +58,6 @@ func (h *Hysteria) DialContext(ctx context.Context, metadata *C.Metadata, opts . remoteAddr: func(addr string) (net.Addr, error) { return resolveUDPAddrWithPrefer(ctx, "udp", addr, h.prefer) }, - network: "udp", } tcpConn, err := h.client.DialTCP(metadata.RemoteAddress(), &hdc) @@ -78,7 +77,6 @@ func (h *Hysteria) ListenPacketContext(ctx context.Context, metadata *C.Metadata remoteAddr: func(addr string) (net.Addr, error) { return resolveUDPAddrWithPrefer(ctx, "udp", addr, h.prefer) }, - network: "udp", } udpConn, err := h.client.DialUDP(&hdc) if err != nil { @@ -331,11 +329,16 @@ type hyDialerWithContext struct { hyDialer func(network string) (net.PacketConn, error) ctx context.Context remoteAddr func(host string) (net.Addr, error) - network string } -func (h *hyDialerWithContext) ListenPacket() (net.PacketConn, error) { - return h.hyDialer(h.network) +func (h *hyDialerWithContext) ListenPacket(rAddr net.Addr) (net.PacketConn, error) { + network := "udp" + if addrPort, err := netip.ParseAddrPort(rAddr.String()); err == nil { + if addrPort.Addr().Is6() { + network = "udp6" + } + } + return h.hyDialer(network) } func (h *hyDialerWithContext) Context() context.Context { @@ -343,16 +346,5 @@ func (h *hyDialerWithContext) Context() context.Context { } func (h *hyDialerWithContext) RemoteAddr(host string) (net.Addr, error) { - addr, err := h.remoteAddr(host) - if err != nil { - return nil, err - } - if addrPort, err := netip.ParseAddrPort(addr.String()); err != nil { - if addrPort.Addr().Is6() { - h.network = "udp6" - } else { - h.network = "udp" - } - } - return addr, nil + return h.remoteAddr(host) } diff --git a/transport/hysteria/conns/udp/hop.go b/transport/hysteria/conns/udp/hop.go index e4958821..4097b692 100644 --- a/transport/hysteria/conns/udp/hop.go +++ b/transport/hysteria/conns/udp/hop.go @@ -90,7 +90,7 @@ func NewObfsUDPHopClientPacketConn(server string, hopInterval time.Duration, obf }, }, } - curConn, err := dialer.ListenPacket() + curConn, err := dialer.ListenPacket(ip) if err != nil { return nil, err } @@ -100,7 +100,7 @@ func NewObfsUDPHopClientPacketConn(server string, hopInterval time.Duration, obf conn.currentConn = curConn } go conn.recvRoutine(conn.currentConn) - go conn.hopRoutine(dialer) + go conn.hopRoutine(dialer, ip) return conn, nil } @@ -120,26 +120,26 @@ func (c *ObfsUDPHopClientPacketConn) recvRoutine(conn net.PacketConn) { } } -func (c *ObfsUDPHopClientPacketConn) hopRoutine(dialer utils.PacketDialer) { +func (c *ObfsUDPHopClientPacketConn) hopRoutine(dialer utils.PacketDialer, rAddr net.Addr) { ticker := time.NewTicker(c.hopInterval) defer ticker.Stop() for { select { case <-ticker.C: - c.hop(dialer) + c.hop(dialer, rAddr) case <-c.closeChan: return } } } -func (c *ObfsUDPHopClientPacketConn) hop(dialer utils.PacketDialer) { +func (c *ObfsUDPHopClientPacketConn) hop(dialer utils.PacketDialer, rAddr net.Addr) { c.connMutex.Lock() defer c.connMutex.Unlock() if c.closed { return } - newConn, err := dialer.ListenPacket() + newConn, err := dialer.ListenPacket(rAddr) if err != nil { // Skip this hop if failed to listen return diff --git a/transport/hysteria/transport/client.go b/transport/hysteria/transport/client.go index c30377a3..af0c3ea2 100644 --- a/transport/hysteria/transport/client.go +++ b/transport/hysteria/transport/client.go @@ -20,9 +20,10 @@ type ClientTransport struct { Dialer *net.Dialer } -func (ct *ClientTransport) quicPacketConn(proto string, server string, obfs obfsPkg.Obfuscator, hopInterval time.Duration, dialer utils.PacketDialer) (net.PacketConn, error) { +func (ct *ClientTransport) quicPacketConn(proto string, rAddr net.Addr, obfs obfsPkg.Obfuscator, hopInterval time.Duration, dialer utils.PacketDialer) (net.PacketConn, error) { + server := rAddr.String() if len(proto) == 0 || proto == "udp" { - conn, err := dialer.ListenPacket() + conn, err := dialer.ListenPacket(rAddr) if err != nil { return nil, err } @@ -39,7 +40,7 @@ func (ct *ClientTransport) quicPacketConn(proto string, server string, obfs obfs return conn, nil } } else if proto == "wechat-video" { - conn, err := dialer.ListenPacket() + conn, err := dialer.ListenPacket(rAddr) if err != nil { return nil, err } @@ -70,7 +71,7 @@ func (ct *ClientTransport) QUICDial(proto string, server string, tlsConfig *tls. return nil, err } - pktConn, err := ct.quicPacketConn(proto, serverUDPAddr.String(), obfs, hopInterval, dialer) + pktConn, err := ct.quicPacketConn(proto, serverUDPAddr, obfs, hopInterval, dialer) if err != nil { return nil, err } diff --git a/transport/hysteria/utils/misc.go b/transport/hysteria/utils/misc.go index 5d5159fc..670d737c 100644 --- a/transport/hysteria/utils/misc.go +++ b/transport/hysteria/utils/misc.go @@ -43,7 +43,7 @@ func last(s string, b byte) int { } type PacketDialer interface { - ListenPacket() (net.PacketConn, error) + ListenPacket(rAddr net.Addr) (net.PacketConn, error) Context() context.Context RemoteAddr(host string) (net.Addr, error) }