chore: trie.DomainTrie will not depend on zero value
This commit is contained in:
parent
c34c5ff1f9
commit
22fb219ad8
12 changed files with 66 additions and 48 deletions
|
@ -34,7 +34,7 @@ type Pool struct {
|
||||||
offset netip.Addr
|
offset netip.Addr
|
||||||
cycle bool
|
cycle bool
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
host *trie.DomainTrie[bool]
|
host *trie.DomainTrie[struct{}]
|
||||||
ipnet *netip.Prefix
|
ipnet *netip.Prefix
|
||||||
store store
|
store store
|
||||||
}
|
}
|
||||||
|
@ -150,7 +150,7 @@ func (p *Pool) restoreState() {
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
IPNet *netip.Prefix
|
IPNet *netip.Prefix
|
||||||
Host *trie.DomainTrie[bool]
|
Host *trie.DomainTrie[struct{}]
|
||||||
|
|
||||||
// Size sets the maximum number of entries in memory
|
// Size sets the maximum number of entries in memory
|
||||||
// and does not work if Persistence is true
|
// and does not work if Persistence is true
|
||||||
|
|
|
@ -127,7 +127,7 @@ func ResolveAllIPv6WithResolver(host string, r Resolver) ([]netip.Addr, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if node := DefaultHosts.Search(host); node != nil {
|
if node := DefaultHosts.Search(host); node != nil {
|
||||||
if ip := node.Data; ip.Is6() {
|
if ip := node.Data(); ip.Is6() {
|
||||||
return []netip.Addr{ip}, nil
|
return []netip.Addr{ip}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -161,8 +161,8 @@ func ResolveAllIPv6WithResolver(host string, r Resolver) ([]netip.Addr, error) {
|
||||||
|
|
||||||
func ResolveAllIPv4WithResolver(host string, r Resolver) ([]netip.Addr, error) {
|
func ResolveAllIPv4WithResolver(host string, r Resolver) ([]netip.Addr, error) {
|
||||||
if node := DefaultHosts.Search(host); node != nil {
|
if node := DefaultHosts.Search(host); node != nil {
|
||||||
if ip := node.Data; ip.Is4() {
|
if ip := node.Data(); ip.Is4() {
|
||||||
return []netip.Addr{node.Data}, nil
|
return []netip.Addr{node.Data()}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -200,7 +200,7 @@ func ResolveAllIPv4WithResolver(host string, r Resolver) ([]netip.Addr, error) {
|
||||||
|
|
||||||
func ResolveAllIPWithResolver(host string, r Resolver) ([]netip.Addr, error) {
|
func ResolveAllIPWithResolver(host string, r Resolver) ([]netip.Addr, error) {
|
||||||
if node := DefaultHosts.Search(host); node != nil {
|
if node := DefaultHosts.Search(host); node != nil {
|
||||||
return []netip.Addr{node.Data}, nil
|
return []netip.Addr{node.Data()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ip, err := netip.ParseAddr(host)
|
ip, err := netip.ParseAddr(host)
|
||||||
|
|
|
@ -31,8 +31,8 @@ type SnifferDispatcher struct {
|
||||||
|
|
||||||
sniffers []sniffer.Sniffer
|
sniffers []sniffer.Sniffer
|
||||||
|
|
||||||
forceDomain *trie.DomainTrie[bool]
|
forceDomain *trie.DomainTrie[struct{}]
|
||||||
skipSNI *trie.DomainTrie[bool]
|
skipSNI *trie.DomainTrie[struct{}]
|
||||||
portRanges *[]utils.Range[uint16]
|
portRanges *[]utils.Range[uint16]
|
||||||
skipList *cache.LruCache[string, uint8]
|
skipList *cache.LruCache[string, uint8]
|
||||||
rwMux sync.RWMutex
|
rwMux sync.RWMutex
|
||||||
|
@ -183,8 +183,8 @@ func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) {
|
||||||
return &dispatcher, nil
|
return &dispatcher, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSnifferDispatcher(needSniffer []sniffer.Type, forceDomain *trie.DomainTrie[bool],
|
func NewSnifferDispatcher(needSniffer []sniffer.Type, forceDomain *trie.DomainTrie[struct{}],
|
||||||
skipSNI *trie.DomainTrie[bool], ports *[]utils.Range[uint16],
|
skipSNI *trie.DomainTrie[struct{}], ports *[]utils.Range[uint16],
|
||||||
forceDnsMapping bool, parsePureIp bool) (*SnifferDispatcher, error) {
|
forceDnsMapping bool, parsePureIp bool) (*SnifferDispatcher, error) {
|
||||||
dispatcher := SnifferDispatcher{
|
dispatcher := SnifferDispatcher{
|
||||||
enable: true,
|
enable: true,
|
||||||
|
|
|
@ -17,7 +17,7 @@ var ErrInvalidDomain = errors.New("invalid domain")
|
||||||
|
|
||||||
// DomainTrie contains the main logic for adding and searching nodes for domain segments.
|
// DomainTrie contains the main logic for adding and searching nodes for domain segments.
|
||||||
// support wildcard domain (e.g *.google.com)
|
// support wildcard domain (e.g *.google.com)
|
||||||
type DomainTrie[T comparable] struct {
|
type DomainTrie[T any] struct {
|
||||||
root *Node[T]
|
root *Node[T]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,13 +74,13 @@ func (t *DomainTrie[T]) insert(parts []string, data T) {
|
||||||
for i := len(parts) - 1; i >= 0; i-- {
|
for i := len(parts) - 1; i >= 0; i-- {
|
||||||
part := parts[i]
|
part := parts[i]
|
||||||
if !node.hasChild(part) {
|
if !node.hasChild(part) {
|
||||||
node.addChild(part, newNode(getZero[T]()))
|
node.addChild(part, newNode[T]())
|
||||||
}
|
}
|
||||||
|
|
||||||
node = node.getChild(part)
|
node = node.getChild(part)
|
||||||
}
|
}
|
||||||
|
|
||||||
node.Data = data
|
node.setData(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Search is the most important part of the Trie.
|
// Search is the most important part of the Trie.
|
||||||
|
@ -96,7 +96,7 @@ func (t *DomainTrie[T]) Search(domain string) *Node[T] {
|
||||||
|
|
||||||
n := t.search(t.root, parts)
|
n := t.search(t.root, parts)
|
||||||
|
|
||||||
if n == nil || n.Data == getZero[T]() {
|
if n.isEmpty() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,13 +109,13 @@ func (t *DomainTrie[T]) search(node *Node[T], parts []string) *Node[T] {
|
||||||
}
|
}
|
||||||
|
|
||||||
if c := node.getChild(parts[len(parts)-1]); c != nil {
|
if c := node.getChild(parts[len(parts)-1]); c != nil {
|
||||||
if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != getZero[T]() {
|
if n := t.search(c, parts[:len(parts)-1]); !n.isEmpty() {
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if c := node.getChild(wildcard); c != nil {
|
if c := node.getChild(wildcard); c != nil {
|
||||||
if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != getZero[T]() {
|
if n := t.search(c, parts[:len(parts)-1]); !n.isEmpty() {
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -124,6 +124,6 @@ func (t *DomainTrie[T]) search(node *Node[T], parts []string) *Node[T] {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns a new, empty Trie.
|
// New returns a new, empty Trie.
|
||||||
func New[T comparable]() *DomainTrie[T] {
|
func New[T any]() *DomainTrie[T] {
|
||||||
return &DomainTrie[T]{root: newNode[T](getZero[T]())}
|
return &DomainTrie[T]{root: newNode[T]()}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,7 @@ func TestTrie_Basic(t *testing.T) {
|
||||||
|
|
||||||
node := tree.Search("example.com")
|
node := tree.Search("example.com")
|
||||||
assert.NotNil(t, node)
|
assert.NotNil(t, node)
|
||||||
assert.True(t, node.Data == localIP)
|
assert.True(t, node.Data() == localIP)
|
||||||
assert.NotNil(t, tree.Insert("", localIP))
|
assert.NotNil(t, tree.Insert("", localIP))
|
||||||
assert.Nil(t, tree.Search(""))
|
assert.Nil(t, tree.Search(""))
|
||||||
assert.NotNil(t, tree.Search("localhost"))
|
assert.NotNil(t, tree.Search("localhost"))
|
||||||
|
@ -75,7 +75,7 @@ func TestTrie_Priority(t *testing.T) {
|
||||||
assertFn := func(domain string, data int) {
|
assertFn := func(domain string, data int) {
|
||||||
node := tree.Search(domain)
|
node := tree.Search(domain)
|
||||||
assert.NotNil(t, node)
|
assert.NotNil(t, node)
|
||||||
assert.Equal(t, data, node.Data)
|
assert.Equal(t, data, node.Data())
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, domain := range domains {
|
for idx, domain := range domains {
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
package trie
|
package trie
|
||||||
|
|
||||||
// Node is the trie's node
|
// Node is the trie's node
|
||||||
type Node[T comparable] struct {
|
type Node[T any] struct {
|
||||||
children map[string]*Node[T]
|
children map[string]*Node[T]
|
||||||
Data T
|
inited bool
|
||||||
|
data T
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Node[T]) getChild(s string) *Node[T] {
|
func (n *Node[T]) getChild(s string) *Node[T] {
|
||||||
|
@ -18,14 +19,31 @@ func (n *Node[T]) addChild(s string, child *Node[T]) {
|
||||||
n.children[s] = child
|
n.children[s] = child
|
||||||
}
|
}
|
||||||
|
|
||||||
func newNode[T comparable](data T) *Node[T] {
|
func (n *Node[T]) isEmpty() bool {
|
||||||
|
if n == nil || n.inited == false {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node[T]) setData(data T) {
|
||||||
|
n.data = data
|
||||||
|
n.inited = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node[T]) Data() T {
|
||||||
|
return n.data
|
||||||
|
}
|
||||||
|
|
||||||
|
func newNode[T any]() *Node[T] {
|
||||||
return &Node[T]{
|
return &Node[T]{
|
||||||
Data: data,
|
|
||||||
children: map[string]*Node[T]{},
|
children: map[string]*Node[T]{},
|
||||||
|
inited: false,
|
||||||
|
data: getZero[T](),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getZero[T comparable]() T {
|
func getZero[T any]() T {
|
||||||
var result T
|
var result T
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
|
@ -197,9 +197,9 @@ type IPTables struct {
|
||||||
type Sniffer struct {
|
type Sniffer struct {
|
||||||
Enable bool
|
Enable bool
|
||||||
Sniffers []sniffer.Type
|
Sniffers []sniffer.Type
|
||||||
Reverses *trie.DomainTrie[bool]
|
Reverses *trie.DomainTrie[struct{}]
|
||||||
ForceDomain *trie.DomainTrie[bool]
|
ForceDomain *trie.DomainTrie[struct{}]
|
||||||
SkipDomain *trie.DomainTrie[bool]
|
SkipDomain *trie.DomainTrie[struct{}]
|
||||||
Ports *[]utils.Range[uint16]
|
Ports *[]utils.Range[uint16]
|
||||||
ForceDnsMapping bool
|
ForceDnsMapping bool
|
||||||
ParsePureIp bool
|
ParsePureIp bool
|
||||||
|
@ -1061,24 +1061,24 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.R
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var host *trie.DomainTrie[bool]
|
var host *trie.DomainTrie[struct{}]
|
||||||
// fake ip skip host filter
|
// fake ip skip host filter
|
||||||
if len(cfg.FakeIPFilter) != 0 {
|
if len(cfg.FakeIPFilter) != 0 {
|
||||||
host = trie.New[bool]()
|
host = trie.New[struct{}]()
|
||||||
for _, domain := range cfg.FakeIPFilter {
|
for _, domain := range cfg.FakeIPFilter {
|
||||||
_ = host.Insert(domain, true)
|
_ = host.Insert(domain, struct{}{})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(dnsCfg.Fallback) != 0 {
|
if len(dnsCfg.Fallback) != 0 {
|
||||||
if host == nil {
|
if host == nil {
|
||||||
host = trie.New[bool]()
|
host = trie.New[struct{}]()
|
||||||
}
|
}
|
||||||
for _, fb := range dnsCfg.Fallback {
|
for _, fb := range dnsCfg.Fallback {
|
||||||
if net.ParseIP(fb.Addr) != nil {
|
if net.ParseIP(fb.Addr) != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
_ = host.Insert(fb.Addr, true)
|
_ = host.Insert(fb.Addr, struct{}{})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1232,17 +1232,17 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) {
|
||||||
for st := range loadSniffer {
|
for st := range loadSniffer {
|
||||||
sniffer.Sniffers = append(sniffer.Sniffers, st)
|
sniffer.Sniffers = append(sniffer.Sniffers, st)
|
||||||
}
|
}
|
||||||
sniffer.ForceDomain = trie.New[bool]()
|
sniffer.ForceDomain = trie.New[struct{}]()
|
||||||
for _, domain := range snifferRaw.ForceDomain {
|
for _, domain := range snifferRaw.ForceDomain {
|
||||||
err := sniffer.ForceDomain.Insert(domain, true)
|
err := sniffer.ForceDomain.Insert(domain, struct{}{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err)
|
return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sniffer.SkipDomain = trie.New[bool]()
|
sniffer.SkipDomain = trie.New[struct{}]()
|
||||||
for _, domain := range snifferRaw.SkipDomain {
|
for _, domain := range snifferRaw.SkipDomain {
|
||||||
err := sniffer.SkipDomain.Insert(domain, true)
|
err := sniffer.SkipDomain.Insert(domain, struct{}{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err)
|
return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,13 +71,13 @@ type fallbackDomainFilter interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type domainFilter struct {
|
type domainFilter struct {
|
||||||
tree *trie.DomainTrie[bool]
|
tree *trie.DomainTrie[struct{}]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDomainFilter(domains []string) *domainFilter {
|
func NewDomainFilter(domains []string) *domainFilter {
|
||||||
df := domainFilter{tree: trie.New[bool]()}
|
df := domainFilter{tree: trie.New[struct{}]()}
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
_ = df.tree.Insert(domain, true)
|
_ = df.tree.Insert(domain, struct{}{})
|
||||||
}
|
}
|
||||||
return &df
|
return &df
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,7 @@ func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[netip
|
||||||
return next(ctx, r)
|
return next(ctx, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
ip := record.Data
|
ip := record.Data()
|
||||||
msg := r.Copy()
|
msg := r.Copy()
|
||||||
|
|
||||||
if ip.Is4() && q.Qtype == D.TypeA {
|
if ip.Is4() && q.Qtype == D.TypeA {
|
||||||
|
|
|
@ -245,7 +245,7 @@ func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
p := record.Data
|
p := record.Data()
|
||||||
return p.GetData()
|
return p.GetData()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
|
|
||||||
type domainStrategy struct {
|
type domainStrategy struct {
|
||||||
count int
|
count int
|
||||||
domainRules *trie.DomainTrie[bool]
|
domainRules *trie.DomainTrie[struct{}]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *domainStrategy) Match(metadata *C.Metadata) bool {
|
func (d *domainStrategy) Match(metadata *C.Metadata) bool {
|
||||||
|
@ -25,11 +25,11 @@ func (d *domainStrategy) ShouldResolveIP() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *domainStrategy) OnUpdate(rules []string) {
|
func (d *domainStrategy) OnUpdate(rules []string) {
|
||||||
domainTrie := trie.New[bool]()
|
domainTrie := trie.New[struct{}]()
|
||||||
count := 0
|
count := 0
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
actualDomain, _ := idna.ToASCII(rule)
|
actualDomain, _ := idna.ToASCII(rule)
|
||||||
err := domainTrie.Insert(actualDomain, true)
|
err := domainTrie.Insert(actualDomain, struct{}{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnln("invalid domain:[%s]", rule)
|
log.Warnln("invalid domain:[%s]", rule)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -178,7 +178,7 @@ func preHandleMetadata(metadata *C.Metadata) error {
|
||||||
metadata.DNSMode = C.DNSFakeIP
|
metadata.DNSMode = C.DNSFakeIP
|
||||||
} else if node := resolver.DefaultHosts.Search(host); node != nil {
|
} else if node := resolver.DefaultHosts.Search(host); node != nil {
|
||||||
// redir-host should lookup the hosts
|
// redir-host should lookup the hosts
|
||||||
metadata.DstIP = node.Data
|
metadata.DstIP = node.Data()
|
||||||
}
|
}
|
||||||
} else if resolver.IsFakeIP(metadata.DstIP) {
|
} else if resolver.IsFakeIP(metadata.DstIP) {
|
||||||
return fmt.Errorf("fake DNS record %s missing", metadata.DstIP)
|
return fmt.Errorf("fake DNS record %s missing", metadata.DstIP)
|
||||||
|
@ -334,7 +334,7 @@ func handleTCPConn(connCtx C.ConnContext) {
|
||||||
dialMetadata := metadata
|
dialMetadata := metadata
|
||||||
if len(metadata.Host) > 0 {
|
if len(metadata.Host) > 0 {
|
||||||
if node := resolver.DefaultHosts.Search(metadata.Host); node != nil {
|
if node := resolver.DefaultHosts.Search(metadata.Host); node != nil {
|
||||||
dialMetadata.DstIP = node.Data
|
dialMetadata.DstIP = node.Data()
|
||||||
dialMetadata.DNSMode = C.DNSHosts
|
dialMetadata.DNSMode = C.DNSHosts
|
||||||
dialMetadata = dialMetadata.Pure()
|
dialMetadata = dialMetadata.Pure()
|
||||||
}
|
}
|
||||||
|
@ -388,7 +388,7 @@ func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) {
|
||||||
)
|
)
|
||||||
|
|
||||||
if node := resolver.DefaultHosts.Search(metadata.Host); node != nil {
|
if node := resolver.DefaultHosts.Search(metadata.Host); node != nil {
|
||||||
metadata.DstIP = node.Data
|
metadata.DstIP = node.Data()
|
||||||
resolved = true
|
resolved = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue