Feature: domain trie support dot dot wildcard

This commit is contained in:
Dreamacro 2020-04-08 15:45:59 +08:00
parent 5591e15452
commit 65dab4e34f
2 changed files with 82 additions and 57 deletions

View file

@ -7,6 +7,7 @@ import (
const ( const (
wildcard = "*" wildcard = "*"
dotWildcard = ""
domainStep = "." domainStep = "."
) )
@ -21,8 +22,23 @@ type Trie struct {
root *Node root *Node
} }
func isValidDomain(domain string) bool { func validAndSplitDomain(domain string) ([]string, bool) {
return domain != "" && domain[0] != '.' && domain[len(domain)-1] != '.' if domain != "" && domain[len(domain)-1] == '.' {
return nil, false
}
parts := strings.Split(domain, domainStep)
if len(parts) == 1 {
return nil, false
}
for _, part := range parts[1:] {
if part == "" {
return nil, false
}
}
return parts, true
} }
// Insert adds a node to the trie. // Insert adds a node to the trie.
@ -30,12 +46,13 @@ func isValidDomain(domain string) bool {
// 1. www.example.com // 1. www.example.com
// 2. *.example.com // 2. *.example.com
// 3. subdomain.*.example.com // 3. subdomain.*.example.com
// 4. .example.com
func (t *Trie) Insert(domain string, data interface{}) error { func (t *Trie) Insert(domain string, data interface{}) error {
if !isValidDomain(domain) { parts, valid := validAndSplitDomain(domain)
if !valid {
return ErrInvalidDomain return ErrInvalidDomain
} }
parts := strings.Split(domain, domainStep)
node := t.root node := t.root
// reverse storage domain part to save space // reverse storage domain part to save space
for i := len(parts) - 1; i >= 0; i-- { for i := len(parts) - 1; i >= 0; i-- {
@ -55,28 +72,38 @@ func (t *Trie) Insert(domain string, data interface{}) error {
// Priority as: // Priority as:
// 1. static part // 1. static part
// 2. wildcard domain // 2. wildcard domain
// 2. dot wildcard domain
func (t *Trie) Search(domain string) *Node { func (t *Trie) Search(domain string) *Node {
if !isValidDomain(domain) { parts, valid := validAndSplitDomain(domain)
if !valid || parts[0] == "" {
return nil return nil
} }
parts := strings.Split(domain, domainStep)
n := t.root n := t.root
var dotWildcardNode *Node
for i := len(parts) - 1; i >= 0; i-- { for i := len(parts) - 1; i >= 0; i-- {
part := parts[i] part := parts[i]
var child *Node if node := n.getChild(dotWildcard); node != nil {
if !n.hasChild(part) { dotWildcardNode = node
if !n.hasChild(wildcard) {
return nil
} }
child = n.getChild(wildcard) if n.hasChild(part) {
n = n.getChild(part)
} else { } else {
child = n.getChild(part) n = n.getChild(wildcard)
} }
n = child if n == nil {
break
}
}
if n == nil {
if dotWildcardNode != nil {
return dotWildcardNode
}
return nil
} }
if n.Data == nil { if n.Data == nil {

View file

@ -3,6 +3,8 @@ package trie
import ( import (
"net" "net"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
var localIP = net.IP{127, 0, 0, 1} var localIP = net.IP{127, 0, 0, 1}
@ -19,17 +21,9 @@ func TestTrie_Basic(t *testing.T) {
} }
node := tree.Search("example.com") node := tree.Search("example.com")
if node == nil { assert.NotNil(t, node)
t.Error("should not recv nil") assert.True(t, node.Data.(net.IP).Equal(localIP))
} assert.NotNil(t, tree.Insert("", localIP))
if !node.Data.(net.IP).Equal(localIP) {
t.Error("should equal 127.0.0.1")
}
if tree.Insert("", localIP) == nil {
t.Error("should return error")
}
} }
func TestTrie_Wildcard(t *testing.T) { func TestTrie_Wildcard(t *testing.T) {
@ -38,50 +32,54 @@ func TestTrie_Wildcard(t *testing.T) {
"*.example.com", "*.example.com",
"sub.*.example.com", "sub.*.example.com",
"*.dev", "*.dev",
".org",
".example.net",
} }
for _, domain := range domains { for _, domain := range domains {
tree.Insert(domain, localIP) tree.Insert(domain, localIP)
} }
if tree.Search("sub.example.com") == nil { assert.NotNil(t, tree.Search("sub.example.com"))
t.Error("should not recv nil") assert.NotNil(t, tree.Search("sub.foo.example.com"))
assert.NotNil(t, tree.Search("test.org"))
assert.NotNil(t, tree.Search("test.example.net"))
assert.Nil(t, tree.Search("foo.sub.example.com"))
assert.Nil(t, tree.Search("foo.example.dev"))
assert.Nil(t, tree.Search("example.com"))
}
func TestTrie_Priority(t *testing.T) {
tree := New()
domains := []string{
".dev",
"example.dev",
"*.example.dev",
"test.example.dev",
} }
if tree.Search("sub.foo.example.com") == nil { assertFn := func(domain string, data int) {
t.Error("should not recv nil") node := tree.Search(domain)
assert.NotNil(t, node)
assert.Equal(t, data, node.Data)
} }
if tree.Search("foo.sub.example.com") != nil { for idx, domain := range domains {
t.Error("should recv nil") tree.Insert(domain, idx)
} }
if tree.Search("foo.example.dev") != nil { assertFn("test.dev", 0)
t.Error("should recv nil") assertFn("foo.bar.dev", 0)
} assertFn("example.dev", 1)
assertFn("foo.example.dev", 2)
if tree.Search("example.com") != nil { assertFn("test.example.dev", 3)
t.Error("should recv nil")
}
} }
func TestTrie_Boundary(t *testing.T) { func TestTrie_Boundary(t *testing.T) {
tree := New() tree := New()
tree.Insert("*.dev", localIP) tree.Insert("*.dev", localIP)
if err := tree.Insert(".", localIP); err == nil { assert.NotNil(t, tree.Insert(".", localIP))
t.Error("should recv err") assert.NotNil(t, tree.Insert("..dev", localIP))
} assert.Nil(t, tree.Search("dev"))
if err := tree.Insert(".com", localIP); err == nil {
t.Error("should recv err")
}
if tree.Search("dev") != nil {
t.Error("should recv nil")
}
if tree.Search(".dev") != nil {
t.Error("should recv nil")
}
} }