Feature(dns): support custom hosts
This commit is contained in:
parent
f867f02546
commit
1a21c8ebfd
10 changed files with 359 additions and 79 deletions
|
@ -120,6 +120,12 @@ experimental:
|
|||
# listen: 0.0.0.0:53
|
||||
# enhanced-mode: redir-host # or fake-ip
|
||||
# # fake-ip-range: 198.18.0.1/16 # if you don't know what it is, don't change it
|
||||
# # experimental hosts, support wildcard (e.g. *.clash.dev Even *.foo.*.example.com)
|
||||
# # static domain has a higher priority than wildcard domain (foo.example.com > *.example.com)
|
||||
# # NOTE: hosts don't work with `fake-ip`
|
||||
# hosts:
|
||||
# '*.clash.dev': 127.0.0.1
|
||||
# 'alpha.clash.dev': '::1'
|
||||
# nameserver:
|
||||
# - 114.114.114.114
|
||||
# - tls://dns.rubyfish.cn:853 # dns over tls
|
||||
|
|
26
component/domain-trie/node.go
Normal file
26
component/domain-trie/node.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package trie
|
||||
|
||||
// Node is the trie's node
|
||||
type Node struct {
|
||||
Data interface{}
|
||||
children map[string]*Node
|
||||
}
|
||||
|
||||
func (n *Node) getChild(s string) *Node {
|
||||
return n.children[s]
|
||||
}
|
||||
|
||||
func (n *Node) hasChild(s string) bool {
|
||||
return n.getChild(s) != nil
|
||||
}
|
||||
|
||||
func (n *Node) addChild(s string, child *Node) {
|
||||
n.children[s] = child
|
||||
}
|
||||
|
||||
func newNode(data interface{}) *Node {
|
||||
return &Node{
|
||||
Data: data,
|
||||
children: map[string]*Node{},
|
||||
}
|
||||
}
|
84
component/domain-trie/tire.go
Normal file
84
component/domain-trie/tire.go
Normal file
|
@ -0,0 +1,84 @@
|
|||
package trie
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
wildcard = "*"
|
||||
domainStep = "."
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidDomain means insert domain is invalid
|
||||
ErrInvalidDomain = errors.New("invalid domain")
|
||||
)
|
||||
|
||||
// Trie contains the main logic for adding and searching nodes for domain segments.
|
||||
// support wildcard domain (e.g *.google.com)
|
||||
type Trie struct {
|
||||
root *Node
|
||||
}
|
||||
|
||||
// Insert adds a node to the trie.
|
||||
// Support
|
||||
// 1. www.example.com
|
||||
// 2. *.example.com
|
||||
// 3. subdomain.*.example.com
|
||||
func (t *Trie) Insert(domain string, data interface{}) error {
|
||||
parts := strings.Split(domain, domainStep)
|
||||
if len(parts) < 2 {
|
||||
return ErrInvalidDomain
|
||||
}
|
||||
|
||||
node := t.root
|
||||
// reverse storage domain part to save space
|
||||
for i := len(parts) - 1; i >= 0; i-- {
|
||||
part := parts[i]
|
||||
if !node.hasChild(part) {
|
||||
node.addChild(part, newNode(nil))
|
||||
}
|
||||
|
||||
node = node.getChild(part)
|
||||
}
|
||||
|
||||
node.Data = data
|
||||
return nil
|
||||
}
|
||||
|
||||
// Search is the most important part of the Trie.
|
||||
// Priority as:
|
||||
// 1. static part
|
||||
// 2. wildcard domain
|
||||
func (t *Trie) Search(domain string) *Node {
|
||||
parts := strings.Split(domain, domainStep)
|
||||
if len(parts) < 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
n := t.root
|
||||
for i := len(parts) - 1; i >= 0; i-- {
|
||||
part := parts[i]
|
||||
|
||||
var child *Node
|
||||
if !n.hasChild(part) {
|
||||
if !n.hasChild(wildcard) {
|
||||
return nil
|
||||
}
|
||||
|
||||
child = n.getChild(wildcard)
|
||||
} else {
|
||||
child = n.getChild(part)
|
||||
}
|
||||
|
||||
n = child
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// New returns a new, empty Trie.
|
||||
func New() *Trie {
|
||||
return &Trie{root: newNode(nil)}
|
||||
}
|
69
component/domain-trie/trie_test.go
Normal file
69
component/domain-trie/trie_test.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
package trie
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTrie_Basic(t *testing.T) {
|
||||
tree := New()
|
||||
domains := []string{
|
||||
"example.com",
|
||||
"google.com",
|
||||
}
|
||||
|
||||
for _, domain := range domains {
|
||||
tree.Insert(domain, net.ParseIP("127.0.0.1"))
|
||||
}
|
||||
|
||||
node := tree.Search("example.com")
|
||||
if node == nil {
|
||||
t.Error("should not recv nil")
|
||||
}
|
||||
|
||||
if !node.Data.(net.IP).Equal(net.IP{127, 0, 0, 1}) {
|
||||
t.Error("should equal 127.0.0.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrie_Wildcard(t *testing.T) {
|
||||
tree := New()
|
||||
domains := []string{
|
||||
"*.example.com",
|
||||
"sub.*.example.com",
|
||||
"*.dev",
|
||||
}
|
||||
|
||||
for _, domain := range domains {
|
||||
tree.Insert(domain, nil)
|
||||
}
|
||||
|
||||
if tree.Search("sub.example.com") == nil {
|
||||
t.Error("should not recv nil")
|
||||
}
|
||||
|
||||
if tree.Search("sub.foo.example.com") == nil {
|
||||
t.Error("should not recv nil")
|
||||
}
|
||||
|
||||
if tree.Search("foo.sub.example.com") != nil {
|
||||
t.Error("should recv nil")
|
||||
}
|
||||
|
||||
if tree.Search("foo.example.dev") != nil {
|
||||
t.Error("should recv nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrie_Boundary(t *testing.T) {
|
||||
tree := New()
|
||||
tree.Insert("*.dev", nil)
|
||||
|
||||
if err := tree.Insert("com", nil); err == nil {
|
||||
t.Error("should recv err")
|
||||
}
|
||||
|
||||
if tree.Search("dev") != nil {
|
||||
t.Error("should recv nil")
|
||||
}
|
||||
}
|
|
@ -12,6 +12,7 @@ import (
|
|||
adapters "github.com/Dreamacro/clash/adapters/outbound"
|
||||
"github.com/Dreamacro/clash/common/structure"
|
||||
"github.com/Dreamacro/clash/component/auth"
|
||||
trie "github.com/Dreamacro/clash/component/domain-trie"
|
||||
"github.com/Dreamacro/clash/component/fakeip"
|
||||
C "github.com/Dreamacro/clash/constant"
|
||||
"github.com/Dreamacro/clash/dns"
|
||||
|
@ -42,6 +43,7 @@ type DNS struct {
|
|||
IPv6 bool `yaml:"ipv6"`
|
||||
NameServer []dns.NameServer `yaml:"nameserver"`
|
||||
Fallback []dns.NameServer `yaml:"fallback"`
|
||||
Hosts *trie.Trie `yaml:"-"`
|
||||
Listen string `yaml:"listen"`
|
||||
EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"`
|
||||
FakeIPRange *fakeip.Pool
|
||||
|
@ -63,13 +65,14 @@ type Config struct {
|
|||
}
|
||||
|
||||
type rawDNS struct {
|
||||
Enable bool `yaml:"enable"`
|
||||
IPv6 bool `yaml:"ipv6"`
|
||||
NameServer []string `yaml:"nameserver"`
|
||||
Fallback []string `yaml:"fallback"`
|
||||
Listen string `yaml:"listen"`
|
||||
EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"`
|
||||
FakeIPRange string `yaml:"fake-ip-range"`
|
||||
Enable bool `yaml:"enable"`
|
||||
IPv6 bool `yaml:"ipv6"`
|
||||
NameServer []string `yaml:"nameserver"`
|
||||
Hosts map[string]string `yaml:"hosts"`
|
||||
Fallback []string `yaml:"fallback"`
|
||||
Listen string `yaml:"listen"`
|
||||
EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"`
|
||||
FakeIPRange string `yaml:"fake-ip-range"`
|
||||
}
|
||||
|
||||
type rawConfig struct {
|
||||
|
@ -134,6 +137,7 @@ func readConfig(path string) (*rawConfig, error) {
|
|||
DNS: rawDNS{
|
||||
Enable: false,
|
||||
FakeIPRange: "198.18.0.1/16",
|
||||
Hosts: map[string]string{},
|
||||
},
|
||||
}
|
||||
err = yaml.Unmarshal([]byte(data), &rawConfig)
|
||||
|
@ -518,6 +522,18 @@ func parseDNS(cfg rawDNS) (*DNS, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if len(cfg.Hosts) != 0 {
|
||||
tree := trie.New()
|
||||
for domain, ipStr := range cfg.Hosts {
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("%s is not a valid IP", ipStr)
|
||||
}
|
||||
tree.Insert(domain, ip)
|
||||
}
|
||||
dnsCfg.Hosts = tree
|
||||
}
|
||||
|
||||
if cfg.EnhancedMode == dns.FAKEIP {
|
||||
_, ipnet, err := net.ParseCIDR(cfg.FakeIPRange)
|
||||
if err != nil {
|
||||
|
|
121
dns/middleware.go
Normal file
121
dns/middleware.go
Normal file
|
@ -0,0 +1,121 @@
|
|||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/Dreamacro/clash/common/cache"
|
||||
"github.com/Dreamacro/clash/component/fakeip"
|
||||
"github.com/Dreamacro/clash/log"
|
||||
|
||||
D "github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type handler func(w D.ResponseWriter, r *D.Msg)
|
||||
|
||||
func withFakeIP(cache *cache.Cache, pool *fakeip.Pool) handler {
|
||||
return func(w D.ResponseWriter, r *D.Msg) {
|
||||
q := r.Question[0]
|
||||
|
||||
cacheItem := cache.Get("fakeip:" + q.String())
|
||||
if cache != nil {
|
||||
msg := cacheItem.(*D.Msg).Copy()
|
||||
setMsgTTL(msg, 1)
|
||||
msg.SetReply(r)
|
||||
w.WriteMsg(msg)
|
||||
return
|
||||
}
|
||||
|
||||
rr := &D.A{}
|
||||
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
|
||||
ip := pool.Get()
|
||||
rr.A = ip
|
||||
msg := r.Copy()
|
||||
msg.Answer = []D.RR{rr}
|
||||
putMsgToCache(cache, "fakeip:"+q.String(), msg)
|
||||
putMsgToCache(cache, ip.String(), msg)
|
||||
|
||||
setMsgTTL(msg, 1)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func withResolver(resolver *Resolver) handler {
|
||||
return func(w D.ResponseWriter, r *D.Msg) {
|
||||
msg, err := resolver.Exchange(r)
|
||||
|
||||
if err != nil {
|
||||
q := r.Question[0]
|
||||
qString := fmt.Sprintf("%s %s %s", q.Name, D.Class(q.Qclass).String(), D.Type(q.Qtype).String())
|
||||
log.Debugln("[DNS Server] Exchange %s failed: %v", qString, err)
|
||||
D.HandleFailed(w, r)
|
||||
return
|
||||
}
|
||||
msg.SetReply(r)
|
||||
w.WriteMsg(msg)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func withHost(resolver *Resolver, next handler) handler {
|
||||
hosts := resolver.hosts
|
||||
if hosts == nil {
|
||||
panic("dns/withHost: hosts should not be nil")
|
||||
}
|
||||
|
||||
return func(w D.ResponseWriter, r *D.Msg) {
|
||||
q := r.Question[0]
|
||||
if q.Qtype != D.TypeA && q.Qtype != D.TypeAAAA {
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
domain := strings.TrimRight(q.Name, ".")
|
||||
host := hosts.Search(domain)
|
||||
if host == nil {
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ip := host.Data.(net.IP)
|
||||
if q.Qtype == D.TypeAAAA && ip.To16() == nil {
|
||||
next(w, r)
|
||||
return
|
||||
} else if q.Qtype == D.TypeA && ip.To4() == nil {
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
var rr D.RR
|
||||
if q.Qtype == D.TypeAAAA {
|
||||
record := &D.AAAA{}
|
||||
record.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
|
||||
record.AAAA = ip
|
||||
rr = record
|
||||
} else {
|
||||
record := &D.A{}
|
||||
record.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
|
||||
record.A = ip
|
||||
rr = record
|
||||
}
|
||||
|
||||
msg := r.Copy()
|
||||
msg.Answer = []D.RR{rr}
|
||||
msg.SetReply(r)
|
||||
w.WriteMsg(msg)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func newHandler(resolver *Resolver) handler {
|
||||
if resolver.IsFakeIP() {
|
||||
return withFakeIP(resolver.cache, resolver.pool)
|
||||
}
|
||||
|
||||
if resolver.hosts != nil {
|
||||
return withHost(resolver, withResolver(resolver))
|
||||
}
|
||||
|
||||
return withResolver(resolver)
|
||||
}
|
|
@ -11,11 +11,13 @@ import (
|
|||
|
||||
"github.com/Dreamacro/clash/common/cache"
|
||||
"github.com/Dreamacro/clash/common/picker"
|
||||
trie "github.com/Dreamacro/clash/component/domain-trie"
|
||||
"github.com/Dreamacro/clash/component/fakeip"
|
||||
C "github.com/Dreamacro/clash/constant"
|
||||
|
||||
D "github.com/miekg/dns"
|
||||
geoip2 "github.com/oschwald/geoip2-golang"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -44,9 +46,11 @@ type Resolver struct {
|
|||
ipv6 bool
|
||||
mapping bool
|
||||
fakeip bool
|
||||
hosts *trie.Trie
|
||||
pool *fakeip.Pool
|
||||
fallback []resolver
|
||||
main []resolver
|
||||
group singleflight.Group
|
||||
cache *cache.Cache
|
||||
}
|
||||
|
||||
|
@ -134,13 +138,20 @@ func (r *Resolver) Exchange(m *D.Msg) (msg *D.Msg, err error) {
|
|||
}
|
||||
}()
|
||||
|
||||
isIPReq := isIPRequest(q)
|
||||
if isIPReq {
|
||||
msg, err = r.fallbackExchange(m)
|
||||
return
|
||||
ret, err, _ := r.group.Do(q.String(), func() (interface{}, error) {
|
||||
isIPReq := isIPRequest(q)
|
||||
if isIPReq {
|
||||
msg, err := r.fallbackExchange(m)
|
||||
return msg, err
|
||||
}
|
||||
|
||||
return r.batchExchange(r.main, m)
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
msg = ret.(*D.Msg)
|
||||
}
|
||||
|
||||
msg, err = r.batchExchange(r.main, m)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -266,6 +277,7 @@ type Config struct {
|
|||
Main, Fallback []NameServer
|
||||
IPv6 bool
|
||||
EnhancedMode EnhancedMode
|
||||
Hosts *trie.Trie
|
||||
Pool *fakeip.Pool
|
||||
}
|
||||
|
||||
|
@ -280,6 +292,7 @@ func New(config Config) *Resolver {
|
|||
cache: cache.New(time.Second * 60),
|
||||
mapping: config.EnhancedMode == MAPPING,
|
||||
fakeip: config.EnhancedMode == FAKEIP,
|
||||
hosts: config.Hosts,
|
||||
pool: config.Pool,
|
||||
}
|
||||
if len(config.Fallback) != 0 {
|
||||
|
|
|
@ -1,12 +1,8 @@
|
|||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/Dreamacro/clash/log"
|
||||
"github.com/miekg/dns"
|
||||
D "github.com/miekg/dns"
|
||||
)
|
||||
|
||||
|
@ -19,79 +15,26 @@ var (
|
|||
|
||||
type Server struct {
|
||||
*D.Server
|
||||
r *Resolver
|
||||
handler handler
|
||||
}
|
||||
|
||||
func (s *Server) ServeDNS(w D.ResponseWriter, r *D.Msg) {
|
||||
if s.r.IsFakeIP() {
|
||||
msg, err := s.handleFakeIP(r)
|
||||
if err != nil {
|
||||
D.HandleFailed(w, r)
|
||||
return
|
||||
}
|
||||
msg.SetReply(r)
|
||||
w.WriteMsg(msg)
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := s.r.Exchange(r)
|
||||
|
||||
if err != nil {
|
||||
if len(r.Question) > 0 {
|
||||
q := r.Question[0]
|
||||
qString := fmt.Sprintf("%s %s %s", q.Name, D.Class(q.Qclass).String(), D.Type(q.Qtype).String())
|
||||
log.Debugln("[DNS Server] Exchange %s failed: %v", qString, err)
|
||||
}
|
||||
if len(r.Question) == 0 {
|
||||
D.HandleFailed(w, r)
|
||||
return
|
||||
}
|
||||
msg.SetReply(r)
|
||||
w.WriteMsg(msg)
|
||||
|
||||
s.handler(w, r)
|
||||
}
|
||||
|
||||
func (s *Server) handleFakeIP(r *D.Msg) (msg *D.Msg, err error) {
|
||||
if len(r.Question) == 0 {
|
||||
err = errors.New("should have one question at least")
|
||||
return
|
||||
}
|
||||
|
||||
q := r.Question[0]
|
||||
|
||||
cache := s.r.cache.Get("fakeip:" + q.String())
|
||||
if cache != nil {
|
||||
msg = cache.(*D.Msg).Copy()
|
||||
setMsgTTL(msg, 1)
|
||||
return
|
||||
}
|
||||
|
||||
var ip net.IP
|
||||
defer func() {
|
||||
if msg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
putMsgToCache(s.r.cache, "fakeip:"+q.String(), msg)
|
||||
putMsgToCache(s.r.cache, ip.String(), msg)
|
||||
|
||||
setMsgTTL(msg, 1)
|
||||
}()
|
||||
|
||||
rr := &D.A{}
|
||||
rr.Hdr = dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: dnsDefaultTTL}
|
||||
ip = s.r.pool.Get()
|
||||
rr.A = ip
|
||||
msg = r.Copy()
|
||||
msg.Answer = []D.RR{rr}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) setReslover(r *Resolver) {
|
||||
s.r = r
|
||||
func (s *Server) setHandler(handler handler) {
|
||||
s.handler = handler
|
||||
}
|
||||
|
||||
func ReCreateServer(addr string, resolver *Resolver) error {
|
||||
if addr == address {
|
||||
server.setReslover(resolver)
|
||||
handler := newHandler(resolver)
|
||||
server.setHandler(handler)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -116,7 +59,8 @@ func ReCreateServer(addr string, resolver *Resolver) error {
|
|||
}
|
||||
|
||||
address = addr
|
||||
server = &Server{r: resolver}
|
||||
handler := newHandler(resolver)
|
||||
server = &Server{handler: handler}
|
||||
server.Server = &D.Server{Addr: addr, PacketConn: p, Handler: server}
|
||||
|
||||
go func() {
|
||||
|
|
2
go.mod
2
go.mod
|
@ -14,7 +14,7 @@ require (
|
|||
github.com/sirupsen/logrus v1.4.2
|
||||
golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58 // indirect
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58
|
||||
gopkg.in/eapache/channels.v1 v1.1.0
|
||||
gopkg.in/yaml.v2 v2.2.2
|
||||
)
|
||||
|
|
|
@ -67,6 +67,7 @@ func updateDNS(c *config.DNS) {
|
|||
Main: c.NameServer,
|
||||
Fallback: c.Fallback,
|
||||
IPv6: c.IPv6,
|
||||
Hosts: c.Hosts,
|
||||
EnhancedMode: c.EnhancedMode,
|
||||
Pool: c.FakeIPRange,
|
||||
})
|
||||
|
|
Loading…
Reference in a new issue