chore: trie.DomainTrie will not depend on zero value

This commit is contained in:
wwqgtxx 2022-11-02 22:28:18 +08:00
parent c34c5ff1f9
commit 22fb219ad8
12 changed files with 66 additions and 48 deletions

View file

@ -34,7 +34,7 @@ type Pool struct {
offset netip.Addr offset netip.Addr
cycle bool cycle bool
mux sync.Mutex mux sync.Mutex
host *trie.DomainTrie[bool] host *trie.DomainTrie[struct{}]
ipnet *netip.Prefix ipnet *netip.Prefix
store store store store
} }
@ -150,7 +150,7 @@ func (p *Pool) restoreState() {
type Options struct { type Options struct {
IPNet *netip.Prefix IPNet *netip.Prefix
Host *trie.DomainTrie[bool] Host *trie.DomainTrie[struct{}]
// Size sets the maximum number of entries in memory // Size sets the maximum number of entries in memory
// and does not work if Persistence is true // and does not work if Persistence is true

View file

@ -127,7 +127,7 @@ func ResolveAllIPv6WithResolver(host string, r Resolver) ([]netip.Addr, error) {
} }
if node := DefaultHosts.Search(host); node != nil { if node := DefaultHosts.Search(host); node != nil {
if ip := node.Data; ip.Is6() { if ip := node.Data(); ip.Is6() {
return []netip.Addr{ip}, nil return []netip.Addr{ip}, nil
} }
} }
@ -161,8 +161,8 @@ func ResolveAllIPv6WithResolver(host string, r Resolver) ([]netip.Addr, error) {
func ResolveAllIPv4WithResolver(host string, r Resolver) ([]netip.Addr, error) { func ResolveAllIPv4WithResolver(host string, r Resolver) ([]netip.Addr, error) {
if node := DefaultHosts.Search(host); node != nil { if node := DefaultHosts.Search(host); node != nil {
if ip := node.Data; ip.Is4() { if ip := node.Data(); ip.Is4() {
return []netip.Addr{node.Data}, nil return []netip.Addr{node.Data()}, nil
} }
} }
@ -200,7 +200,7 @@ func ResolveAllIPv4WithResolver(host string, r Resolver) ([]netip.Addr, error) {
func ResolveAllIPWithResolver(host string, r Resolver) ([]netip.Addr, error) { func ResolveAllIPWithResolver(host string, r Resolver) ([]netip.Addr, error) {
if node := DefaultHosts.Search(host); node != nil { if node := DefaultHosts.Search(host); node != nil {
return []netip.Addr{node.Data}, nil return []netip.Addr{node.Data()}, nil
} }
ip, err := netip.ParseAddr(host) ip, err := netip.ParseAddr(host)

View file

@ -31,8 +31,8 @@ type SnifferDispatcher struct {
sniffers []sniffer.Sniffer sniffers []sniffer.Sniffer
forceDomain *trie.DomainTrie[bool] forceDomain *trie.DomainTrie[struct{}]
skipSNI *trie.DomainTrie[bool] skipSNI *trie.DomainTrie[struct{}]
portRanges *[]utils.Range[uint16] portRanges *[]utils.Range[uint16]
skipList *cache.LruCache[string, uint8] skipList *cache.LruCache[string, uint8]
rwMux sync.RWMutex rwMux sync.RWMutex
@ -183,8 +183,8 @@ func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) {
return &dispatcher, nil return &dispatcher, nil
} }
func NewSnifferDispatcher(needSniffer []sniffer.Type, forceDomain *trie.DomainTrie[bool], func NewSnifferDispatcher(needSniffer []sniffer.Type, forceDomain *trie.DomainTrie[struct{}],
skipSNI *trie.DomainTrie[bool], ports *[]utils.Range[uint16], skipSNI *trie.DomainTrie[struct{}], ports *[]utils.Range[uint16],
forceDnsMapping bool, parsePureIp bool) (*SnifferDispatcher, error) { forceDnsMapping bool, parsePureIp bool) (*SnifferDispatcher, error) {
dispatcher := SnifferDispatcher{ dispatcher := SnifferDispatcher{
enable: true, enable: true,

View file

@ -17,7 +17,7 @@ var ErrInvalidDomain = errors.New("invalid domain")
// DomainTrie contains the main logic for adding and searching nodes for domain segments. // DomainTrie contains the main logic for adding and searching nodes for domain segments.
// support wildcard domain (e.g *.google.com) // support wildcard domain (e.g *.google.com)
type DomainTrie[T comparable] struct { type DomainTrie[T any] struct {
root *Node[T] root *Node[T]
} }
@ -74,13 +74,13 @@ func (t *DomainTrie[T]) insert(parts []string, data T) {
for i := len(parts) - 1; i >= 0; i-- { for i := len(parts) - 1; i >= 0; i-- {
part := parts[i] part := parts[i]
if !node.hasChild(part) { if !node.hasChild(part) {
node.addChild(part, newNode(getZero[T]())) node.addChild(part, newNode[T]())
} }
node = node.getChild(part) node = node.getChild(part)
} }
node.Data = data node.setData(data)
} }
// Search is the most important part of the Trie. // Search is the most important part of the Trie.
@ -96,7 +96,7 @@ func (t *DomainTrie[T]) Search(domain string) *Node[T] {
n := t.search(t.root, parts) n := t.search(t.root, parts)
if n == nil || n.Data == getZero[T]() { if n.isEmpty() {
return nil return nil
} }
@ -109,13 +109,13 @@ func (t *DomainTrie[T]) search(node *Node[T], parts []string) *Node[T] {
} }
if c := node.getChild(parts[len(parts)-1]); c != nil { if c := node.getChild(parts[len(parts)-1]); c != nil {
if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != getZero[T]() { if n := t.search(c, parts[:len(parts)-1]); !n.isEmpty() {
return n return n
} }
} }
if c := node.getChild(wildcard); c != nil { if c := node.getChild(wildcard); c != nil {
if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != getZero[T]() { if n := t.search(c, parts[:len(parts)-1]); !n.isEmpty() {
return n return n
} }
} }
@ -124,6 +124,6 @@ func (t *DomainTrie[T]) search(node *Node[T], parts []string) *Node[T] {
} }
// New returns a new, empty Trie. // New returns a new, empty Trie.
func New[T comparable]() *DomainTrie[T] { func New[T any]() *DomainTrie[T] {
return &DomainTrie[T]{root: newNode[T](getZero[T]())} return &DomainTrie[T]{root: newNode[T]()}
} }

View file

@ -23,7 +23,7 @@ func TestTrie_Basic(t *testing.T) {
node := tree.Search("example.com") node := tree.Search("example.com")
assert.NotNil(t, node) assert.NotNil(t, node)
assert.True(t, node.Data == localIP) assert.True(t, node.Data() == localIP)
assert.NotNil(t, tree.Insert("", localIP)) assert.NotNil(t, tree.Insert("", localIP))
assert.Nil(t, tree.Search("")) assert.Nil(t, tree.Search(""))
assert.NotNil(t, tree.Search("localhost")) assert.NotNil(t, tree.Search("localhost"))
@ -75,7 +75,7 @@ func TestTrie_Priority(t *testing.T) {
assertFn := func(domain string, data int) { assertFn := func(domain string, data int) {
node := tree.Search(domain) node := tree.Search(domain)
assert.NotNil(t, node) assert.NotNil(t, node)
assert.Equal(t, data, node.Data) assert.Equal(t, data, node.Data())
} }
for idx, domain := range domains { for idx, domain := range domains {

View file

@ -1,9 +1,10 @@
package trie package trie
// Node is the trie's node // Node is the trie's node
type Node[T comparable] struct { type Node[T any] struct {
children map[string]*Node[T] children map[string]*Node[T]
Data T inited bool
data T
} }
func (n *Node[T]) getChild(s string) *Node[T] { func (n *Node[T]) getChild(s string) *Node[T] {
@ -18,14 +19,31 @@ func (n *Node[T]) addChild(s string, child *Node[T]) {
n.children[s] = child n.children[s] = child
} }
func newNode[T comparable](data T) *Node[T] { func (n *Node[T]) isEmpty() bool {
if n == nil || n.inited == false {
return true
}
return false
}
func (n *Node[T]) setData(data T) {
n.data = data
n.inited = true
}
func (n *Node[T]) Data() T {
return n.data
}
func newNode[T any]() *Node[T] {
return &Node[T]{ return &Node[T]{
Data: data,
children: map[string]*Node[T]{}, children: map[string]*Node[T]{},
inited: false,
data: getZero[T](),
} }
} }
func getZero[T comparable]() T { func getZero[T any]() T {
var result T var result T
return result return result
} }

View file

@ -197,9 +197,9 @@ type IPTables struct {
type Sniffer struct { type Sniffer struct {
Enable bool Enable bool
Sniffers []sniffer.Type Sniffers []sniffer.Type
Reverses *trie.DomainTrie[bool] Reverses *trie.DomainTrie[struct{}]
ForceDomain *trie.DomainTrie[bool] ForceDomain *trie.DomainTrie[struct{}]
SkipDomain *trie.DomainTrie[bool] SkipDomain *trie.DomainTrie[struct{}]
Ports *[]utils.Range[uint16] Ports *[]utils.Range[uint16]
ForceDnsMapping bool ForceDnsMapping bool
ParsePureIp bool ParsePureIp bool
@ -1061,24 +1061,24 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.R
return nil, err return nil, err
} }
var host *trie.DomainTrie[bool] var host *trie.DomainTrie[struct{}]
// fake ip skip host filter // fake ip skip host filter
if len(cfg.FakeIPFilter) != 0 { if len(cfg.FakeIPFilter) != 0 {
host = trie.New[bool]() host = trie.New[struct{}]()
for _, domain := range cfg.FakeIPFilter { for _, domain := range cfg.FakeIPFilter {
_ = host.Insert(domain, true) _ = host.Insert(domain, struct{}{})
} }
} }
if len(dnsCfg.Fallback) != 0 { if len(dnsCfg.Fallback) != 0 {
if host == nil { if host == nil {
host = trie.New[bool]() host = trie.New[struct{}]()
} }
for _, fb := range dnsCfg.Fallback { for _, fb := range dnsCfg.Fallback {
if net.ParseIP(fb.Addr) != nil { if net.ParseIP(fb.Addr) != nil {
continue continue
} }
_ = host.Insert(fb.Addr, true) _ = host.Insert(fb.Addr, struct{}{})
} }
} }
@ -1232,17 +1232,17 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) {
for st := range loadSniffer { for st := range loadSniffer {
sniffer.Sniffers = append(sniffer.Sniffers, st) sniffer.Sniffers = append(sniffer.Sniffers, st)
} }
sniffer.ForceDomain = trie.New[bool]() sniffer.ForceDomain = trie.New[struct{}]()
for _, domain := range snifferRaw.ForceDomain { for _, domain := range snifferRaw.ForceDomain {
err := sniffer.ForceDomain.Insert(domain, true) err := sniffer.ForceDomain.Insert(domain, struct{}{})
if err != nil { if err != nil {
return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err) return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err)
} }
} }
sniffer.SkipDomain = trie.New[bool]() sniffer.SkipDomain = trie.New[struct{}]()
for _, domain := range snifferRaw.SkipDomain { for _, domain := range snifferRaw.SkipDomain {
err := sniffer.SkipDomain.Insert(domain, true) err := sniffer.SkipDomain.Insert(domain, struct{}{})
if err != nil { if err != nil {
return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err) return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err)
} }

View file

@ -71,13 +71,13 @@ type fallbackDomainFilter interface {
} }
type domainFilter struct { type domainFilter struct {
tree *trie.DomainTrie[bool] tree *trie.DomainTrie[struct{}]
} }
func NewDomainFilter(domains []string) *domainFilter { func NewDomainFilter(domains []string) *domainFilter {
df := domainFilter{tree: trie.New[bool]()} df := domainFilter{tree: trie.New[struct{}]()}
for _, domain := range domains { for _, domain := range domains {
_ = df.tree.Insert(domain, true) _ = df.tree.Insert(domain, struct{}{})
} }
return &df return &df
} }

View file

@ -37,7 +37,7 @@ func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[netip
return next(ctx, r) return next(ctx, r)
} }
ip := record.Data ip := record.Data()
msg := r.Copy() msg := r.Copy()
if ip.Is4() && q.Qtype == D.TypeA { if ip.Is4() && q.Qtype == D.TypeA {

View file

@ -245,7 +245,7 @@ func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient {
return nil return nil
} }
p := record.Data p := record.Data()
return p.GetData() return p.GetData()
} }

View file

@ -9,7 +9,7 @@ import (
type domainStrategy struct { type domainStrategy struct {
count int count int
domainRules *trie.DomainTrie[bool] domainRules *trie.DomainTrie[struct{}]
} }
func (d *domainStrategy) Match(metadata *C.Metadata) bool { func (d *domainStrategy) Match(metadata *C.Metadata) bool {
@ -25,11 +25,11 @@ func (d *domainStrategy) ShouldResolveIP() bool {
} }
func (d *domainStrategy) OnUpdate(rules []string) { func (d *domainStrategy) OnUpdate(rules []string) {
domainTrie := trie.New[bool]() domainTrie := trie.New[struct{}]()
count := 0 count := 0
for _, rule := range rules { for _, rule := range rules {
actualDomain, _ := idna.ToASCII(rule) actualDomain, _ := idna.ToASCII(rule)
err := domainTrie.Insert(actualDomain, true) err := domainTrie.Insert(actualDomain, struct{}{})
if err != nil { if err != nil {
log.Warnln("invalid domain:[%s]", rule) log.Warnln("invalid domain:[%s]", rule)
} else { } else {

View file

@ -178,7 +178,7 @@ func preHandleMetadata(metadata *C.Metadata) error {
metadata.DNSMode = C.DNSFakeIP metadata.DNSMode = C.DNSFakeIP
} else if node := resolver.DefaultHosts.Search(host); node != nil { } else if node := resolver.DefaultHosts.Search(host); node != nil {
// redir-host should lookup the hosts // redir-host should lookup the hosts
metadata.DstIP = node.Data metadata.DstIP = node.Data()
} }
} else if resolver.IsFakeIP(metadata.DstIP) { } else if resolver.IsFakeIP(metadata.DstIP) {
return fmt.Errorf("fake DNS record %s missing", metadata.DstIP) return fmt.Errorf("fake DNS record %s missing", metadata.DstIP)
@ -334,7 +334,7 @@ func handleTCPConn(connCtx C.ConnContext) {
dialMetadata := metadata dialMetadata := metadata
if len(metadata.Host) > 0 { if len(metadata.Host) > 0 {
if node := resolver.DefaultHosts.Search(metadata.Host); node != nil { if node := resolver.DefaultHosts.Search(metadata.Host); node != nil {
dialMetadata.DstIP = node.Data dialMetadata.DstIP = node.Data()
dialMetadata.DNSMode = C.DNSHosts dialMetadata.DNSMode = C.DNSHosts
dialMetadata = dialMetadata.Pure() dialMetadata = dialMetadata.Pure()
} }
@ -388,7 +388,7 @@ func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) {
) )
if node := resolver.DefaultHosts.Search(metadata.Host); node != nil { if node := resolver.DefaultHosts.Search(metadata.Host); node != nil {
metadata.DstIP = node.Data metadata.DstIP = node.Data()
resolved = true resolved = true
} }