chore: DomainSet now build from a DomainTrie

This commit is contained in:
wwqgtxx 2023-04-01 12:15:03 +08:00
parent cfd03a99c2
commit 54cad53f5f
5 changed files with 72 additions and 36 deletions

View file

@ -23,20 +23,16 @@ type DomainSet struct {
ranks, selects []int32 ranks, selects []int32
} }
// NewDomainSet creates a new *DomainSet struct, from a slice of sorted strings. // NewDomainSet creates a new *DomainSet struct, from a DomainTrie.
func NewDomainSet(keys []string) *DomainSet { func (t *DomainTrie[T]) NewDomainSet() *DomainSet {
domainTrie := New[struct{}]() reserveDomains := make([]string, 0)
for _, domain := range keys { t.Foreach(func(domain string, data T) {
domainTrie.Insert(domain, struct{}{})
}
reserveDomains := make([]string, 0, len(keys))
domainTrie.Foreach(func(domain string, data struct{}) {
reserveDomains = append(reserveDomains, utils.Reverse(domain)) reserveDomains = append(reserveDomains, utils.Reverse(domain))
}) })
// ensure that the same prefix is continuous // ensure that the same prefix is continuous
// and according to the ascending sequence of length // and according to the ascending sequence of length
sort.Strings(reserveDomains) sort.Strings(reserveDomains)
keys = reserveDomains keys := reserveDomains
if len(keys) == 0 { if len(keys) == 0 {
return nil return nil
} }
@ -104,7 +100,7 @@ func (ss *DomainSet) Has(key string) bool {
if j == len(key) { if j == len(key) {
if getBit(ss.leaves, nextNodeId) != 0 { if getBit(ss.leaves, nextNodeId) != 0 {
return true return true
}else { } else {
goto RESTART goto RESTART
} }
} }

View file

@ -7,7 +7,8 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestDomain(t *testing.T) { func TestDomainSet(t *testing.T) {
tree := trie.New[struct{}]()
domainSet := []string{ domainSet := []string{
"baidu.com", "baidu.com",
"google.com", "google.com",
@ -15,14 +16,19 @@ func TestDomain(t *testing.T) {
"test.a.net", "test.a.net",
"test.a.oc", "test.a.oc",
} }
set := trie.NewDomainSet(domainSet)
for _, domain := range domainSet {
assert.NoError(t, tree.Insert(domain, struct{}{}))
}
set := tree.NewDomainSet()
assert.NotNil(t, set) assert.NotNil(t, set)
assert.True(t, set.Has("test.a.net")) assert.True(t, set.Has("test.a.net"))
assert.True(t, set.Has("google.com")) assert.True(t, set.Has("google.com"))
assert.False(t, set.Has("www.baidu.com")) assert.False(t, set.Has("www.baidu.com"))
} }
func TestDomainComplexWildcard(t *testing.T) { func TestDomainSetComplexWildcard(t *testing.T) {
tree := trie.New[struct{}]()
domainSet := []string{ domainSet := []string{
"+.baidu.com", "+.baidu.com",
"+.a.baidu.com", "+.a.baidu.com",
@ -32,14 +38,19 @@ func TestDomainComplexWildcard(t *testing.T) {
"test.a.oc", "test.a.oc",
"www.qq.com", "www.qq.com",
} }
set := trie.NewDomainSet(domainSet)
for _, domain := range domainSet {
assert.NoError(t, tree.Insert(domain, struct{}{}))
}
set := tree.NewDomainSet()
assert.NotNil(t, set) assert.NotNil(t, set)
assert.False(t, set.Has("google.com")) assert.False(t, set.Has("google.com"))
assert.True(t, set.Has("www.baidu.com")) assert.True(t, set.Has("www.baidu.com"))
assert.True(t, set.Has("test.test.baidu.com")) assert.True(t, set.Has("test.test.baidu.com"))
} }
func TestDomainWildcard(t *testing.T) { func TestDomainSetWildcard(t *testing.T) {
tree := trie.New[struct{}]()
domainSet := []string{ domainSet := []string{
"*.*.*.baidu.com", "*.*.*.baidu.com",
"www.baidu.*", "www.baidu.*",
@ -47,14 +58,18 @@ func TestDomainWildcard(t *testing.T) {
"*.*.qq.com", "*.*.qq.com",
"test.*.baidu.com", "test.*.baidu.com",
} }
set := trie.NewDomainSet(domainSet)
for _, domain := range domainSet {
assert.NoError(t, tree.Insert(domain, struct{}{}))
}
set := tree.NewDomainSet()
assert.NotNil(t, set) assert.NotNil(t, set)
assert.True(t, set.Has("www.baidu.com")) assert.True(t, set.Has("www.baidu.com"))
assert.True(t, set.Has("test.test.baidu.com")) assert.True(t, set.Has("test.test.baidu.com"))
assert.True(t, set.Has("test.test.qq.com")) assert.True(t, set.Has("test.test.qq.com"))
assert.True(t,set.Has("stun.ab.cd")) assert.True(t, set.Has("stun.ab.cd"))
assert.False(t, set.Has("test.baidu.com")) assert.False(t, set.Has("test.baidu.com"))
assert.False(t,set.Has("www.google.com")) assert.False(t, set.Has("www.google.com"))
assert.False(t, set.Has("test.qq.com")) assert.False(t, set.Has("test.qq.com"))
assert.False(t, set.Has("test.test.test.qq.com")) assert.False(t, set.Has("test.test.test.qq.com"))
} }

View file

@ -1,16 +1,17 @@
package trie package trie_test
import ( import (
"net/netip" "net/netip"
"testing" "testing"
"github.com/Dreamacro/clash/component/trie"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
var localIP = netip.AddrFrom4([4]byte{127, 0, 0, 1}) var localIP = netip.AddrFrom4([4]byte{127, 0, 0, 1})
func TestTrie_Basic(t *testing.T) { func TestTrie_Basic(t *testing.T) {
tree := New[netip.Addr]() tree := trie.New[netip.Addr]()
domains := []string{ domains := []string{
"example.com", "example.com",
"google.com", "google.com",
@ -18,7 +19,7 @@ func TestTrie_Basic(t *testing.T) {
} }
for _, domain := range domains { for _, domain := range domains {
tree.Insert(domain, localIP) assert.NoError(t, tree.Insert(domain, localIP))
} }
node := tree.Search("example.com") node := tree.Search("example.com")
@ -31,7 +32,7 @@ func TestTrie_Basic(t *testing.T) {
} }
func TestTrie_Wildcard(t *testing.T) { func TestTrie_Wildcard(t *testing.T) {
tree := New[netip.Addr]() tree := trie.New[netip.Addr]()
domains := []string{ domains := []string{
"*.example.com", "*.example.com",
"sub.*.example.com", "sub.*.example.com",
@ -47,7 +48,7 @@ func TestTrie_Wildcard(t *testing.T) {
} }
for _, domain := range domains { for _, domain := range domains {
tree.Insert(domain, localIP) assert.NoError(t, tree.Insert(domain, localIP))
} }
assert.NotNil(t, tree.Search("sub.example.com")) assert.NotNil(t, tree.Search("sub.example.com"))
@ -64,7 +65,7 @@ func TestTrie_Wildcard(t *testing.T) {
} }
func TestTrie_Priority(t *testing.T) { func TestTrie_Priority(t *testing.T) {
tree := New[int]() tree := trie.New[int]()
domains := []string{ domains := []string{
".dev", ".dev",
"example.dev", "example.dev",
@ -79,7 +80,7 @@ func TestTrie_Priority(t *testing.T) {
} }
for idx, domain := range domains { for idx, domain := range domains {
tree.Insert(domain, idx+1) assert.NoError(t, tree.Insert(domain, idx+1))
} }
assertFn("test.dev", 1) assertFn("test.dev", 1)
@ -90,8 +91,8 @@ func TestTrie_Priority(t *testing.T) {
} }
func TestTrie_Boundary(t *testing.T) { func TestTrie_Boundary(t *testing.T) {
tree := New[netip.Addr]() tree := trie.New[netip.Addr]()
tree.Insert("*.dev", localIP) assert.NoError(t, tree.Insert("*.dev", localIP))
assert.NotNil(t, tree.Insert(".", localIP)) assert.NotNil(t, tree.Insert(".", localIP))
assert.NotNil(t, tree.Insert("..dev", localIP)) assert.NotNil(t, tree.Insert("..dev", localIP))
@ -99,15 +100,15 @@ func TestTrie_Boundary(t *testing.T) {
} }
func TestTrie_WildcardBoundary(t *testing.T) { func TestTrie_WildcardBoundary(t *testing.T) {
tree := New[netip.Addr]() tree := trie.New[netip.Addr]()
tree.Insert("+.*", localIP) assert.NoError(t, tree.Insert("+.*", localIP))
tree.Insert("stun.*.*.*", localIP) assert.NoError(t, tree.Insert("stun.*.*.*", localIP))
assert.NotNil(t, tree.Search("example.com")) assert.NotNil(t, tree.Search("example.com"))
} }
func TestTrie_Foreach(t *testing.T) { func TestTrie_Foreach(t *testing.T) {
tree := New[netip.Addr]() tree := trie.New[netip.Addr]()
domainList := []string{ domainList := []string{
"google.com", "google.com",
"stun.*.*.*", "stun.*.*.*",
@ -117,7 +118,7 @@ func TestTrie_Foreach(t *testing.T) {
"*.*.baidu.com", "*.*.baidu.com",
} }
for _, domain := range domainList { for _, domain := range domainList {
tree.Insert(domain, localIP) assert.NoError(t, tree.Insert(domain, localIP))
} }
count := 0 count := 0
tree.Foreach(func(domain string, data netip.Addr) { tree.Foreach(func(domain string, data netip.Addr) {

View file

@ -1340,8 +1340,25 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) {
} }
sniffer.Sniffers = loadSniffer sniffer.Sniffers = loadSniffer
sniffer.ForceDomain = trie.NewDomainSet(snifferRaw.ForceDomain)
sniffer.SkipDomain = trie.NewDomainSet(snifferRaw.SkipDomain) forceDomainTrie := trie.New[struct{}]()
for _, domain := range snifferRaw.ForceDomain {
err := forceDomainTrie.Insert(domain, struct{}{})
if err != nil {
return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err)
}
}
sniffer.ForceDomain = forceDomainTrie.NewDomainSet()
skipDomainTrie := trie.New[struct{}]()
for _, domain := range snifferRaw.SkipDomain {
err := skipDomainTrie.Insert(domain, struct{}{})
if err != nil {
return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err)
}
}
sniffer.SkipDomain = skipDomainTrie.NewDomainSet()
return sniffer, nil return sniffer, nil
} }

View file

@ -3,6 +3,7 @@ package provider
import ( import (
"github.com/Dreamacro/clash/component/trie" "github.com/Dreamacro/clash/component/trie"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
) )
type domainStrategy struct { type domainStrategy struct {
@ -27,8 +28,14 @@ func (d *domainStrategy) ShouldResolveIP() bool {
} }
func (d *domainStrategy) OnUpdate(rules []string) { func (d *domainStrategy) OnUpdate(rules []string) {
domainTrie := trie.NewDomainSet(rules) domainTrie := trie.New[struct{}]()
d.domainRules = domainTrie for _, rule := range rules {
err := domainTrie.Insert(rule, struct{}{})
if err != nil {
log.Warnln("invalid domain:[%s]", rule)
}
}
d.domainRules = domainTrie.NewDomainSet()
d.count = len(rules) d.count = len(rules)
} }