diff --git a/component/fakeip/pool.go b/component/fakeip/pool.go index 58a75a94..1ef50d95 100644 --- a/component/fakeip/pool.go +++ b/component/fakeip/pool.go @@ -34,7 +34,7 @@ type Pool struct { offset netip.Addr cycle bool mux sync.Mutex - host *trie.DomainTrie[bool] + host *trie.DomainTrie[struct{}] ipnet *netip.Prefix store store } @@ -150,7 +150,7 @@ func (p *Pool) restoreState() { type Options struct { IPNet *netip.Prefix - Host *trie.DomainTrie[bool] + Host *trie.DomainTrie[struct{}] // Size sets the maximum number of entries in memory // and does not work if Persistence is true diff --git a/component/resolver/resolver.go b/component/resolver/resolver.go index abb45564..af32cc94 100644 --- a/component/resolver/resolver.go +++ b/component/resolver/resolver.go @@ -127,7 +127,7 @@ func ResolveAllIPv6WithResolver(host string, r Resolver) ([]netip.Addr, error) { } if node := DefaultHosts.Search(host); node != nil { - if ip := node.Data; ip.Is6() { + if ip := node.Data(); ip.Is6() { return []netip.Addr{ip}, nil } } @@ -161,8 +161,8 @@ func ResolveAllIPv6WithResolver(host string, r Resolver) ([]netip.Addr, error) { func ResolveAllIPv4WithResolver(host string, r Resolver) ([]netip.Addr, error) { if node := DefaultHosts.Search(host); node != nil { - if ip := node.Data; ip.Is4() { - return []netip.Addr{node.Data}, nil + if ip := node.Data(); ip.Is4() { + return []netip.Addr{node.Data()}, nil } } @@ -200,7 +200,7 @@ func ResolveAllIPv4WithResolver(host string, r Resolver) ([]netip.Addr, error) { func ResolveAllIPWithResolver(host string, r Resolver) ([]netip.Addr, error) { if node := DefaultHosts.Search(host); node != nil { - return []netip.Addr{node.Data}, nil + return []netip.Addr{node.Data()}, nil } ip, err := netip.ParseAddr(host) diff --git a/component/sniffer/dispatcher.go b/component/sniffer/dispatcher.go index 6a5e632a..3bd81ac8 100644 --- a/component/sniffer/dispatcher.go +++ b/component/sniffer/dispatcher.go @@ -31,8 +31,8 @@ type SnifferDispatcher struct { sniffers []sniffer.Sniffer - forceDomain *trie.DomainTrie[bool] - skipSNI *trie.DomainTrie[bool] + forceDomain *trie.DomainTrie[struct{}] + skipSNI *trie.DomainTrie[struct{}] portRanges *[]utils.Range[uint16] skipList *cache.LruCache[string, uint8] rwMux sync.RWMutex @@ -183,8 +183,8 @@ func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) { return &dispatcher, nil } -func NewSnifferDispatcher(needSniffer []sniffer.Type, forceDomain *trie.DomainTrie[bool], - skipSNI *trie.DomainTrie[bool], ports *[]utils.Range[uint16], +func NewSnifferDispatcher(needSniffer []sniffer.Type, forceDomain *trie.DomainTrie[struct{}], + skipSNI *trie.DomainTrie[struct{}], ports *[]utils.Range[uint16], forceDnsMapping bool, parsePureIp bool) (*SnifferDispatcher, error) { dispatcher := SnifferDispatcher{ enable: true, diff --git a/component/trie/domain.go b/component/trie/domain.go index 16dd9ae9..86c467db 100644 --- a/component/trie/domain.go +++ b/component/trie/domain.go @@ -17,7 +17,7 @@ var ErrInvalidDomain = errors.New("invalid domain") // DomainTrie contains the main logic for adding and searching nodes for domain segments. // support wildcard domain (e.g *.google.com) -type DomainTrie[T comparable] struct { +type DomainTrie[T any] struct { root *Node[T] } @@ -74,13 +74,13 @@ func (t *DomainTrie[T]) insert(parts []string, data T) { for i := len(parts) - 1; i >= 0; i-- { part := parts[i] if !node.hasChild(part) { - node.addChild(part, newNode(getZero[T]())) + node.addChild(part, newNode[T]()) } node = node.getChild(part) } - node.Data = data + node.setData(data) } // Search is the most important part of the Trie. @@ -96,7 +96,7 @@ func (t *DomainTrie[T]) Search(domain string) *Node[T] { n := t.search(t.root, parts) - if n == nil || n.Data == getZero[T]() { + if n.isEmpty() { return nil } @@ -109,13 +109,13 @@ func (t *DomainTrie[T]) search(node *Node[T], parts []string) *Node[T] { } if c := node.getChild(parts[len(parts)-1]); c != nil { - if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != getZero[T]() { + if n := t.search(c, parts[:len(parts)-1]); !n.isEmpty() { return n } } if c := node.getChild(wildcard); c != nil { - if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != getZero[T]() { + if n := t.search(c, parts[:len(parts)-1]); !n.isEmpty() { return n } } @@ -124,6 +124,6 @@ func (t *DomainTrie[T]) search(node *Node[T], parts []string) *Node[T] { } // New returns a new, empty Trie. -func New[T comparable]() *DomainTrie[T] { - return &DomainTrie[T]{root: newNode[T](getZero[T]())} +func New[T any]() *DomainTrie[T] { + return &DomainTrie[T]{root: newNode[T]()} } diff --git a/component/trie/domain_test.go b/component/trie/domain_test.go index ced44d03..c54b3d3b 100644 --- a/component/trie/domain_test.go +++ b/component/trie/domain_test.go @@ -23,7 +23,7 @@ func TestTrie_Basic(t *testing.T) { node := tree.Search("example.com") assert.NotNil(t, node) - assert.True(t, node.Data == localIP) + assert.True(t, node.Data() == localIP) assert.NotNil(t, tree.Insert("", localIP)) assert.Nil(t, tree.Search("")) assert.NotNil(t, tree.Search("localhost")) @@ -75,7 +75,7 @@ func TestTrie_Priority(t *testing.T) { assertFn := func(domain string, data int) { node := tree.Search(domain) assert.NotNil(t, node) - assert.Equal(t, data, node.Data) + assert.Equal(t, data, node.Data()) } for idx, domain := range domains { diff --git a/component/trie/node.go b/component/trie/node.go index 1545d880..9d45bda8 100644 --- a/component/trie/node.go +++ b/component/trie/node.go @@ -1,9 +1,10 @@ package trie // Node is the trie's node -type Node[T comparable] struct { +type Node[T any] struct { children map[string]*Node[T] - Data T + inited bool + data T } func (n *Node[T]) getChild(s string) *Node[T] { @@ -18,14 +19,31 @@ func (n *Node[T]) addChild(s string, child *Node[T]) { n.children[s] = child } -func newNode[T comparable](data T) *Node[T] { +func (n *Node[T]) isEmpty() bool { + if n == nil || n.inited == false { + return true + } + return false +} + +func (n *Node[T]) setData(data T) { + n.data = data + n.inited = true +} + +func (n *Node[T]) Data() T { + return n.data +} + +func newNode[T any]() *Node[T] { return &Node[T]{ - Data: data, children: map[string]*Node[T]{}, + inited: false, + data: getZero[T](), } } -func getZero[T comparable]() T { +func getZero[T any]() T { var result T return result } diff --git a/config/config.go b/config/config.go index ee4fbb16..9cad0915 100644 --- a/config/config.go +++ b/config/config.go @@ -197,9 +197,9 @@ type IPTables struct { type Sniffer struct { Enable bool Sniffers []sniffer.Type - Reverses *trie.DomainTrie[bool] - ForceDomain *trie.DomainTrie[bool] - SkipDomain *trie.DomainTrie[bool] + Reverses *trie.DomainTrie[struct{}] + ForceDomain *trie.DomainTrie[struct{}] + SkipDomain *trie.DomainTrie[struct{}] Ports *[]utils.Range[uint16] ForceDnsMapping bool ParsePureIp bool @@ -1061,24 +1061,24 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.R return nil, err } - var host *trie.DomainTrie[bool] + var host *trie.DomainTrie[struct{}] // fake ip skip host filter if len(cfg.FakeIPFilter) != 0 { - host = trie.New[bool]() + host = trie.New[struct{}]() for _, domain := range cfg.FakeIPFilter { - _ = host.Insert(domain, true) + _ = host.Insert(domain, struct{}{}) } } if len(dnsCfg.Fallback) != 0 { if host == nil { - host = trie.New[bool]() + host = trie.New[struct{}]() } for _, fb := range dnsCfg.Fallback { if net.ParseIP(fb.Addr) != nil { continue } - _ = host.Insert(fb.Addr, true) + _ = host.Insert(fb.Addr, struct{}{}) } } @@ -1232,17 +1232,17 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) { for st := range loadSniffer { sniffer.Sniffers = append(sniffer.Sniffers, st) } - sniffer.ForceDomain = trie.New[bool]() + sniffer.ForceDomain = trie.New[struct{}]() for _, domain := range snifferRaw.ForceDomain { - err := sniffer.ForceDomain.Insert(domain, true) + err := sniffer.ForceDomain.Insert(domain, struct{}{}) if err != nil { return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err) } } - sniffer.SkipDomain = trie.New[bool]() + sniffer.SkipDomain = trie.New[struct{}]() for _, domain := range snifferRaw.SkipDomain { - err := sniffer.SkipDomain.Insert(domain, true) + err := sniffer.SkipDomain.Insert(domain, struct{}{}) if err != nil { return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err) } diff --git a/dns/filters.go b/dns/filters.go index 80b656c9..5b0141a2 100644 --- a/dns/filters.go +++ b/dns/filters.go @@ -71,13 +71,13 @@ type fallbackDomainFilter interface { } type domainFilter struct { - tree *trie.DomainTrie[bool] + tree *trie.DomainTrie[struct{}] } func NewDomainFilter(domains []string) *domainFilter { - df := domainFilter{tree: trie.New[bool]()} + df := domainFilter{tree: trie.New[struct{}]()} for _, domain := range domains { - _ = df.tree.Insert(domain, true) + _ = df.tree.Insert(domain, struct{}{}) } return &df } diff --git a/dns/middleware.go b/dns/middleware.go index 0bfc4977..0e1335f9 100644 --- a/dns/middleware.go +++ b/dns/middleware.go @@ -37,7 +37,7 @@ func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[netip return next(ctx, r) } - ip := record.Data + ip := record.Data() msg := r.Copy() if ip.Is4() && q.Qtype == D.TypeA { diff --git a/dns/resolver.go b/dns/resolver.go index aac22cc8..84a38034 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -245,7 +245,7 @@ func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient { return nil } - p := record.Data + p := record.Data() return p.GetData() } diff --git a/rules/provider/domain_strategy.go b/rules/provider/domain_strategy.go index a6145383..e8bebf6e 100644 --- a/rules/provider/domain_strategy.go +++ b/rules/provider/domain_strategy.go @@ -9,7 +9,7 @@ import ( type domainStrategy struct { count int - domainRules *trie.DomainTrie[bool] + domainRules *trie.DomainTrie[struct{}] } func (d *domainStrategy) Match(metadata *C.Metadata) bool { @@ -25,11 +25,11 @@ func (d *domainStrategy) ShouldResolveIP() bool { } func (d *domainStrategy) OnUpdate(rules []string) { - domainTrie := trie.New[bool]() + domainTrie := trie.New[struct{}]() count := 0 for _, rule := range rules { actualDomain, _ := idna.ToASCII(rule) - err := domainTrie.Insert(actualDomain, true) + err := domainTrie.Insert(actualDomain, struct{}{}) if err != nil { log.Warnln("invalid domain:[%s]", rule) } else { diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index ddf771b9..d9513e01 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -178,7 +178,7 @@ func preHandleMetadata(metadata *C.Metadata) error { metadata.DNSMode = C.DNSFakeIP } else if node := resolver.DefaultHosts.Search(host); node != nil { // redir-host should lookup the hosts - metadata.DstIP = node.Data + metadata.DstIP = node.Data() } } else if resolver.IsFakeIP(metadata.DstIP) { return fmt.Errorf("fake DNS record %s missing", metadata.DstIP) @@ -334,7 +334,7 @@ func handleTCPConn(connCtx C.ConnContext) { dialMetadata := metadata if len(metadata.Host) > 0 { if node := resolver.DefaultHosts.Search(metadata.Host); node != nil { - dialMetadata.DstIP = node.Data + dialMetadata.DstIP = node.Data() dialMetadata.DNSMode = C.DNSHosts dialMetadata = dialMetadata.Pure() } @@ -388,7 +388,7 @@ func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { ) if node := resolver.DefaultHosts.Search(metadata.Host); node != nil { - metadata.DstIP = node.Data + metadata.DstIP = node.Data() resolved = true }