diff --git a/component/trie/ipcidr_trie.go b/component/trie/ipcidr_trie.go index a3a63f95..08edbbeb 100644 --- a/component/trie/ipcidr_trie.go +++ b/component/trie/ipcidr_trie.go @@ -1,8 +1,9 @@ package trie import ( - "github.com/Dreamacro/clash/log" "net" + + "github.com/Dreamacro/clash/log" ) type IPV6 bool @@ -47,11 +48,10 @@ func (trie *IpCidrTrie) AddIpCidrForString(ipCidr string) error { } func (trie *IpCidrTrie) IsContain(ip net.IP) bool { - ip, isIpv4 := checkAndConverterIp(ip) if ip == nil { return false } - + isIpv4 := len(ip) == net.IPv4len var groupValues []uint32 var ipCidrNode *IpCidrNode @@ -71,7 +71,13 @@ func (trie *IpCidrTrie) IsContain(ip net.IP) bool { } func (trie *IpCidrTrie) IsContainForString(ipString string) bool { - return trie.IsContain(net.ParseIP(ipString)) + ip := net.ParseIP(ipString) + // deal with 4in6 + actualIp := ip.To4() + if actualIp == nil { + actualIp = ip + } + return trie.IsContain(actualIp) } func ipCidrToSubIpCidr(ipNet *net.IPNet) ([]net.IP, int, bool, error) { @@ -82,9 +88,8 @@ func ipCidrToSubIpCidr(ipNet *net.IPNet) ([]net.IP, int, bool, error) { isIpv4 bool err error ) - - ip, isIpv4 := checkAndConverterIp(ipNet.IP) - ipList, newMaskSize, err = subIpCidr(ip, maskSize, isIpv4) + isIpv4 = len(ipNet.IP) == net.IPv4len + ipList, newMaskSize, err = subIpCidr(ipNet.IP, maskSize, isIpv4) return ipList, newMaskSize, isIpv4, err } @@ -238,18 +243,3 @@ func search(root *IpCidrNode, groupValues []uint32) *IpCidrNode { return nil } - -// return net.IP To4 or To16 and is ipv4 -func checkAndConverterIp(ip net.IP) (net.IP, bool) { - ipResult := ip.To4() - if ipResult == nil { - ipResult = ip.To16() - if ipResult == nil { - return nil, false - } - - return ipResult, false - } - - return ipResult, true -} diff --git a/component/trie/trie_test.go b/component/trie/trie_test.go index dca77c05..e1b20103 100644 --- a/component/trie/trie_test.go +++ b/component/trie/trie_test.go @@ -3,8 +3,9 @@ package trie import ( "net" "testing" + + "github.com/stretchr/testify/assert" ) -import "github.com/stretchr/testify/assert" func TestIpv4AddSuccess(t *testing.T) { trie := NewIpCidrTrie() @@ -96,5 +97,11 @@ func TestIpv6Search(t *testing.T) { assert.Equal(t, true, trie.IsContainForString("2001:67c:4e8:9666::1213")) assert.Equal(t, false, trie.IsContain(net.ParseIP("22233:22"))) - +} + +func TestIpv4InIpv6(t *testing.T) { + trie := NewIpCidrTrie() + + // Boundary testing + assert.NoError(t, trie.AddIpCidrForString("::ffff:198.18.5.138/128")) }