[Refactor] gvisor support hijack dns list

dns-hijack:
 - 1.1.1.1
 - 8.8.8.8:53
 - tcp://1.1.1.1:53
 - udp://223.5.5.5
 - 10.0.0.1:5353
This commit is contained in:
gVisor bot 2022-01-09 00:35:45 +08:00
parent 1c1eb6bdfb
commit 0a96994452
3 changed files with 308 additions and 133 deletions

46
common/net/tcpip.go Normal file
View file

@ -0,0 +1,46 @@
package net
import (
"fmt"
"net"
"strings"
)
func SplitNetworkType(s string) (string, string, error) {
var (
shecme string
hostPort string
)
result := strings.Split(s, "://")
if len(result) == 2 {
shecme = result[0]
hostPort = result[1]
} else if len(result) == 1 {
hostPort = result[0]
} else {
return "", "", fmt.Errorf("tcp/udp style error")
}
if len(shecme) == 0 {
shecme = "udp"
}
if shecme != "tcp" && shecme != "udp" {
return "", "", fmt.Errorf("scheme should be tcp:// or udp://")
} else {
return shecme, hostPort, nil
}
}
func SplitHostPort(s string) (host, port string, hasPort bool, err error) {
temp := s
hasPort = true
if !strings.Contains(s, ":") && !strings.Contains(s, "]:") {
temp += ":0"
hasPort = false
}
host, port, err = net.SplitHostPort(temp)
return
}

View file

@ -36,7 +36,7 @@ const nicID tcpip.NICID = 1
type gvisorAdapter struct { type gvisorAdapter struct {
device dev.TunDevice device dev.TunDevice
ipstack *stack.Stack ipstack *stack.Stack
dnsServers []*DNSServer dnsServer *DNSServer
udpIn chan<- *inbound.PacketAdapter udpIn chan<- *inbound.PacketAdapter
stackName string stackName string
@ -47,7 +47,7 @@ type gvisorAdapter struct {
writeHandle *channel.NotificationHandle writeHandle *channel.NotificationHandle
} }
// GvisorAdapter create GvisorAdapter // NewAdapter GvisorAdapter create GvisorAdapter
func NewAdapter(device dev.TunDevice, conf config.Tun, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.TunAdapter, error) { func NewAdapter(device dev.TunDevice, conf config.Tun, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.TunAdapter, error) {
ipstack := stack.New(stack.Options{ ipstack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
@ -132,7 +132,7 @@ func (t *gvisorAdapter) AutoRoute() bool {
// Close close the TunAdapter // Close close the TunAdapter
func (t *gvisorAdapter) Close() { func (t *gvisorAdapter) Close() {
t.StopAllDNSServer() t.StopDNSServer()
if t.ipstack != nil { if t.ipstack != nil {
t.ipstack.Close() t.ipstack.Close()
} }

View file

@ -2,13 +2,14 @@ package gvisor
import ( import (
"fmt" "fmt"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"net" "net"
Common "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/dns" "github.com/Dreamacro/clash/dns"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
D "github.com/miekg/dns" D "github.com/miekg/dns"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@ -23,15 +24,33 @@ var (
ipv6Zero = tcpip.Address(net.IPv6zero.To16()) ipv6Zero = tcpip.Address(net.IPv6zero.To16())
) )
type ListenerWrap struct {
net.Listener
listener net.Listener
}
func (l *ListenerWrap) Accept() (conn net.Conn, err error) {
conn, err = l.listener.Accept()
log.Debugln("[DNS] hijack tcp:%s", l.Addr())
return
}
func (l *ListenerWrap) Close() error {
return l.listener.Close()
}
func (l *ListenerWrap) Addr() net.Addr {
return l.listener.Addr()
}
// DNSServer is DNS Server listening on tun devcice // DNSServer is DNS Server listening on tun devcice
type DNSServer struct { type DNSServer struct {
*dns.Server dnsServers []*dns.Server
tcpListeners []net.Listener
resolver *dns.Resolver resolver *dns.Resolver
stack *stack.Stack stack *stack.Stack
tcpListener net.Listener udpEndpoints []*dnsEndpoint
udpEndpoint *dnsEndpoint udpEndpointIDs []*stack.TransportEndpointID
udpEndpointID *stack.TransportEndpointID
tcpip.NICID tcpip.NICID
} }
@ -66,6 +85,7 @@ func (e *dnsEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pack
var msg D.Msg var msg D.Msg
msg.Unpack(pkt.Data().AsRange().ToOwnedView()) msg.Unpack(pkt.Data().AsRange().ToOwnedView())
writer := dnsResponseWriter{s: e.stack, pkt: pkt, id: id} writer := dnsResponseWriter{s: e.stack, pkt: pkt, id: id}
log.Debugln("[DNS] hijack udp:%s:%d", id.LocalAddress.String(), id.LocalPort)
go e.server.ServeDNS(&writer, &msg) go e.server.ServeDNS(&writer, &msg)
} }
@ -129,26 +149,66 @@ func (w *dnsResponseWriter) Close() error {
} }
// CreateDNSServer create a dns server on given netstack // CreateDNSServer create a dns server on given netstack
func CreateDNSServer(s *stack.Stack, resolver *dns.Resolver, mapper *dns.ResolverEnhancer, ip net.IP, port int, nicID tcpip.NICID) (*DNSServer, error) { func CreateDNSServer(s *stack.Stack, resolver *dns.Resolver, mapper *dns.ResolverEnhancer, dnsHijack []net.Addr, nicID tcpip.NICID) (*DNSServer, error) {
var v4 bool
var err error var err error
handler := dns.NewHandler(resolver, mapper)
serverIn := &dns.Server{}
serverIn.SetHandler(handler)
tcpDnsArr := make([]net.TCPAddr, 0, len(dnsHijack))
udpDnsArr := make([]net.UDPAddr, 0, len(dnsHijack))
for _, d := range dnsHijack {
switch d.(type) {
case *net.TCPAddr:
{
tcpDnsArr = append(tcpDnsArr, *d.(*net.TCPAddr))
break
}
case *net.UDPAddr:
{
udpDnsArr = append(udpDnsArr, *d.(*net.UDPAddr))
break
}
}
}
endpoints, ids := hijackUdpDns(udpDnsArr, s, serverIn)
tcpListeners, dnsServers := hijackTcpDns(tcpDnsArr, s, serverIn)
server := &DNSServer{
resolver: resolver,
stack: s,
udpEndpoints: endpoints,
udpEndpointIDs: ids,
NICID: nicID,
tcpListeners: tcpListeners,
}
server.dnsServers = dnsServers
return server, err
}
func hijackUdpDns(dnsArr []net.UDPAddr, s *stack.Stack, serverIn *dns.Server) ([]*dnsEndpoint, []*stack.TransportEndpointID) {
endpoints := make([]*dnsEndpoint, len(dnsArr))
ids := make([]*stack.TransportEndpointID, len(dnsArr))
for i, dns := range dnsArr {
port := dns.Port
ip := dns.IP
address := tcpip.FullAddress{NIC: nicID, Port: uint16(port)} address := tcpip.FullAddress{NIC: nicID, Port: uint16(port)}
var protocol tcpip.NetworkProtocolNumber var protocol tcpip.NetworkProtocolNumber
if ip.To4() != nil { if ip.To4() != nil {
v4 = true
address.Addr = tcpip.Address(ip.To4()) address.Addr = tcpip.Address(ip.To4())
protocol = ipv4.ProtocolNumber protocol = ipv4.ProtocolNumber
} else { } else {
v4 = false
address.Addr = tcpip.Address(ip.To16()) address.Addr = tcpip.Address(ip.To16())
protocol = ipv6.ProtocolNumber protocol = ipv6.ProtocolNumber
} }
protocolAddr := tcpip.ProtocolAddress{ protocolAddr := tcpip.ProtocolAddress{
Protocol: protocol, Protocol: protocol,
AddressWithPrefix: address.Addr.WithPrefix(), AddressWithPrefix: address.Addr.WithPrefix(),
} }
// netstack will only reassemble IP fragments when its' dest ip address is registered in NIC.endpoints // netstack will only reassemble IP fragments when its' dest ip address is registered in NIC.endpoints
if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
log.Errorln("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) log.Errorln("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
@ -158,10 +218,6 @@ func CreateDNSServer(s *stack.Stack, resolver *dns.Resolver, mapper *dns.Resolve
address.Addr = "" address.Addr = ""
} }
handler := dns.NewHandler(resolver, mapper)
serverIn := &dns.Server{}
serverIn.SetHandler(handler)
// UDP DNS // UDP DNS
id := &stack.TransportEndpointID{ id := &stack.TransportEndpointID{
LocalAddress: address.Addr, LocalAddress: address.Addr,
@ -190,44 +246,72 @@ func CreateDNSServer(s *stack.Stack, resolver *dns.Resolver, mapper *dns.Resolve
log.Errorln("Unable to start UDP DNS on tun: %v", tcpiperr.String()) log.Errorln("Unable to start UDP DNS on tun: %v", tcpiperr.String())
} }
// TCP DNS ids[i] = id
endpoints[i] = endpoint
}
return endpoints, ids
}
func hijackTcpDns(dnsArr []net.TCPAddr, s *stack.Stack, serverIn *dns.Server) ([]net.Listener, []*dns.Server) {
tcpListeners := make([]net.Listener, len(dnsArr))
dnsServers := make([]*dns.Server, len(dnsArr))
for i, dnsAddr := range dnsArr {
var tcpListener net.Listener var tcpListener net.Listener
var v4 bool
var err error
port := dnsAddr.Port
ip := dnsAddr.IP
address := tcpip.FullAddress{NIC: nicID, Port: uint16(port)}
if ip.To4() != nil {
address.Addr = tcpip.Address(ip.To4())
v4 = true
} else {
address.Addr = tcpip.Address(ip.To16())
v4 = false
}
if v4 { if v4 {
tcpListener, err = gonet.ListenTCP(s, address, ipv4.ProtocolNumber) tcpListener, err = gonet.ListenTCP(s, address, ipv4.ProtocolNumber)
} else { } else {
tcpListener, err = gonet.ListenTCP(s, address, ipv6.ProtocolNumber) tcpListener, err = gonet.ListenTCP(s, address, ipv6.ProtocolNumber)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("can not listen on tun: %v", err) log.Errorln("can not listen on tun: %v, hijack tcp[%s] failed", err, dnsAddr)
} else {
tcpListeners[i] = tcpListener
server := &D.Server{Listener: &ListenerWrap{
listener: tcpListener,
}, Handler: serverIn}
dnsServer := dns.Server{}
dnsServer.Server = server
go dnsServer.ActivateAndServe()
dnsServers[i] = &dnsServer
} }
server := &DNSServer{
Server: serverIn,
resolver: resolver,
stack: s,
tcpListener: tcpListener,
udpEndpoint: endpoint,
udpEndpointID: id,
NICID: nicID,
} }
server.SetHandler(handler) //
server.Server.Server = &D.Server{Listener: tcpListener, Handler: server} //for _, listener := range tcpListeners {
// server := &D.Server{Listener: listener, Handler: serverIn}
//
// dnsServers = append(dnsServers, &dnsServer)
// go dnsServer.ActivateAndServe()
//}
go func() { return tcpListeners, dnsServers
server.ActivateAndServe()
}()
return server, err
} }
// Stop stop the DNS Server on tun // Stop stop the DNS Server on tun
func (s *DNSServer) Stop() { func (s *DNSServer) Stop() {
// shutdown TCP DNS Server if s == nil {
s.Server.Shutdown() return
// remove TCP endpoint from stack
if s.Listener != nil {
s.Listener.Close()
} }
for i := 0; i < len(s.udpEndpointIDs); i++ {
ep := s.udpEndpoints[i]
id := s.udpEndpointIDs[i]
// remove udp endpoint from stack // remove udp endpoint from stack
s.stack.UnregisterTransportEndpoint( s.stack.UnregisterTransportEndpoint(
[]tcpip.NetworkProtocolNumber{ []tcpip.NetworkProtocolNumber{
@ -235,61 +319,106 @@ func (s *DNSServer) Stop() {
ipv6.ProtocolNumber, ipv6.ProtocolNumber,
}, },
udp.ProtocolNumber, udp.ProtocolNumber,
*s.udpEndpointID, *id,
s.udpEndpoint, ep,
ports.Flags{LoadBalanced: true}, // should match the RegisterTransportEndpoint ports.Flags{LoadBalanced: true}, // should match the RegisterTransportEndpoint
s.NICID) s.NICID)
} }
for _, server := range s.dnsServers {
server.Shutdown()
}
for _, listener := range s.tcpListeners {
listener.Close()
}
}
// DnsHijack return the listening address of DNS Server // DnsHijack return the listening address of DNS Server
func (t *gvisorAdapter) DnsHijack() []string { func (t *gvisorAdapter) DnsHijack() []string {
results := make([]string, len(t.dnsServers)) dnsHijackArr := make([]string, len(t.dnsServer.udpEndpoints))
for i, dnsServer := range t.dnsServers { for _, id := range t.dnsServer.udpEndpointIDs {
id := dnsServer.udpEndpointID dnsHijackArr = append(dnsHijackArr, fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort))
results[i] = fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort)
} }
return results return dnsHijackArr
} }
func (t *gvisorAdapter) StopAllDNSServer() { func (t *gvisorAdapter) StopDNSServer() {
for _, dnsServer := range t.dnsServers { t.dnsServer.Stop()
dnsServer.Stop()
}
log.Debugln("tun DNS server stoped") log.Debugln("tun DNS server stoped")
t.dnsServers = nil t.dnsServer = nil
} }
// ReCreateDNSServer recreate the DNS Server on tun // ReCreateDNSServer recreate the DNS Server on tun
func (t *gvisorAdapter) ReCreateDNSServer(resolver *dns.Resolver, mapper *dns.ResolverEnhancer, addrs []string) error { func (t *gvisorAdapter) ReCreateDNSServer(resolver *dns.Resolver, mapper *dns.ResolverEnhancer, dnsHijackArr []string) error {
t.StopAllDNSServer() t.StopDNSServer()
if resolver == nil { if resolver == nil {
return fmt.Errorf("failed to create DNS server on tun: resolver not provided") return fmt.Errorf("failed to create DNS server on tun: resolver not provided")
} }
if len(addrs) == 0 { if len(dnsHijackArr) == 0 {
return fmt.Errorf("failed to create DNS server on tun: len(addrs) == 0") return fmt.Errorf("failed to create DNS server on tun: len(addrs) == 0")
} }
for _, addr := range addrs {
var err error var err error
_, port, err := net.SplitHostPort(addr) var addrs []net.Addr
if port == "0" || port == "" || err != nil { for _, addr := range dnsHijackArr {
return nil var (
} addrType string
hostPort string
)
udpAddr, err := net.ResolveUDPAddr("udp", addr) addrType, hostPort, err = Common.SplitNetworkType(addr)
if err != nil { if err != nil {
return err return err
} }
server, err := CreateDNSServer(t.ipstack, resolver, mapper, udpAddr.IP, udpAddr.Port, nicID) var (
host, port string
hasPort bool
)
host, port, hasPort, err = Common.SplitHostPort(hostPort)
if !hasPort {
port = "53"
}
switch addrType {
case "udp", "":
{
var udpDNS *net.UDPAddr
udpDNS, err = net.ResolveUDPAddr("udp", net.JoinHostPort(host, port))
if err != nil { if err != nil {
return err return err
} }
t.dnsServers = append(t.dnsServers, server)
log.Infoln("Tun DNS server listening at: %s, fake ip enabled: %v", addr, mapper.FakeIPEnabled()) addrs = append(addrs, udpDNS)
break
} }
case "tcp":
{
var tcpDNS *net.TCPAddr
tcpDNS, err = net.ResolveTCPAddr("tcp", net.JoinHostPort(host, port))
if err != nil {
return err
}
addrs = append(addrs, tcpDNS)
break
}
default:
err = fmt.Errorf("unspported dns scheme:%s", addrType)
}
}
server, err := CreateDNSServer(t.ipstack, resolver, mapper, addrs, nicID)
if err != nil {
return err
}
t.dnsServer = server
return nil return nil
} }