diff --git a/component/trie/domain.go b/component/trie/domain.go index cb4cee94..d9463c6e 100644 --- a/component/trie/domain.go +++ b/component/trie/domain.go @@ -119,8 +119,8 @@ func (t *DomainTrie[T]) search(node *Node[T], parts []string) *Node[T] { return node.getChild(dotWildcard) } -func (t *DomainTrie[T]) FinishInsert() { - t.root.finishAdd() +func (t *DomainTrie[T]) Optimize() { + t.root.optimize() } // New returns a new, empty Trie. diff --git a/component/trie/node.go b/component/trie/node.go index 37570351..e7baabb6 100644 --- a/component/trie/node.go +++ b/component/trie/node.go @@ -4,13 +4,18 @@ import "strings" // Node is the trie's node type Node[T any] struct { - children map[string]*Node[T] - inited bool - data T + childNode *Node[T] // optimize for only one child + childStr string + children map[string]*Node[T] + inited bool + data T } func (n *Node[T]) getChild(s string) *Node[T] { if n.children == nil { + if n.childNode != nil && n.childStr == s { + return n.childNode + } return nil } return n.children[s] @@ -22,8 +27,19 @@ func (n *Node[T]) hasChild(s string) bool { func (n *Node[T]) addChild(s string, child *Node[T]) { if n.children == nil { + if n.childNode == nil { + n.childStr = s + n.childNode = child + return + } n.children = map[string]*Node[T]{} + if n.childNode != nil { + n.children[n.childStr] = n.childNode + } + n.childStr = "" + n.childNode = nil } + n.children[s] = child } @@ -36,12 +52,28 @@ func (n *Node[T]) getOrNewChild(s string) *Node[T] { return node } -func (n *Node[T]) finishAdd() { +func (n *Node[T]) optimize() { + if len(n.childStr) > 0 { + n.childStr = strings.Clone(n.childStr) + } + if n.childNode != nil { + n.childNode.optimize() + } if n.children == nil { return } - if len(n.children) == 0 { + switch len(n.children) { + case 0: n.children = nil + return + case 1: + for key := range n.children { + n.childStr = key + n.childNode = n.children[key] + } + n.children = nil + n.optimize() + return } children := make(map[string]*Node[T], len(n.children)) // avoid map reallocate memory for key := range n.children { @@ -58,7 +90,7 @@ func (n *Node[T]) finishAdd() { key = strings.Clone(key) } children[key] = child - child.finishAdd() + child.optimize() } n.children = children } @@ -80,8 +112,5 @@ func (n *Node[T]) Data() T { } func newNode[T any]() *Node[T] { - return &Node[T]{ - children: nil, - inited: false, - } + return &Node[T]{} } diff --git a/config/config.go b/config/config.go index 9b6f60c0..bb8c65f4 100644 --- a/config/config.go +++ b/config/config.go @@ -972,7 +972,7 @@ func parseHosts(cfg *RawConfig) (*trie.DomainTrie[netip.Addr], error) { _ = tree.Insert(domain, ip) } } - tree.FinishInsert() + tree.Optimize() return tree, nil } @@ -1207,7 +1207,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.R for _, domain := range cfg.FakeIPFilter { _ = host.Insert(domain, struct{}{}) } - host.FinishInsert() + host.Optimize() } if len(dnsCfg.Fallback) != 0 { @@ -1220,7 +1220,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.R } _ = host.Insert(fb.Addr, struct{}{}) } - host.FinishInsert() + host.Optimize() } pool, err := fakeip.New(fakeip.Options{ @@ -1396,7 +1396,7 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) { return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err) } } - sniffer.ForceDomain.FinishInsert() + sniffer.ForceDomain.Optimize() sniffer.SkipDomain = trie.New[struct{}]() for _, domain := range snifferRaw.SkipDomain { @@ -1405,7 +1405,7 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) { return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err) } } - sniffer.SkipDomain.FinishInsert() + sniffer.SkipDomain.Optimize() return sniffer, nil } diff --git a/dns/filters.go b/dns/filters.go index 4fcb8ee4..0dbfa317 100644 --- a/dns/filters.go +++ b/dns/filters.go @@ -79,7 +79,7 @@ func NewDomainFilter(domains []string) *domainFilter { for _, domain := range domains { _ = df.tree.Insert(domain, struct{}{}) } - df.tree.FinishInsert() + df.tree.Optimize() return &df } diff --git a/dns/resolver.go b/dns/resolver.go index 501bfb0d..d99a465d 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -435,7 +435,7 @@ func NewResolver(config Config) *Resolver { for domain, nameserver := range config.Policy { _ = r.policy.Insert(domain, NewPolicy(transform([]NameServer{nameserver}, defaultResolver))) } - r.policy.FinishInsert() + r.policy.Optimize() } fallbackIPFilters := []fallbackIPFilter{} diff --git a/rules/provider/domain_strategy.go b/rules/provider/domain_strategy.go index 14ff17a4..61fe93a6 100644 --- a/rules/provider/domain_strategy.go +++ b/rules/provider/domain_strategy.go @@ -36,7 +36,7 @@ func (d *domainStrategy) OnUpdate(rules []string) { count++ } } - domainTrie.FinishInsert() + domainTrie.Optimize() d.domainRules = domainTrie d.count = count