chore: optimize DomainTrie for only one child

This commit is contained in:
wwqgtxx 2022-11-30 19:42:05 +08:00
parent 84caee94af
commit df8e129fc6
6 changed files with 49 additions and 20 deletions

View file

@ -119,8 +119,8 @@ func (t *DomainTrie[T]) search(node *Node[T], parts []string) *Node[T] {
return node.getChild(dotWildcard) return node.getChild(dotWildcard)
} }
func (t *DomainTrie[T]) FinishInsert() { func (t *DomainTrie[T]) Optimize() {
t.root.finishAdd() t.root.optimize()
} }
// New returns a new, empty Trie. // New returns a new, empty Trie.

View file

@ -4,6 +4,8 @@ import "strings"
// Node is the trie's node // Node is the trie's node
type Node[T any] struct { type Node[T any] struct {
childNode *Node[T] // optimize for only one child
childStr string
children map[string]*Node[T] children map[string]*Node[T]
inited bool inited bool
data T data T
@ -11,6 +13,9 @@ type Node[T any] struct {
func (n *Node[T]) getChild(s string) *Node[T] { func (n *Node[T]) getChild(s string) *Node[T] {
if n.children == nil { if n.children == nil {
if n.childNode != nil && n.childStr == s {
return n.childNode
}
return nil return nil
} }
return n.children[s] return n.children[s]
@ -22,8 +27,19 @@ func (n *Node[T]) hasChild(s string) bool {
func (n *Node[T]) addChild(s string, child *Node[T]) { func (n *Node[T]) addChild(s string, child *Node[T]) {
if n.children == nil { if n.children == nil {
n.children = map[string]*Node[T]{} if n.childNode == nil {
n.childStr = s
n.childNode = child
return
} }
n.children = map[string]*Node[T]{}
if n.childNode != nil {
n.children[n.childStr] = n.childNode
}
n.childStr = ""
n.childNode = nil
}
n.children[s] = child n.children[s] = child
} }
@ -36,12 +52,28 @@ func (n *Node[T]) getOrNewChild(s string) *Node[T] {
return node return node
} }
func (n *Node[T]) finishAdd() { func (n *Node[T]) optimize() {
if len(n.childStr) > 0 {
n.childStr = strings.Clone(n.childStr)
}
if n.childNode != nil {
n.childNode.optimize()
}
if n.children == nil { if n.children == nil {
return return
} }
if len(n.children) == 0 { switch len(n.children) {
case 0:
n.children = nil n.children = nil
return
case 1:
for key := range n.children {
n.childStr = key
n.childNode = n.children[key]
}
n.children = nil
n.optimize()
return
} }
children := make(map[string]*Node[T], len(n.children)) // avoid map reallocate memory children := make(map[string]*Node[T], len(n.children)) // avoid map reallocate memory
for key := range n.children { for key := range n.children {
@ -58,7 +90,7 @@ func (n *Node[T]) finishAdd() {
key = strings.Clone(key) key = strings.Clone(key)
} }
children[key] = child children[key] = child
child.finishAdd() child.optimize()
} }
n.children = children n.children = children
} }
@ -80,8 +112,5 @@ func (n *Node[T]) Data() T {
} }
func newNode[T any]() *Node[T] { func newNode[T any]() *Node[T] {
return &Node[T]{ return &Node[T]{}
children: nil,
inited: false,
}
} }

View file

@ -972,7 +972,7 @@ func parseHosts(cfg *RawConfig) (*trie.DomainTrie[netip.Addr], error) {
_ = tree.Insert(domain, ip) _ = tree.Insert(domain, ip)
} }
} }
tree.FinishInsert() tree.Optimize()
return tree, nil return tree, nil
} }
@ -1207,7 +1207,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.R
for _, domain := range cfg.FakeIPFilter { for _, domain := range cfg.FakeIPFilter {
_ = host.Insert(domain, struct{}{}) _ = host.Insert(domain, struct{}{})
} }
host.FinishInsert() host.Optimize()
} }
if len(dnsCfg.Fallback) != 0 { if len(dnsCfg.Fallback) != 0 {
@ -1220,7 +1220,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.R
} }
_ = host.Insert(fb.Addr, struct{}{}) _ = host.Insert(fb.Addr, struct{}{})
} }
host.FinishInsert() host.Optimize()
} }
pool, err := fakeip.New(fakeip.Options{ pool, err := fakeip.New(fakeip.Options{
@ -1396,7 +1396,7 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) {
return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err) return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err)
} }
} }
sniffer.ForceDomain.FinishInsert() sniffer.ForceDomain.Optimize()
sniffer.SkipDomain = trie.New[struct{}]() sniffer.SkipDomain = trie.New[struct{}]()
for _, domain := range snifferRaw.SkipDomain { for _, domain := range snifferRaw.SkipDomain {
@ -1405,7 +1405,7 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) {
return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err) return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err)
} }
} }
sniffer.SkipDomain.FinishInsert() sniffer.SkipDomain.Optimize()
return sniffer, nil return sniffer, nil
} }

View file

@ -79,7 +79,7 @@ func NewDomainFilter(domains []string) *domainFilter {
for _, domain := range domains { for _, domain := range domains {
_ = df.tree.Insert(domain, struct{}{}) _ = df.tree.Insert(domain, struct{}{})
} }
df.tree.FinishInsert() df.tree.Optimize()
return &df return &df
} }

View file

@ -435,7 +435,7 @@ func NewResolver(config Config) *Resolver {
for domain, nameserver := range config.Policy { for domain, nameserver := range config.Policy {
_ = r.policy.Insert(domain, NewPolicy(transform([]NameServer{nameserver}, defaultResolver))) _ = r.policy.Insert(domain, NewPolicy(transform([]NameServer{nameserver}, defaultResolver)))
} }
r.policy.FinishInsert() r.policy.Optimize()
} }
fallbackIPFilters := []fallbackIPFilter{} fallbackIPFilters := []fallbackIPFilter{}

View file

@ -36,7 +36,7 @@ func (d *domainStrategy) OnUpdate(rules []string) {
count++ count++
} }
} }
domainTrie.FinishInsert() domainTrie.Optimize()
d.domainRules = domainTrie d.domainRules = domainTrie
d.count = count d.count = count