Feature: domain trie support dot dot wildcard
This commit is contained in:
parent
5591e15452
commit
65dab4e34f
2 changed files with 82 additions and 57 deletions
|
@ -6,8 +6,9 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
wildcard = "*"
|
wildcard = "*"
|
||||||
domainStep = "."
|
dotWildcard = ""
|
||||||
|
domainStep = "."
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -21,8 +22,23 @@ type Trie struct {
|
||||||
root *Node
|
root *Node
|
||||||
}
|
}
|
||||||
|
|
||||||
func isValidDomain(domain string) bool {
|
func validAndSplitDomain(domain string) ([]string, bool) {
|
||||||
return domain != "" && domain[0] != '.' && domain[len(domain)-1] != '.'
|
if domain != "" && domain[len(domain)-1] == '.' {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(domain, domainStep)
|
||||||
|
if len(parts) == 1 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, part := range parts[1:] {
|
||||||
|
if part == "" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return parts, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert adds a node to the trie.
|
// Insert adds a node to the trie.
|
||||||
|
@ -30,12 +46,13 @@ func isValidDomain(domain string) bool {
|
||||||
// 1. www.example.com
|
// 1. www.example.com
|
||||||
// 2. *.example.com
|
// 2. *.example.com
|
||||||
// 3. subdomain.*.example.com
|
// 3. subdomain.*.example.com
|
||||||
|
// 4. .example.com
|
||||||
func (t *Trie) Insert(domain string, data interface{}) error {
|
func (t *Trie) Insert(domain string, data interface{}) error {
|
||||||
if !isValidDomain(domain) {
|
parts, valid := validAndSplitDomain(domain)
|
||||||
|
if !valid {
|
||||||
return ErrInvalidDomain
|
return ErrInvalidDomain
|
||||||
}
|
}
|
||||||
|
|
||||||
parts := strings.Split(domain, domainStep)
|
|
||||||
node := t.root
|
node := t.root
|
||||||
// reverse storage domain part to save space
|
// reverse storage domain part to save space
|
||||||
for i := len(parts) - 1; i >= 0; i-- {
|
for i := len(parts) - 1; i >= 0; i-- {
|
||||||
|
@ -55,28 +72,38 @@ func (t *Trie) Insert(domain string, data interface{}) error {
|
||||||
// Priority as:
|
// Priority as:
|
||||||
// 1. static part
|
// 1. static part
|
||||||
// 2. wildcard domain
|
// 2. wildcard domain
|
||||||
|
// 2. dot wildcard domain
|
||||||
func (t *Trie) Search(domain string) *Node {
|
func (t *Trie) Search(domain string) *Node {
|
||||||
if !isValidDomain(domain) {
|
parts, valid := validAndSplitDomain(domain)
|
||||||
|
if !valid || parts[0] == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
parts := strings.Split(domain, domainStep)
|
|
||||||
|
|
||||||
n := t.root
|
n := t.root
|
||||||
|
var dotWildcardNode *Node
|
||||||
for i := len(parts) - 1; i >= 0; i-- {
|
for i := len(parts) - 1; i >= 0; i-- {
|
||||||
part := parts[i]
|
part := parts[i]
|
||||||
|
|
||||||
var child *Node
|
if node := n.getChild(dotWildcard); node != nil {
|
||||||
if !n.hasChild(part) {
|
dotWildcardNode = node
|
||||||
if !n.hasChild(wildcard) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
child = n.getChild(wildcard)
|
|
||||||
} else {
|
|
||||||
child = n.getChild(part)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
n = child
|
if n.hasChild(part) {
|
||||||
|
n = n.getChild(part)
|
||||||
|
} else {
|
||||||
|
n = n.getChild(wildcard)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if n == nil {
|
||||||
|
if dotWildcardNode != nil {
|
||||||
|
return dotWildcardNode
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if n.Data == nil {
|
if n.Data == nil {
|
||||||
|
|
|
@ -3,6 +3,8 @@ package trie
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
var localIP = net.IP{127, 0, 0, 1}
|
var localIP = net.IP{127, 0, 0, 1}
|
||||||
|
@ -19,17 +21,9 @@ func TestTrie_Basic(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
node := tree.Search("example.com")
|
node := tree.Search("example.com")
|
||||||
if node == nil {
|
assert.NotNil(t, node)
|
||||||
t.Error("should not recv nil")
|
assert.True(t, node.Data.(net.IP).Equal(localIP))
|
||||||
}
|
assert.NotNil(t, tree.Insert("", localIP))
|
||||||
|
|
||||||
if !node.Data.(net.IP).Equal(localIP) {
|
|
||||||
t.Error("should equal 127.0.0.1")
|
|
||||||
}
|
|
||||||
|
|
||||||
if tree.Insert("", localIP) == nil {
|
|
||||||
t.Error("should return error")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTrie_Wildcard(t *testing.T) {
|
func TestTrie_Wildcard(t *testing.T) {
|
||||||
|
@ -38,50 +32,54 @@ func TestTrie_Wildcard(t *testing.T) {
|
||||||
"*.example.com",
|
"*.example.com",
|
||||||
"sub.*.example.com",
|
"sub.*.example.com",
|
||||||
"*.dev",
|
"*.dev",
|
||||||
|
".org",
|
||||||
|
".example.net",
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
tree.Insert(domain, localIP)
|
tree.Insert(domain, localIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tree.Search("sub.example.com") == nil {
|
assert.NotNil(t, tree.Search("sub.example.com"))
|
||||||
t.Error("should not recv nil")
|
assert.NotNil(t, tree.Search("sub.foo.example.com"))
|
||||||
|
assert.NotNil(t, tree.Search("test.org"))
|
||||||
|
assert.NotNil(t, tree.Search("test.example.net"))
|
||||||
|
assert.Nil(t, tree.Search("foo.sub.example.com"))
|
||||||
|
assert.Nil(t, tree.Search("foo.example.dev"))
|
||||||
|
assert.Nil(t, tree.Search("example.com"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTrie_Priority(t *testing.T) {
|
||||||
|
tree := New()
|
||||||
|
domains := []string{
|
||||||
|
".dev",
|
||||||
|
"example.dev",
|
||||||
|
"*.example.dev",
|
||||||
|
"test.example.dev",
|
||||||
}
|
}
|
||||||
|
|
||||||
if tree.Search("sub.foo.example.com") == nil {
|
assertFn := func(domain string, data int) {
|
||||||
t.Error("should not recv nil")
|
node := tree.Search(domain)
|
||||||
|
assert.NotNil(t, node)
|
||||||
|
assert.Equal(t, data, node.Data)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tree.Search("foo.sub.example.com") != nil {
|
for idx, domain := range domains {
|
||||||
t.Error("should recv nil")
|
tree.Insert(domain, idx)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tree.Search("foo.example.dev") != nil {
|
assertFn("test.dev", 0)
|
||||||
t.Error("should recv nil")
|
assertFn("foo.bar.dev", 0)
|
||||||
}
|
assertFn("example.dev", 1)
|
||||||
|
assertFn("foo.example.dev", 2)
|
||||||
if tree.Search("example.com") != nil {
|
assertFn("test.example.dev", 3)
|
||||||
t.Error("should recv nil")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTrie_Boundary(t *testing.T) {
|
func TestTrie_Boundary(t *testing.T) {
|
||||||
tree := New()
|
tree := New()
|
||||||
tree.Insert("*.dev", localIP)
|
tree.Insert("*.dev", localIP)
|
||||||
|
|
||||||
if err := tree.Insert(".", localIP); err == nil {
|
assert.NotNil(t, tree.Insert(".", localIP))
|
||||||
t.Error("should recv err")
|
assert.NotNil(t, tree.Insert("..dev", localIP))
|
||||||
}
|
assert.Nil(t, tree.Search("dev"))
|
||||||
|
|
||||||
if err := tree.Insert(".com", localIP); err == nil {
|
|
||||||
t.Error("should recv err")
|
|
||||||
}
|
|
||||||
|
|
||||||
if tree.Search("dev") != nil {
|
|
||||||
t.Error("should recv nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
if tree.Search(".dev") != nil {
|
|
||||||
t.Error("should recv nil")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue