diff --git a/component/trie/sskv.go b/component/trie/domain_set.go similarity index 91% rename from component/trie/sskv.go rename to component/trie/domain_set.go index 6a661a85..ce416e16 100644 --- a/component/trie/sskv.go +++ b/component/trie/domain_set.go @@ -23,20 +23,16 @@ type DomainSet struct { ranks, selects []int32 } -// NewDomainSet creates a new *DomainSet struct, from a slice of sorted strings. -func NewDomainSet(keys []string) *DomainSet { - domainTrie := New[struct{}]() - for _, domain := range keys { - domainTrie.Insert(domain, struct{}{}) - } - reserveDomains := make([]string, 0, len(keys)) - domainTrie.Foreach(func(domain string, data struct{}) { +// NewDomainSet creates a new *DomainSet struct, from a DomainTrie. +func (t *DomainTrie[T]) NewDomainSet() *DomainSet { + reserveDomains := make([]string, 0) + t.Foreach(func(domain string, data T) { reserveDomains = append(reserveDomains, utils.Reverse(domain)) }) // ensure that the same prefix is continuous // and according to the ascending sequence of length sort.Strings(reserveDomains) - keys = reserveDomains + keys := reserveDomains if len(keys) == 0 { return nil } @@ -104,7 +100,7 @@ func (ss *DomainSet) Has(key string) bool { if j == len(key) { if getBit(ss.leaves, nextNodeId) != 0 { return true - }else { + } else { goto RESTART } } diff --git a/component/trie/set_test.go b/component/trie/domain_set_test.go similarity index 62% rename from component/trie/set_test.go rename to component/trie/domain_set_test.go index 346bb31a..090bd495 100644 --- a/component/trie/set_test.go +++ b/component/trie/domain_set_test.go @@ -7,7 +7,8 @@ import ( "github.com/stretchr/testify/assert" ) -func TestDomain(t *testing.T) { +func TestDomainSet(t *testing.T) { + tree := trie.New[struct{}]() domainSet := []string{ "baidu.com", "google.com", @@ -15,14 +16,19 @@ func TestDomain(t *testing.T) { "test.a.net", "test.a.oc", } - set := trie.NewDomainSet(domainSet) + + for _, domain := range domainSet { + assert.NoError(t, tree.Insert(domain, struct{}{})) + } + set := tree.NewDomainSet() assert.NotNil(t, set) assert.True(t, set.Has("test.a.net")) assert.True(t, set.Has("google.com")) assert.False(t, set.Has("www.baidu.com")) } -func TestDomainComplexWildcard(t *testing.T) { +func TestDomainSetComplexWildcard(t *testing.T) { + tree := trie.New[struct{}]() domainSet := []string{ "+.baidu.com", "+.a.baidu.com", @@ -32,14 +38,19 @@ func TestDomainComplexWildcard(t *testing.T) { "test.a.oc", "www.qq.com", } - set := trie.NewDomainSet(domainSet) + + for _, domain := range domainSet { + assert.NoError(t, tree.Insert(domain, struct{}{})) + } + set := tree.NewDomainSet() assert.NotNil(t, set) assert.False(t, set.Has("google.com")) assert.True(t, set.Has("www.baidu.com")) assert.True(t, set.Has("test.test.baidu.com")) } -func TestDomainWildcard(t *testing.T) { +func TestDomainSetWildcard(t *testing.T) { + tree := trie.New[struct{}]() domainSet := []string{ "*.*.*.baidu.com", "www.baidu.*", @@ -47,14 +58,18 @@ func TestDomainWildcard(t *testing.T) { "*.*.qq.com", "test.*.baidu.com", } - set := trie.NewDomainSet(domainSet) + + for _, domain := range domainSet { + assert.NoError(t, tree.Insert(domain, struct{}{})) + } + set := tree.NewDomainSet() assert.NotNil(t, set) assert.True(t, set.Has("www.baidu.com")) assert.True(t, set.Has("test.test.baidu.com")) assert.True(t, set.Has("test.test.qq.com")) - assert.True(t,set.Has("stun.ab.cd")) + assert.True(t, set.Has("stun.ab.cd")) assert.False(t, set.Has("test.baidu.com")) - assert.False(t,set.Has("www.google.com")) + assert.False(t, set.Has("www.google.com")) assert.False(t, set.Has("test.qq.com")) assert.False(t, set.Has("test.test.test.qq.com")) } diff --git a/component/trie/domain_test.go b/component/trie/domain_test.go index 2dfd1c34..976055a9 100644 --- a/component/trie/domain_test.go +++ b/component/trie/domain_test.go @@ -1,16 +1,17 @@ -package trie +package trie_test import ( "net/netip" "testing" + "github.com/Dreamacro/clash/component/trie" "github.com/stretchr/testify/assert" ) var localIP = netip.AddrFrom4([4]byte{127, 0, 0, 1}) func TestTrie_Basic(t *testing.T) { - tree := New[netip.Addr]() + tree := trie.New[netip.Addr]() domains := []string{ "example.com", "google.com", @@ -18,7 +19,7 @@ func TestTrie_Basic(t *testing.T) { } for _, domain := range domains { - tree.Insert(domain, localIP) + assert.NoError(t, tree.Insert(domain, localIP)) } node := tree.Search("example.com") @@ -31,7 +32,7 @@ func TestTrie_Basic(t *testing.T) { } func TestTrie_Wildcard(t *testing.T) { - tree := New[netip.Addr]() + tree := trie.New[netip.Addr]() domains := []string{ "*.example.com", "sub.*.example.com", @@ -47,7 +48,7 @@ func TestTrie_Wildcard(t *testing.T) { } for _, domain := range domains { - tree.Insert(domain, localIP) + assert.NoError(t, tree.Insert(domain, localIP)) } assert.NotNil(t, tree.Search("sub.example.com")) @@ -64,7 +65,7 @@ func TestTrie_Wildcard(t *testing.T) { } func TestTrie_Priority(t *testing.T) { - tree := New[int]() + tree := trie.New[int]() domains := []string{ ".dev", "example.dev", @@ -79,7 +80,7 @@ func TestTrie_Priority(t *testing.T) { } for idx, domain := range domains { - tree.Insert(domain, idx+1) + assert.NoError(t, tree.Insert(domain, idx+1)) } assertFn("test.dev", 1) @@ -90,8 +91,8 @@ func TestTrie_Priority(t *testing.T) { } func TestTrie_Boundary(t *testing.T) { - tree := New[netip.Addr]() - tree.Insert("*.dev", localIP) + tree := trie.New[netip.Addr]() + assert.NoError(t, tree.Insert("*.dev", localIP)) assert.NotNil(t, tree.Insert(".", localIP)) assert.NotNil(t, tree.Insert("..dev", localIP)) @@ -99,15 +100,15 @@ func TestTrie_Boundary(t *testing.T) { } func TestTrie_WildcardBoundary(t *testing.T) { - tree := New[netip.Addr]() - tree.Insert("+.*", localIP) - tree.Insert("stun.*.*.*", localIP) + tree := trie.New[netip.Addr]() + assert.NoError(t, tree.Insert("+.*", localIP)) + assert.NoError(t, tree.Insert("stun.*.*.*", localIP)) assert.NotNil(t, tree.Search("example.com")) } func TestTrie_Foreach(t *testing.T) { - tree := New[netip.Addr]() + tree := trie.New[netip.Addr]() domainList := []string{ "google.com", "stun.*.*.*", @@ -117,7 +118,7 @@ func TestTrie_Foreach(t *testing.T) { "*.*.baidu.com", } for _, domain := range domainList { - tree.Insert(domain, localIP) + assert.NoError(t, tree.Insert(domain, localIP)) } count := 0 tree.Foreach(func(domain string, data netip.Addr) { diff --git a/config/config.go b/config/config.go index d2378822..e7720c40 100644 --- a/config/config.go +++ b/config/config.go @@ -1340,8 +1340,25 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) { } sniffer.Sniffers = loadSniffer - sniffer.ForceDomain = trie.NewDomainSet(snifferRaw.ForceDomain) - sniffer.SkipDomain = trie.NewDomainSet(snifferRaw.SkipDomain) + + forceDomainTrie := trie.New[struct{}]() + for _, domain := range snifferRaw.ForceDomain { + err := forceDomainTrie.Insert(domain, struct{}{}) + if err != nil { + return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err) + } + } + sniffer.ForceDomain = forceDomainTrie.NewDomainSet() + + skipDomainTrie := trie.New[struct{}]() + for _, domain := range snifferRaw.SkipDomain { + err := skipDomainTrie.Insert(domain, struct{}{}) + if err != nil { + return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err) + } + } + sniffer.SkipDomain = skipDomainTrie.NewDomainSet() + return sniffer, nil } diff --git a/rules/provider/domain_strategy.go b/rules/provider/domain_strategy.go index 0b2a5d3c..a2cb795d 100644 --- a/rules/provider/domain_strategy.go +++ b/rules/provider/domain_strategy.go @@ -3,6 +3,7 @@ package provider import ( "github.com/Dreamacro/clash/component/trie" C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" ) type domainStrategy struct { @@ -27,8 +28,14 @@ func (d *domainStrategy) ShouldResolveIP() bool { } func (d *domainStrategy) OnUpdate(rules []string) { - domainTrie := trie.NewDomainSet(rules) - d.domainRules = domainTrie + domainTrie := trie.New[struct{}]() + for _, rule := range rules { + err := domainTrie.Insert(rule, struct{}{}) + if err != nil { + log.Warnln("invalid domain:[%s]", rule) + } + } + d.domainRules = domainTrie.NewDomainSet() d.count = len(rules) }