Feature(dns): support custom hosts

This commit is contained in:
Dreamacro 2019-07-14 19:29:58 +08:00
parent f867f02546
commit 1a21c8ebfd
10 changed files with 359 additions and 79 deletions

View file

@ -120,6 +120,12 @@ experimental:
# listen: 0.0.0.0:53
# enhanced-mode: redir-host # or fake-ip
# # fake-ip-range: 198.18.0.1/16 # if you don't know what it is, don't change it
# # experimental hosts, support wildcard (e.g. *.clash.dev Even *.foo.*.example.com)
# # static domain has a higher priority than wildcard domain (foo.example.com > *.example.com)
# # NOTE: hosts don't work with `fake-ip`
# hosts:
# '*.clash.dev': 127.0.0.1
# 'alpha.clash.dev': '::1'
# nameserver:
# - 114.114.114.114
# - tls://dns.rubyfish.cn:853 # dns over tls

View file

@ -0,0 +1,26 @@
package trie
// Node is the trie's node
type Node struct {
Data interface{}
children map[string]*Node
}
func (n *Node) getChild(s string) *Node {
return n.children[s]
}
func (n *Node) hasChild(s string) bool {
return n.getChild(s) != nil
}
func (n *Node) addChild(s string, child *Node) {
n.children[s] = child
}
func newNode(data interface{}) *Node {
return &Node{
Data: data,
children: map[string]*Node{},
}
}

View file

@ -0,0 +1,84 @@
package trie
import (
"errors"
"strings"
)
const (
wildcard = "*"
domainStep = "."
)
var (
// ErrInvalidDomain means insert domain is invalid
ErrInvalidDomain = errors.New("invalid domain")
)
// Trie contains the main logic for adding and searching nodes for domain segments.
// support wildcard domain (e.g *.google.com)
type Trie struct {
root *Node
}
// Insert adds a node to the trie.
// Support
// 1. www.example.com
// 2. *.example.com
// 3. subdomain.*.example.com
func (t *Trie) Insert(domain string, data interface{}) error {
parts := strings.Split(domain, domainStep)
if len(parts) < 2 {
return ErrInvalidDomain
}
node := t.root
// reverse storage domain part to save space
for i := len(parts) - 1; i >= 0; i-- {
part := parts[i]
if !node.hasChild(part) {
node.addChild(part, newNode(nil))
}
node = node.getChild(part)
}
node.Data = data
return nil
}
// Search is the most important part of the Trie.
// Priority as:
// 1. static part
// 2. wildcard domain
func (t *Trie) Search(domain string) *Node {
parts := strings.Split(domain, domainStep)
if len(parts) < 2 {
return nil
}
n := t.root
for i := len(parts) - 1; i >= 0; i-- {
part := parts[i]
var child *Node
if !n.hasChild(part) {
if !n.hasChild(wildcard) {
return nil
}
child = n.getChild(wildcard)
} else {
child = n.getChild(part)
}
n = child
}
return n
}
// New returns a new, empty Trie.
func New() *Trie {
return &Trie{root: newNode(nil)}
}

View file

@ -0,0 +1,69 @@
package trie
import (
"net"
"testing"
)
func TestTrie_Basic(t *testing.T) {
tree := New()
domains := []string{
"example.com",
"google.com",
}
for _, domain := range domains {
tree.Insert(domain, net.ParseIP("127.0.0.1"))
}
node := tree.Search("example.com")
if node == nil {
t.Error("should not recv nil")
}
if !node.Data.(net.IP).Equal(net.IP{127, 0, 0, 1}) {
t.Error("should equal 127.0.0.1")
}
}
func TestTrie_Wildcard(t *testing.T) {
tree := New()
domains := []string{
"*.example.com",
"sub.*.example.com",
"*.dev",
}
for _, domain := range domains {
tree.Insert(domain, nil)
}
if tree.Search("sub.example.com") == nil {
t.Error("should not recv nil")
}
if tree.Search("sub.foo.example.com") == nil {
t.Error("should not recv nil")
}
if tree.Search("foo.sub.example.com") != nil {
t.Error("should recv nil")
}
if tree.Search("foo.example.dev") != nil {
t.Error("should recv nil")
}
}
func TestTrie_Boundary(t *testing.T) {
tree := New()
tree.Insert("*.dev", nil)
if err := tree.Insert("com", nil); err == nil {
t.Error("should recv err")
}
if tree.Search("dev") != nil {
t.Error("should recv nil")
}
}

View file

@ -12,6 +12,7 @@ import (
adapters "github.com/Dreamacro/clash/adapters/outbound"
"github.com/Dreamacro/clash/common/structure"
"github.com/Dreamacro/clash/component/auth"
trie "github.com/Dreamacro/clash/component/domain-trie"
"github.com/Dreamacro/clash/component/fakeip"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/dns"
@ -42,6 +43,7 @@ type DNS struct {
IPv6 bool `yaml:"ipv6"`
NameServer []dns.NameServer `yaml:"nameserver"`
Fallback []dns.NameServer `yaml:"fallback"`
Hosts *trie.Trie `yaml:"-"`
Listen string `yaml:"listen"`
EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"`
FakeIPRange *fakeip.Pool
@ -63,13 +65,14 @@ type Config struct {
}
type rawDNS struct {
Enable bool `yaml:"enable"`
IPv6 bool `yaml:"ipv6"`
NameServer []string `yaml:"nameserver"`
Fallback []string `yaml:"fallback"`
Listen string `yaml:"listen"`
EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"`
FakeIPRange string `yaml:"fake-ip-range"`
Enable bool `yaml:"enable"`
IPv6 bool `yaml:"ipv6"`
NameServer []string `yaml:"nameserver"`
Hosts map[string]string `yaml:"hosts"`
Fallback []string `yaml:"fallback"`
Listen string `yaml:"listen"`
EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"`
FakeIPRange string `yaml:"fake-ip-range"`
}
type rawConfig struct {
@ -134,6 +137,7 @@ func readConfig(path string) (*rawConfig, error) {
DNS: rawDNS{
Enable: false,
FakeIPRange: "198.18.0.1/16",
Hosts: map[string]string{},
},
}
err = yaml.Unmarshal([]byte(data), &rawConfig)
@ -518,6 +522,18 @@ func parseDNS(cfg rawDNS) (*DNS, error) {
return nil, err
}
if len(cfg.Hosts) != 0 {
tree := trie.New()
for domain, ipStr := range cfg.Hosts {
ip := net.ParseIP(ipStr)
if ip == nil {
return nil, fmt.Errorf("%s is not a valid IP", ipStr)
}
tree.Insert(domain, ip)
}
dnsCfg.Hosts = tree
}
if cfg.EnhancedMode == dns.FAKEIP {
_, ipnet, err := net.ParseCIDR(cfg.FakeIPRange)
if err != nil {

121
dns/middleware.go Normal file
View file

@ -0,0 +1,121 @@
package dns
import (
"fmt"
"net"
"strings"
"github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/component/fakeip"
"github.com/Dreamacro/clash/log"
D "github.com/miekg/dns"
)
type handler func(w D.ResponseWriter, r *D.Msg)
func withFakeIP(cache *cache.Cache, pool *fakeip.Pool) handler {
return func(w D.ResponseWriter, r *D.Msg) {
q := r.Question[0]
cacheItem := cache.Get("fakeip:" + q.String())
if cache != nil {
msg := cacheItem.(*D.Msg).Copy()
setMsgTTL(msg, 1)
msg.SetReply(r)
w.WriteMsg(msg)
return
}
rr := &D.A{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
ip := pool.Get()
rr.A = ip
msg := r.Copy()
msg.Answer = []D.RR{rr}
putMsgToCache(cache, "fakeip:"+q.String(), msg)
putMsgToCache(cache, ip.String(), msg)
setMsgTTL(msg, 1)
return
}
}
func withResolver(resolver *Resolver) handler {
return func(w D.ResponseWriter, r *D.Msg) {
msg, err := resolver.Exchange(r)
if err != nil {
q := r.Question[0]
qString := fmt.Sprintf("%s %s %s", q.Name, D.Class(q.Qclass).String(), D.Type(q.Qtype).String())
log.Debugln("[DNS Server] Exchange %s failed: %v", qString, err)
D.HandleFailed(w, r)
return
}
msg.SetReply(r)
w.WriteMsg(msg)
return
}
}
func withHost(resolver *Resolver, next handler) handler {
hosts := resolver.hosts
if hosts == nil {
panic("dns/withHost: hosts should not be nil")
}
return func(w D.ResponseWriter, r *D.Msg) {
q := r.Question[0]
if q.Qtype != D.TypeA && q.Qtype != D.TypeAAAA {
next(w, r)
return
}
domain := strings.TrimRight(q.Name, ".")
host := hosts.Search(domain)
if host == nil {
next(w, r)
return
}
ip := host.Data.(net.IP)
if q.Qtype == D.TypeAAAA && ip.To16() == nil {
next(w, r)
return
} else if q.Qtype == D.TypeA && ip.To4() == nil {
next(w, r)
return
}
var rr D.RR
if q.Qtype == D.TypeAAAA {
record := &D.AAAA{}
record.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
record.AAAA = ip
rr = record
} else {
record := &D.A{}
record.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
record.A = ip
rr = record
}
msg := r.Copy()
msg.Answer = []D.RR{rr}
msg.SetReply(r)
w.WriteMsg(msg)
return
}
}
func newHandler(resolver *Resolver) handler {
if resolver.IsFakeIP() {
return withFakeIP(resolver.cache, resolver.pool)
}
if resolver.hosts != nil {
return withHost(resolver, withResolver(resolver))
}
return withResolver(resolver)
}

View file

@ -11,11 +11,13 @@ import (
"github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/common/picker"
trie "github.com/Dreamacro/clash/component/domain-trie"
"github.com/Dreamacro/clash/component/fakeip"
C "github.com/Dreamacro/clash/constant"
D "github.com/miekg/dns"
geoip2 "github.com/oschwald/geoip2-golang"
"golang.org/x/sync/singleflight"
)
var (
@ -44,9 +46,11 @@ type Resolver struct {
ipv6 bool
mapping bool
fakeip bool
hosts *trie.Trie
pool *fakeip.Pool
fallback []resolver
main []resolver
group singleflight.Group
cache *cache.Cache
}
@ -134,13 +138,20 @@ func (r *Resolver) Exchange(m *D.Msg) (msg *D.Msg, err error) {
}
}()
isIPReq := isIPRequest(q)
if isIPReq {
msg, err = r.fallbackExchange(m)
return
ret, err, _ := r.group.Do(q.String(), func() (interface{}, error) {
isIPReq := isIPRequest(q)
if isIPReq {
msg, err := r.fallbackExchange(m)
return msg, err
}
return r.batchExchange(r.main, m)
})
if err == nil {
msg = ret.(*D.Msg)
}
msg, err = r.batchExchange(r.main, m)
return
}
@ -266,6 +277,7 @@ type Config struct {
Main, Fallback []NameServer
IPv6 bool
EnhancedMode EnhancedMode
Hosts *trie.Trie
Pool *fakeip.Pool
}
@ -280,6 +292,7 @@ func New(config Config) *Resolver {
cache: cache.New(time.Second * 60),
mapping: config.EnhancedMode == MAPPING,
fakeip: config.EnhancedMode == FAKEIP,
hosts: config.Hosts,
pool: config.Pool,
}
if len(config.Fallback) != 0 {

View file

@ -1,12 +1,8 @@
package dns
import (
"errors"
"fmt"
"net"
"github.com/Dreamacro/clash/log"
"github.com/miekg/dns"
D "github.com/miekg/dns"
)
@ -19,79 +15,26 @@ var (
type Server struct {
*D.Server
r *Resolver
handler handler
}
func (s *Server) ServeDNS(w D.ResponseWriter, r *D.Msg) {
if s.r.IsFakeIP() {
msg, err := s.handleFakeIP(r)
if err != nil {
D.HandleFailed(w, r)
return
}
msg.SetReply(r)
w.WriteMsg(msg)
return
}
msg, err := s.r.Exchange(r)
if err != nil {
if len(r.Question) > 0 {
q := r.Question[0]
qString := fmt.Sprintf("%s %s %s", q.Name, D.Class(q.Qclass).String(), D.Type(q.Qtype).String())
log.Debugln("[DNS Server] Exchange %s failed: %v", qString, err)
}
if len(r.Question) == 0 {
D.HandleFailed(w, r)
return
}
msg.SetReply(r)
w.WriteMsg(msg)
s.handler(w, r)
}
func (s *Server) handleFakeIP(r *D.Msg) (msg *D.Msg, err error) {
if len(r.Question) == 0 {
err = errors.New("should have one question at least")
return
}
q := r.Question[0]
cache := s.r.cache.Get("fakeip:" + q.String())
if cache != nil {
msg = cache.(*D.Msg).Copy()
setMsgTTL(msg, 1)
return
}
var ip net.IP
defer func() {
if msg == nil {
return
}
putMsgToCache(s.r.cache, "fakeip:"+q.String(), msg)
putMsgToCache(s.r.cache, ip.String(), msg)
setMsgTTL(msg, 1)
}()
rr := &D.A{}
rr.Hdr = dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: dnsDefaultTTL}
ip = s.r.pool.Get()
rr.A = ip
msg = r.Copy()
msg.Answer = []D.RR{rr}
return
}
func (s *Server) setReslover(r *Resolver) {
s.r = r
func (s *Server) setHandler(handler handler) {
s.handler = handler
}
func ReCreateServer(addr string, resolver *Resolver) error {
if addr == address {
server.setReslover(resolver)
handler := newHandler(resolver)
server.setHandler(handler)
return nil
}
@ -116,7 +59,8 @@ func ReCreateServer(addr string, resolver *Resolver) error {
}
address = addr
server = &Server{r: resolver}
handler := newHandler(resolver)
server = &Server{handler: handler}
server.Server = &D.Server{Addr: addr, PacketConn: p, Handler: server}
go func() {

2
go.mod
View file

@ -14,7 +14,7 @@ require (
github.com/sirupsen/logrus v1.4.2
golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3
golang.org/x/sync v0.0.0-20190423024810-112230192c58 // indirect
golang.org/x/sync v0.0.0-20190423024810-112230192c58
gopkg.in/eapache/channels.v1 v1.1.0
gopkg.in/yaml.v2 v2.2.2
)

View file

@ -67,6 +67,7 @@ func updateDNS(c *config.DNS) {
Main: c.NameServer,
Fallback: c.Fallback,
IPv6: c.IPv6,
Hosts: c.Hosts,
EnhancedMode: c.EnhancedMode,
Pool: c.FakeIPRange,
})