From e2e0fd4ebaafc168e5e1708b042b624d8acba1f1 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Wed, 9 Aug 2023 13:51:02 +0800 Subject: [PATCH] chore: using uint16 for ports in Metadata --- adapter/adapter.go | 7 ++++++- adapter/inbound/socket.go | 5 ++++- adapter/inbound/util.go | 26 ++++++++++++++++++-------- adapter/outbound/snell.go | 6 ++---- adapter/outbound/util.go | 3 +-- adapter/outbound/vless.go | 5 ++--- adapter/outbound/wireguard.go | 6 ++---- component/sniffer/dispatcher.go | 31 ++++++++++++------------------- constant/metadata.go | 22 +++++++++++++--------- dns/util.go | 15 ++++++++++++--- rules/common/port.go | 9 +-------- rules/logic_test/logic_test.go | 4 ++-- test/clash_test.go | 12 ++++++------ transport/tuic/v4/protocol.go | 4 +--- transport/tuic/v5/protocol.go | 4 +--- tunnel/tunnel.go | 4 +--- 16 files changed, 84 insertions(+), 79 deletions(-) diff --git a/adapter/adapter.go b/adapter/adapter.go index 20de5f29..6cc79c3a 100644 --- a/adapter/adapter.go +++ b/adapter/adapter.go @@ -9,6 +9,7 @@ import ( "net/http" "net/netip" "net/url" + "strconv" "time" "github.com/Dreamacro/clash/common/atomic" @@ -327,11 +328,15 @@ func urlToMetadata(rawURL string) (addr C.Metadata, err error) { return } } + uintPort, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return + } addr = C.Metadata{ Host: u.Hostname(), DstIP: netip.Addr{}, - DstPort: port, + DstPort: uint16(uintPort), } return } diff --git a/adapter/inbound/socket.go b/adapter/inbound/socket.go index e41ee925..d75901f1 100644 --- a/adapter/inbound/socket.go +++ b/adapter/inbound/socket.go @@ -3,6 +3,7 @@ package inbound import ( "net" "net/netip" + "strconv" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/context" @@ -37,7 +38,9 @@ func NewInner(conn net.Conn, address string) *context.ConnContext { metadata.DNSMode = C.DNSNormal metadata.Process = C.ClashName if h, port, err := net.SplitHostPort(address); err == nil { - metadata.DstPort = port + if port, err := strconv.ParseUint(port, 10, 16); err == nil { + metadata.DstPort = uint16(port) + } if ip, err := netip.ParseAddr(h); err == nil { metadata.DstIP = ip } else { diff --git a/adapter/inbound/util.go b/adapter/inbound/util.go index 88e989f9..626687c0 100644 --- a/adapter/inbound/util.go +++ b/adapter/inbound/util.go @@ -20,14 +20,14 @@ func parseSocksAddr(target socks5.Addr) *C.Metadata { case socks5.AtypDomainName: // trim for FQDN metadata.Host = strings.TrimRight(string(target[2:2+target[1]]), ".") - metadata.DstPort = strconv.Itoa((int(target[2+target[1]]) << 8) | int(target[2+target[1]+1])) + metadata.DstPort = uint16((int(target[2+target[1]]) << 8) | int(target[2+target[1]+1])) case socks5.AtypIPv4: metadata.DstIP = nnip.IpToAddr(net.IP(target[1 : 1+net.IPv4len])) - metadata.DstPort = strconv.Itoa((int(target[1+net.IPv4len]) << 8) | int(target[1+net.IPv4len+1])) + metadata.DstPort = uint16((int(target[1+net.IPv4len]) << 8) | int(target[1+net.IPv4len+1])) case socks5.AtypIPv6: ip6, _ := netip.AddrFromSlice(target[1 : 1+net.IPv6len]) metadata.DstIP = ip6.Unmap() - metadata.DstPort = strconv.Itoa((int(target[1+net.IPv6len]) << 8) | int(target[1+net.IPv6len+1])) + metadata.DstPort = uint16((int(target[1+net.IPv6len]) << 8) | int(target[1+net.IPv6len+1])) } return metadata @@ -43,11 +43,16 @@ func parseHTTPAddr(request *http.Request) *C.Metadata { // trim FQDN (#737) host = strings.TrimRight(host, ".") + var uint16Port uint16 + if port, err := strconv.ParseUint(port, 10, 16); err == nil { + uint16Port = uint16(port) + } + metadata := &C.Metadata{ NetWork: C.TCP, Host: host, DstIP: netip.Addr{}, - DstPort: port, + DstPort: uint16Port, } ip, err := netip.ParseAddr(host) @@ -58,10 +63,10 @@ func parseHTTPAddr(request *http.Request) *C.Metadata { return metadata } -func parseAddr(addr net.Addr) (netip.Addr, string, error) { +func parseAddr(addr net.Addr) (netip.Addr, uint16, error) { // Filter when net.Addr interface is nil if addr == nil { - return netip.Addr{}, "", errors.New("nil addr") + return netip.Addr{}, 0, errors.New("nil addr") } if rawAddr, ok := addr.(interface{ RawAddr() net.Addr }); ok { ip, port, err := parseAddr(rawAddr.RawAddr()) @@ -72,9 +77,14 @@ func parseAddr(addr net.Addr) (netip.Addr, string, error) { addrStr := addr.String() host, port, err := net.SplitHostPort(addrStr) if err != nil { - return netip.Addr{}, "", err + return netip.Addr{}, 0, err + } + + var uint16Port uint16 + if port, err := strconv.ParseUint(port, 10, 16); err == nil { + uint16Port = uint16(port) } ip, err := netip.ParseAddr(host) - return ip, port, err + return ip, uint16Port, err } diff --git a/adapter/outbound/snell.go b/adapter/outbound/snell.go index fc1f4eb3..e542d84d 100644 --- a/adapter/outbound/snell.go +++ b/adapter/outbound/snell.go @@ -59,8 +59,7 @@ func (s *Snell) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.M err := snell.WriteUDPHeader(c, s.version) return c, err } - port, _ := strconv.ParseUint(metadata.DstPort, 10, 16) - err := snell.WriteHeader(c, metadata.String(), uint(port), s.version) + err := snell.WriteHeader(c, metadata.String(), uint(metadata.DstPort), s.version) return c, err } @@ -72,8 +71,7 @@ func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata, opts ...d return nil, err } - port, _ := strconv.ParseUint(metadata.DstPort, 10, 16) - if err = snell.WriteHeader(c, metadata.String(), uint(port), s.version); err != nil { + if err = snell.WriteHeader(c, metadata.String(), uint(metadata.DstPort), s.version); err != nil { c.Close() return nil, err } diff --git a/adapter/outbound/util.go b/adapter/outbound/util.go index 0504d005..7f3ec4c3 100644 --- a/adapter/outbound/util.go +++ b/adapter/outbound/util.go @@ -6,7 +6,6 @@ import ( "crypto/tls" "net" "net/netip" - "strconv" "sync" "time" @@ -38,7 +37,7 @@ func serializesSocksAddr(metadata *C.Metadata) []byte { var buf [][]byte addrType := metadata.AddrType() aType := uint8(addrType) - p, _ := strconv.ParseUint(metadata.DstPort, 10, 16) + p := uint(metadata.DstPort) port := []byte{uint8(p >> 8), uint8(p & 0xff)} switch addrType { case socks5.AtypDomainName: diff --git a/adapter/outbound/vless.go b/adapter/outbound/vless.go index 6423eb29..44d05ba6 100644 --- a/adapter/outbound/vless.go +++ b/adapter/outbound/vless.go @@ -179,7 +179,7 @@ func (v *Vless) streamConn(c net.Conn, metadata *C.Metadata) (conn net.Conn, err metadata = &C.Metadata{ NetWork: C.UDP, Host: packetaddr.SeqPacketMagicAddress, - DstPort: "443", + DstPort: 443, } } else { metadata = &C.Metadata{ // a clear metadata only contains ip @@ -399,12 +399,11 @@ func parseVlessAddr(metadata *C.Metadata, xudp bool) *vless.DstAddr { copy(addr[1:], metadata.Host) } - port, _ := strconv.ParseUint(metadata.DstPort, 10, 16) return &vless.DstAddr{ UDP: metadata.NetWork == C.UDP, AddrType: addrType, Addr: addr, - Port: uint16(port), + Port: metadata.DstPort, Mux: metadata.NetWork == C.UDP && xudp, } } diff --git a/adapter/outbound/wireguard.go b/adapter/outbound/wireguard.go index c12321f3..e6738596 100644 --- a/adapter/outbound/wireguard.go +++ b/adapter/outbound/wireguard.go @@ -374,8 +374,7 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts options = append(options, dialer.WithNetDialer(wgNetDialer{tunDevice: w.tunDevice})) conn, err = dialer.NewDialer(options...).DialContext(ctx, "tcp", metadata.RemoteAddress()) } else { - port, _ := strconv.Atoi(metadata.DstPort) - conn, err = w.tunDevice.DialContext(ctx, "tcp", M.SocksaddrFrom(metadata.DstIP, uint16(port)).Unwrap()) + conn, err = w.tunDevice.DialContext(ctx, "tcp", M.SocksaddrFrom(metadata.DstIP, metadata.DstPort).Unwrap()) } if err != nil { return nil, err @@ -412,8 +411,7 @@ func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadat } metadata.DstIP = ip } - port, _ := strconv.Atoi(metadata.DstPort) - pc, err = w.tunDevice.ListenPacket(ctx, M.SocksaddrFrom(metadata.DstIP, uint16(port)).Unwrap()) + pc, err = w.tunDevice.ListenPacket(ctx, M.SocksaddrFrom(metadata.DstIP, metadata.DstPort).Unwrap()) if err != nil { return nil, err } diff --git a/component/sniffer/dispatcher.go b/component/sniffer/dispatcher.go index fa1c6827..f813eec2 100644 --- a/component/sniffer/dispatcher.go +++ b/component/sniffer/dispatcher.go @@ -5,7 +5,6 @@ import ( "fmt" "net" "net/netip" - "strconv" "sync" "time" @@ -26,29 +25,23 @@ var ( var Dispatcher *SnifferDispatcher type SnifferDispatcher struct { - enable bool - sniffers map[sniffer.Sniffer]SnifferConfig - forceDomain *trie.DomainSet - skipSNI *trie.DomainSet - skipList *cache.LruCache[string, uint8] - rwMux sync.RWMutex - forceDnsMapping bool - parsePureIp bool + enable bool + sniffers map[sniffer.Sniffer]SnifferConfig + forceDomain *trie.DomainSet + skipSNI *trie.DomainSet + skipList *cache.LruCache[string, uint8] + rwMux sync.RWMutex + forceDnsMapping bool + parsePureIp bool } func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata) { if (metadata.Host == "" && sd.parsePureIp) || sd.forceDomain.Has(metadata.Host) || (metadata.DNSMode == C.DNSMapping && sd.forceDnsMapping) { - port, err := strconv.ParseUint(metadata.DstPort, 10, 16) - if err != nil { - log.Debugln("[Sniffer] Dst port is error") - return - } - inWhitelist := false overrideDest := false for sniffer, config := range sd.sniffers { if sniffer.SupportNetwork() == C.TCP || sniffer.SupportNetwork() == C.ALLNet { - inWhitelist = sniffer.SupportPort(uint16(port)) + inWhitelist = sniffer.SupportPort(metadata.DstPort) if inWhitelist { overrideDest = config.OverrideDest break @@ -61,7 +54,7 @@ func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata } sd.rwMux.RLock() - dst := fmt.Sprintf("%s:%s", metadata.DstIP, metadata.DstPort) + dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort) if count, ok := sd.skipList.Get(dst); ok && count > 5 { log.Debugln("[Sniffer] Skip sniffing[%s] due to multiple failures", dst) defer sd.rwMux.RUnlock() @@ -71,7 +64,7 @@ func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata if host, err := sd.sniffDomain(conn, metadata); err != nil { sd.cacheSniffFailed(metadata) - log.Debugln("[Sniffer] All sniffing sniff failed with from [%s:%s] to [%s:%s]", metadata.SrcIP, metadata.SrcPort, metadata.String(), metadata.DstPort) + log.Debugln("[Sniffer] All sniffing sniff failed with from [%s:%d] to [%s:%d]", metadata.SrcIP, metadata.SrcPort, metadata.String(), metadata.DstPort) return } else { if sd.skipSNI.Has(host) { @@ -149,7 +142,7 @@ func (sd *SnifferDispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metad func (sd *SnifferDispatcher) cacheSniffFailed(metadata *C.Metadata) { sd.rwMux.Lock() - dst := fmt.Sprintf("%s:%s", metadata.DstIP, metadata.DstPort) + dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort) count, _ := sd.skipList.Get(dst) if count <= 5 { count++ diff --git a/constant/metadata.go b/constant/metadata.go index de26a05f..dbd31fd8 100644 --- a/constant/metadata.go +++ b/constant/metadata.go @@ -128,10 +128,10 @@ type Metadata struct { Type Type `json:"type"` SrcIP netip.Addr `json:"sourceIP"` DstIP netip.Addr `json:"destinationIP"` - SrcPort string `json:"sourcePort"` - DstPort string `json:"destinationPort"` + SrcPort uint16 `json:"sourcePort,string"` // `,string` is used to compatible with old version json output + DstPort uint16 `json:"destinationPort,string"` // `,string` is used to compatible with old version json output InIP netip.Addr `json:"inboundIP"` - InPort string `json:"inboundPort"` + InPort uint16 `json:"inboundPort,string"` // `,string` is used to compatible with old version json output InName string `json:"inboundName"` InUser string `json:"inboundUser"` Host string `json:"host"` @@ -147,11 +147,11 @@ type Metadata struct { } func (m *Metadata) RemoteAddress() string { - return net.JoinHostPort(m.String(), m.DstPort) + return net.JoinHostPort(m.String(), strconv.FormatUint(uint64(m.DstPort), 10)) } func (m *Metadata) SourceAddress() string { - return net.JoinHostPort(m.SrcIP.String(), m.SrcPort) + return net.JoinHostPort(m.SrcIP.String(), strconv.FormatUint(uint64(m.SrcPort), 10)) } func (m *Metadata) SourceDetail() string { @@ -172,7 +172,7 @@ func (m *Metadata) SourceDetail() string { } func (m *Metadata) SourceValid() bool { - return m.SrcPort != "" && m.SrcIP.IsValid() + return m.SrcPort != 0 && m.SrcIP.IsValid() } func (m *Metadata) AddrType() int { @@ -211,8 +211,7 @@ func (m *Metadata) Pure() *Metadata { } func (m *Metadata) AddrPort() netip.AddrPort { - port, _ := strconv.ParseUint(m.DstPort, 10, 16) - return netip.AddrPortFrom(m.DstIP.Unmap(), uint16(port)) + return netip.AddrPortFrom(m.DstIP.Unmap(), m.DstPort) } func (m *Metadata) UDPAddr() *net.UDPAddr { @@ -242,6 +241,11 @@ func (m *Metadata) SetRemoteAddress(rawAddress string) error { return err } + var uint16Port uint16 + if port, err := strconv.ParseUint(port, 10, 16); err == nil { + uint16Port = uint16(port) + } + if ip, err := netip.ParseAddr(host); err != nil { m.Host = host m.DstIP = netip.Addr{} @@ -249,7 +253,7 @@ func (m *Metadata) SetRemoteAddress(rawAddress string) error { m.Host = "" m.DstIP = ip.Unmap() } - m.DstPort = port + m.DstPort = uint16Port return nil } diff --git a/dns/util.go b/dns/util.go index 739fd16b..77f677cb 100644 --- a/dns/util.go +++ b/dns/util.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "net/netip" + "strconv" "strings" "time" @@ -193,6 +194,10 @@ func getDialHandler(r *Resolver, proxyAdapter C.ProxyAdapter, proxyName string, if err != nil { return nil, err } + uintPort, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return nil, err + } if proxyAdapter == nil { var ok bool proxyAdapter, ok = tunnel.Proxies()[proxyName] @@ -206,7 +211,7 @@ func getDialHandler(r *Resolver, proxyAdapter C.ProxyAdapter, proxyName string, metadata := &C.Metadata{ NetWork: C.TCP, Host: host, - DstPort: port, + DstPort: uint16(uintPort), } if proxyAdapter != nil { if proxyAdapter.IsL3Protocol(metadata) { // L3 proxy should resolve domain before to avoid loopback @@ -231,7 +236,7 @@ func getDialHandler(r *Resolver, proxyAdapter C.ProxyAdapter, proxyName string, NetWork: C.UDP, Host: "", DstIP: dstIP, - DstPort: port, + DstPort: uint16(uintPort), } if proxyAdapter == nil { return dialer.DialContext(ctx, network, addr, opts...) @@ -257,6 +262,10 @@ func listenPacket(ctx context.Context, proxyAdapter C.ProxyAdapter, proxyName st if err != nil { return nil, err } + uintPort, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return nil, err + } if proxyAdapter == nil { var ok bool proxyAdapter, ok = tunnel.Proxies()[proxyName] @@ -274,7 +283,7 @@ func listenPacket(ctx context.Context, proxyAdapter C.ProxyAdapter, proxyName st NetWork: C.UDP, Host: "", DstIP: dstIP, - DstPort: port, + DstPort: uint16(uintPort), } if proxyAdapter == nil { return dialer.ListenPacket(ctx, dialer.ParseNetwork(network, dstIP), "", opts...) diff --git a/rules/common/port.go b/rules/common/port.go index aeacc4dc..334d083f 100644 --- a/rules/common/port.go +++ b/rules/common/port.go @@ -2,7 +2,6 @@ package common import ( "fmt" - "strconv" "github.com/Dreamacro/clash/common/utils" C "github.com/Dreamacro/clash/constant" @@ -28,7 +27,7 @@ func (p *Port) Match(metadata *C.Metadata) (bool, string) { case C.SrcPort: targetPort = metadata.SrcPort } - return p.matchPortReal(targetPort), p.adapter + return p.portRanges.Check(targetPort), p.adapter } func (p *Port) Adapter() string { @@ -39,12 +38,6 @@ func (p *Port) Payload() string { return p.port } -func (p *Port) matchPortReal(portRef string) bool { - port, _ := strconv.Atoi(portRef) - - return p.portRanges.Check(uint16(port)) -} - func NewPort(port string, adapter string, ruleType C.RuleType) (*Port, error) { portRanges, err := utils.NewIntRanges[uint16](port) if err != nil { diff --git a/rules/logic_test/logic_test.go b/rules/logic_test/logic_test.go index de5ae569..52318b3f 100644 --- a/rules/logic_test/logic_test.go +++ b/rules/logic_test/logic_test.go @@ -20,7 +20,7 @@ func TestAND(t *testing.T) { m, _ := and.Match(&C.Metadata{ Host: "baidu.com", NetWork: C.TCP, - DstPort: "20000", + DstPort: 20000, }) assert.Equal(t, true, m) @@ -35,7 +35,7 @@ func TestNOT(t *testing.T) { not, err := NewNOT("((DST-PORT,6000-6500))", "REJECT", ParseRule) assert.Equal(t, nil, err) m, _ := not.Match(&C.Metadata{ - DstPort: "6100", + DstPort: 6100, }) assert.Equal(t, false, m) diff --git a/test/clash_test.go b/test/clash_test.go index 3fdca5d0..60b99791 100644 --- a/test/clash_test.go +++ b/test/clash_test.go @@ -556,7 +556,7 @@ func testSuit(t *testing.T, proxy C.ProxyAdapter) { assert.NoError(t, testPingPongWithConn(t, func() net.Conn { conn, err := proxy.DialContext(context.Background(), &C.Metadata{ Host: localIP.String(), - DstPort: "10001", + DstPort: 10001, }) require.NoError(t, err) return conn @@ -565,7 +565,7 @@ func testSuit(t *testing.T, proxy C.ProxyAdapter) { assert.NoError(t, testLargeDataWithConn(t, func() net.Conn { conn, err := proxy.DialContext(context.Background(), &C.Metadata{ Host: localIP.String(), - DstPort: "10001", + DstPort: 10001, }) require.NoError(t, err) return conn @@ -578,7 +578,7 @@ func testSuit(t *testing.T, proxy C.ProxyAdapter) { pc, err := proxy.ListenPacketContext(context.Background(), &C.Metadata{ NetWork: C.UDP, DstIP: localIP, - DstPort: "10001", + DstPort: 10001, }) require.NoError(t, err) defer pc.Close() @@ -588,7 +588,7 @@ func testSuit(t *testing.T, proxy C.ProxyAdapter) { pc, err = proxy.ListenPacketContext(context.Background(), &C.Metadata{ NetWork: C.UDP, DstIP: localIP, - DstPort: "10001", + DstPort: 10001, }) require.NoError(t, err) defer pc.Close() @@ -598,7 +598,7 @@ func testSuit(t *testing.T, proxy C.ProxyAdapter) { pc, err = proxy.ListenPacketContext(context.Background(), &C.Metadata{ NetWork: C.UDP, DstIP: localIP, - DstPort: "10001", + DstPort: 10001, }) require.NoError(t, err) defer pc.Close() @@ -635,7 +635,7 @@ func benchmarkProxy(b *testing.B, proxy C.ProxyAdapter) { conn, err := proxy.DialContext(context.Background(), &C.Metadata{ Host: localIP.String(), - DstPort: "10001", + DstPort: 10001, }) require.NoError(b, err) diff --git a/transport/tuic/v4/protocol.go b/transport/tuic/v4/protocol.go index 11ac3b4e..bbdca67c 100644 --- a/transport/tuic/v4/protocol.go +++ b/transport/tuic/v4/protocol.go @@ -457,12 +457,10 @@ func NewAddress(metadata *C.Metadata) Address { copy(addr[1:], metadata.Host) } - port, _ := strconv.ParseUint(metadata.DstPort, 10, 16) - return Address{ TYPE: addrType, ADDR: addr, - PORT: uint16(port), + PORT: metadata.DstPort, } } diff --git a/transport/tuic/v5/protocol.go b/transport/tuic/v5/protocol.go index 83b44146..964401e1 100644 --- a/transport/tuic/v5/protocol.go +++ b/transport/tuic/v5/protocol.go @@ -436,12 +436,10 @@ func NewAddress(metadata *C.Metadata) Address { copy(addr[1:], metadata.Host) } - port, _ := strconv.ParseUint(metadata.DstPort, 10, 16) - return Address{ TYPE: addrType, ADDR: addr, - PORT: uint16(port), + PORT: metadata.DstPort, } } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index e375f656..d4c15a87 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -7,7 +7,6 @@ import ( "net/netip" "path/filepath" "runtime" - "strconv" "sync" "time" @@ -566,8 +565,7 @@ func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { if attemptProcessLookup && !findProcessMode.Off() && (findProcessMode.Always() || rule.ShouldFindProcess()) { attemptProcessLookup = false - srcPort, _ := strconv.ParseUint(metadata.SrcPort, 10, 16) - uid, path, err := P.FindProcessName(metadata.NetWork.String(), metadata.SrcIP, int(srcPort)) + uid, path, err := P.FindProcessName(metadata.NetWork.String(), metadata.SrcIP, int(metadata.SrcPort)) if err != nil { log.Debugln("[Process] find process %s: %v", metadata.String(), err) } else {