diff --git a/constant/metadata.go b/constant/metadata.go index 1c344d5d..198c447d 100644 --- a/constant/metadata.go +++ b/constant/metadata.go @@ -205,15 +205,16 @@ func (m *Metadata) Pure() *Metadata { return m } +func (m *Metadata) AddrPort() netip.AddrPort { + port, _ := strconv.ParseUint(m.DstPort, 10, 16) + return netip.AddrPortFrom(m.DstIP.Unmap(), uint16(port)) +} + func (m *Metadata) UDPAddr() *net.UDPAddr { if m.NetWork != UDP || !m.DstIP.IsValid() { return nil } - port, _ := strconv.ParseUint(m.DstPort, 10, 16) - return &net.UDPAddr{ - IP: m.DstIP.AsSlice(), - Port: int(port), - } + return net.UDPAddrFromAddrPort(m.AddrPort()) } func (m *Metadata) String() string { diff --git a/tunnel/connection.go b/tunnel/connection.go index 321c7d06..b130f79a 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -26,7 +26,7 @@ func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata return nil } -func handleUDPToLocal(packet C.UDPPacket, pc N.EnhancePacketConn, key string, oAddr, fAddr netip.Addr) { +func handleUDPToLocal(packet C.UDPPacket, pc N.EnhancePacketConn, key string, oAddrPort netip.AddrPort, fAddr netip.Addr) { defer func() { _ = pc.Close() closeAllLocalCoon(key) @@ -40,18 +40,23 @@ func handleUDPToLocal(packet C.UDPPacket, pc N.EnhancePacketConn, key string, oA return } - fromUDPAddr := from.(*net.UDPAddr) - _fromUDPAddr := *fromUDPAddr - fromUDPAddr = &_fromUDPAddr // make a copy - if fromAddr, ok := netip.AddrFromSlice(fromUDPAddr.IP); ok { - fromAddr = fromAddr.Unmap() - if fAddr.IsValid() && (oAddr.Unmap() == fromAddr) { - fromAddr = fAddr.Unmap() - } - fromUDPAddr.IP = fromAddr.AsSlice() - if fromAddr.Is4() { - fromUDPAddr.Zone = "" // only ipv6 can have the zone + fromUDPAddr, isUDPAddr := from.(*net.UDPAddr) + if isUDPAddr { + _fromUDPAddr := *fromUDPAddr + fromUDPAddr = &_fromUDPAddr // make a copy + if fromAddr, ok := netip.AddrFromSlice(fromUDPAddr.IP); ok { + fromAddr = fromAddr.Unmap() + if fAddr.IsValid() && (oAddrPort.Addr() == fromAddr) { // oAddrPort was Unmapped + fromAddr = fAddr.Unmap() + } + fromUDPAddr.IP = fromAddr.AsSlice() + if fromAddr.Is4() { + fromUDPAddr.Zone = "" // only ipv6 can have the zone + } } + } else { + fromUDPAddr = net.UDPAddrFromAddrPort(oAddrPort) // oAddrPort was Unmapped + log.Warnln("server return a [%T](%s) which isn't a *net.UDPAddr, force replace to (%s), this may be caused by a wrongly implemented server", from, from, oAddrPort) } _, err = packet.WriteBack(data, fromUDPAddr) diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index a4a473e9..4e00aca2 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -383,10 +383,10 @@ func handleUDPConn(packet C.PacketAdapter) { log.Infoln("[UDP] %s --> %s doesn't match any rule using DIRECT", metadata.SourceDetail(), metadata.RemoteAddress()) } - oAddr := metadata.DstIP + oAddrPort := metadata.AddrPort() natTable.Set(key, pc) - go handleUDPToLocal(packet, pc, key, oAddr, fAddr) + go handleUDPToLocal(packet, pc, key, oAddrPort, fAddr) handle() }()