diff --git a/component/trie/domain.go b/component/trie/domain.go index 86c467db..cb4cee94 100644 --- a/component/trie/domain.go +++ b/component/trie/domain.go @@ -73,11 +73,7 @@ func (t *DomainTrie[T]) insert(parts []string, data T) { // reverse storage domain part to save space for i := len(parts) - 1; i >= 0; i-- { part := parts[i] - if !node.hasChild(part) { - node.addChild(part, newNode[T]()) - } - - node = node.getChild(part) + node = node.getOrNewChild(part) } node.setData(data) @@ -123,6 +119,10 @@ func (t *DomainTrie[T]) search(node *Node[T], parts []string) *Node[T] { return node.getChild(dotWildcard) } +func (t *DomainTrie[T]) FinishInsert() { + t.root.finishAdd() +} + // New returns a new, empty Trie. func New[T any]() *DomainTrie[T] { return &DomainTrie[T]{root: newNode[T]()} diff --git a/component/trie/node.go b/component/trie/node.go index 9d45bda8..37570351 100644 --- a/component/trie/node.go +++ b/component/trie/node.go @@ -1,5 +1,7 @@ package trie +import "strings" + // Node is the trie's node type Node[T any] struct { children map[string]*Node[T] @@ -8,6 +10,9 @@ type Node[T any] struct { } func (n *Node[T]) getChild(s string) *Node[T] { + if n.children == nil { + return nil + } return n.children[s] } @@ -16,9 +21,48 @@ func (n *Node[T]) hasChild(s string) bool { } func (n *Node[T]) addChild(s string, child *Node[T]) { + if n.children == nil { + n.children = map[string]*Node[T]{} + } n.children[s] = child } +func (n *Node[T]) getOrNewChild(s string) *Node[T] { + node := n.getChild(s) + if node == nil { + node = newNode[T]() + n.addChild(s, node) + } + return node +} + +func (n *Node[T]) finishAdd() { + if n.children == nil { + return + } + if len(n.children) == 0 { + n.children = nil + } + children := make(map[string]*Node[T], len(n.children)) // avoid map reallocate memory + for key := range n.children { + child := n.children[key] + if child == nil { + continue + } + switch key { // try to save string's memory + case wildcard: + key = wildcard + case dotWildcard: + key = dotWildcard + default: + key = strings.Clone(key) + } + children[key] = child + child.finishAdd() + } + n.children = children +} + func (n *Node[T]) isEmpty() bool { if n == nil || n.inited == false { return true @@ -37,13 +81,7 @@ func (n *Node[T]) Data() T { func newNode[T any]() *Node[T] { return &Node[T]{ - children: map[string]*Node[T]{}, + children: nil, inited: false, - data: getZero[T](), } } - -func getZero[T any]() T { - var result T - return result -} diff --git a/config/config.go b/config/config.go index 9033f162..9b6f60c0 100644 --- a/config/config.go +++ b/config/config.go @@ -972,6 +972,7 @@ func parseHosts(cfg *RawConfig) (*trie.DomainTrie[netip.Addr], error) { _ = tree.Insert(domain, ip) } } + tree.FinishInsert() return tree, nil } @@ -1206,6 +1207,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.R for _, domain := range cfg.FakeIPFilter { _ = host.Insert(domain, struct{}{}) } + host.FinishInsert() } if len(dnsCfg.Fallback) != 0 { @@ -1218,6 +1220,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.R } _ = host.Insert(fb.Addr, struct{}{}) } + host.FinishInsert() } pool, err := fakeip.New(fakeip.Options{ @@ -1393,6 +1396,7 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) { return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err) } } + sniffer.ForceDomain.FinishInsert() sniffer.SkipDomain = trie.New[struct{}]() for _, domain := range snifferRaw.SkipDomain { @@ -1401,6 +1405,7 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) { return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err) } } + sniffer.SkipDomain.FinishInsert() return sniffer, nil } diff --git a/dns/filters.go b/dns/filters.go index 5b0141a2..4fcb8ee4 100644 --- a/dns/filters.go +++ b/dns/filters.go @@ -79,6 +79,7 @@ func NewDomainFilter(domains []string) *domainFilter { for _, domain := range domains { _ = df.tree.Insert(domain, struct{}{}) } + df.tree.FinishInsert() return &df } diff --git a/dns/resolver.go b/dns/resolver.go index 4461a563..501bfb0d 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -435,6 +435,7 @@ func NewResolver(config Config) *Resolver { for domain, nameserver := range config.Policy { _ = r.policy.Insert(domain, NewPolicy(transform([]NameServer{nameserver}, defaultResolver))) } + r.policy.FinishInsert() } fallbackIPFilters := []fallbackIPFilter{} diff --git a/rules/provider/domain_strategy.go b/rules/provider/domain_strategy.go index e8bebf6e..14ff17a4 100644 --- a/rules/provider/domain_strategy.go +++ b/rules/provider/domain_strategy.go @@ -36,6 +36,7 @@ func (d *domainStrategy) OnUpdate(rules []string) { count++ } } + domainTrie.FinishInsert() d.domainRules = domainTrie d.count = count