From 3497fdaf45e06e75b558646735b7e335c58b6ffd Mon Sep 17 00:00:00 2001 From: Dreamacro <305009791@qq.com> Date: Mon, 15 Jul 2019 18:00:51 +0800 Subject: [PATCH] Fix(domain-trie): Incorrect result --- component/domain-trie/tire.go | 4 ++++ component/domain-trie/trie_test.go | 16 +++++++++++----- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/component/domain-trie/tire.go b/component/domain-trie/tire.go index 8d570214..13567a98 100644 --- a/component/domain-trie/tire.go +++ b/component/domain-trie/tire.go @@ -75,6 +75,10 @@ func (t *Trie) Search(domain string) *Node { n = child } + if n.Data == nil { + return nil + } + return n } diff --git a/component/domain-trie/trie_test.go b/component/domain-trie/trie_test.go index cd80ce3d..bd594e97 100644 --- a/component/domain-trie/trie_test.go +++ b/component/domain-trie/trie_test.go @@ -5,6 +5,8 @@ import ( "testing" ) +var localIP = net.IP{127, 0, 0, 1} + func TestTrie_Basic(t *testing.T) { tree := New() domains := []string{ @@ -13,7 +15,7 @@ func TestTrie_Basic(t *testing.T) { } for _, domain := range domains { - tree.Insert(domain, net.ParseIP("127.0.0.1")) + tree.Insert(domain, localIP) } node := tree.Search("example.com") @@ -21,7 +23,7 @@ func TestTrie_Basic(t *testing.T) { t.Error("should not recv nil") } - if !node.Data.(net.IP).Equal(net.IP{127, 0, 0, 1}) { + if !node.Data.(net.IP).Equal(localIP) { t.Error("should equal 127.0.0.1") } } @@ -35,7 +37,7 @@ func TestTrie_Wildcard(t *testing.T) { } for _, domain := range domains { - tree.Insert(domain, nil) + tree.Insert(domain, localIP) } if tree.Search("sub.example.com") == nil { @@ -53,13 +55,17 @@ func TestTrie_Wildcard(t *testing.T) { if tree.Search("foo.example.dev") != nil { t.Error("should recv nil") } + + if tree.Search("example.com") != nil { + t.Error("should recv nil") + } } func TestTrie_Boundary(t *testing.T) { tree := New() - tree.Insert("*.dev", nil) + tree.Insert("*.dev", localIP) - if err := tree.Insert("com", nil); err == nil { + if err := tree.Insert("com", localIP); err == nil { t.Error("should recv err") }