[Refactor] gvisor support hijack dns list

dns-hijack:
 - 1.1.1.1
 - 8.8.8.8:53
 - tcp://1.1.1.1:53
 - udp://223.5.5.5
 - 10.0.0.1:5353
This commit is contained in:
gVisor bot 2022-01-09 00:35:45 +08:00
parent 1c1eb6bdfb
commit 0a96994452
3 changed files with 308 additions and 133 deletions

46
common/net/tcpip.go Normal file
View file

@ -0,0 +1,46 @@
package net
import (
"fmt"
"net"
"strings"
)
func SplitNetworkType(s string) (string, string, error) {
var (
shecme string
hostPort string
)
result := strings.Split(s, "://")
if len(result) == 2 {
shecme = result[0]
hostPort = result[1]
} else if len(result) == 1 {
hostPort = result[0]
} else {
return "", "", fmt.Errorf("tcp/udp style error")
}
if len(shecme) == 0 {
shecme = "udp"
}
if shecme != "tcp" && shecme != "udp" {
return "", "", fmt.Errorf("scheme should be tcp:// or udp://")
} else {
return shecme, hostPort, nil
}
}
func SplitHostPort(s string) (host, port string, hasPort bool, err error) {
temp := s
hasPort = true
if !strings.Contains(s, ":") && !strings.Contains(s, "]:") {
temp += ":0"
hasPort = false
}
host, port, err = net.SplitHostPort(temp)
return
}

View file

@ -34,10 +34,10 @@ import (
const nicID tcpip.NICID = 1
type gvisorAdapter struct {
device dev.TunDevice
ipstack *stack.Stack
dnsServers []*DNSServer
udpIn chan<- *inbound.PacketAdapter
device dev.TunDevice
ipstack *stack.Stack
dnsServer *DNSServer
udpIn chan<- *inbound.PacketAdapter
stackName string
autoRoute bool
@ -47,7 +47,7 @@ type gvisorAdapter struct {
writeHandle *channel.NotificationHandle
}
// GvisorAdapter create GvisorAdapter
// NewAdapter GvisorAdapter create GvisorAdapter
func NewAdapter(device dev.TunDevice, conf config.Tun, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.TunAdapter, error) {
ipstack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
@ -132,7 +132,7 @@ func (t *gvisorAdapter) AutoRoute() bool {
// Close close the TunAdapter
func (t *gvisorAdapter) Close() {
t.StopAllDNSServer()
t.StopDNSServer()
if t.ipstack != nil {
t.ipstack.Close()
}

View file

@ -2,13 +2,14 @@ package gvisor
import (
"fmt"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"net"
Common "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/dns"
"github.com/Dreamacro/clash/log"
D "github.com/miekg/dns"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@ -23,15 +24,33 @@ var (
ipv6Zero = tcpip.Address(net.IPv6zero.To16())
)
type ListenerWrap struct {
net.Listener
listener net.Listener
}
func (l *ListenerWrap) Accept() (conn net.Conn, err error) {
conn, err = l.listener.Accept()
log.Debugln("[DNS] hijack tcp:%s", l.Addr())
return
}
func (l *ListenerWrap) Close() error {
return l.listener.Close()
}
func (l *ListenerWrap) Addr() net.Addr {
return l.listener.Addr()
}
// DNSServer is DNS Server listening on tun devcice
type DNSServer struct {
*dns.Server
resolver *dns.Resolver
stack *stack.Stack
tcpListener net.Listener
udpEndpoint *dnsEndpoint
udpEndpointID *stack.TransportEndpointID
dnsServers []*dns.Server
tcpListeners []net.Listener
resolver *dns.Resolver
stack *stack.Stack
udpEndpoints []*dnsEndpoint
udpEndpointIDs []*stack.TransportEndpointID
tcpip.NICID
}
@ -66,6 +85,7 @@ func (e *dnsEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pack
var msg D.Msg
msg.Unpack(pkt.Data().AsRange().ToOwnedView())
writer := dnsResponseWriter{s: e.stack, pkt: pkt, id: id}
log.Debugln("[DNS] hijack udp:%s:%d", id.LocalAddress.String(), id.LocalPort)
go e.server.ServeDNS(&writer, &msg)
}
@ -129,167 +149,276 @@ func (w *dnsResponseWriter) Close() error {
}
// CreateDNSServer create a dns server on given netstack
func CreateDNSServer(s *stack.Stack, resolver *dns.Resolver, mapper *dns.ResolverEnhancer, ip net.IP, port int, nicID tcpip.NICID) (*DNSServer, error) {
var v4 bool
func CreateDNSServer(s *stack.Stack, resolver *dns.Resolver, mapper *dns.ResolverEnhancer, dnsHijack []net.Addr, nicID tcpip.NICID) (*DNSServer, error) {
var err error
address := tcpip.FullAddress{NIC: nicID, Port: uint16(port)}
var protocol tcpip.NetworkProtocolNumber
if ip.To4() != nil {
v4 = true
address.Addr = tcpip.Address(ip.To4())
protocol = ipv4.ProtocolNumber
} else {
v4 = false
address.Addr = tcpip.Address(ip.To16())
protocol = ipv6.ProtocolNumber
}
protocolAddr := tcpip.ProtocolAddress{
Protocol: protocol,
AddressWithPrefix: address.Addr.WithPrefix(),
}
// netstack will only reassemble IP fragments when its' dest ip address is registered in NIC.endpoints
if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
log.Errorln("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if address.Addr == ipv4Zero || address.Addr == ipv6Zero {
address.Addr = ""
}
handler := dns.NewHandler(resolver, mapper)
serverIn := &dns.Server{}
serverIn.SetHandler(handler)
// UDP DNS
id := &stack.TransportEndpointID{
LocalAddress: address.Addr,
LocalPort: uint16(port),
RemotePort: 0,
RemoteAddress: "",
}
// TransportEndpoint for DNS
endpoint := &dnsEndpoint{
stack: s,
uniqueID: s.UniqueID(),
server: serverIn,
}
if tcpiperr := s.RegisterTransportEndpoint(
[]tcpip.NetworkProtocolNumber{
ipv4.ProtocolNumber,
ipv6.ProtocolNumber,
},
udp.ProtocolNumber,
*id,
endpoint,
ports.Flags{LoadBalanced: true}, // it's actually the SO_REUSEPORT. Not sure it take effect.
nicID); tcpiperr != nil {
log.Errorln("Unable to start UDP DNS on tun: %v", tcpiperr.String())
}
// TCP DNS
var tcpListener net.Listener
if v4 {
tcpListener, err = gonet.ListenTCP(s, address, ipv4.ProtocolNumber)
} else {
tcpListener, err = gonet.ListenTCP(s, address, ipv6.ProtocolNumber)
}
if err != nil {
return nil, fmt.Errorf("can not listen on tun: %v", err)
tcpDnsArr := make([]net.TCPAddr, 0, len(dnsHijack))
udpDnsArr := make([]net.UDPAddr, 0, len(dnsHijack))
for _, d := range dnsHijack {
switch d.(type) {
case *net.TCPAddr:
{
tcpDnsArr = append(tcpDnsArr, *d.(*net.TCPAddr))
break
}
case *net.UDPAddr:
{
udpDnsArr = append(udpDnsArr, *d.(*net.UDPAddr))
break
}
}
}
endpoints, ids := hijackUdpDns(udpDnsArr, s, serverIn)
tcpListeners, dnsServers := hijackTcpDns(tcpDnsArr, s, serverIn)
server := &DNSServer{
Server: serverIn,
resolver: resolver,
stack: s,
tcpListener: tcpListener,
udpEndpoint: endpoint,
udpEndpointID: id,
NICID: nicID,
resolver: resolver,
stack: s,
udpEndpoints: endpoints,
udpEndpointIDs: ids,
NICID: nicID,
tcpListeners: tcpListeners,
}
server.SetHandler(handler)
server.Server.Server = &D.Server{Listener: tcpListener, Handler: server}
go func() {
server.ActivateAndServe()
}()
server.dnsServers = dnsServers
return server, err
}
func hijackUdpDns(dnsArr []net.UDPAddr, s *stack.Stack, serverIn *dns.Server) ([]*dnsEndpoint, []*stack.TransportEndpointID) {
endpoints := make([]*dnsEndpoint, len(dnsArr))
ids := make([]*stack.TransportEndpointID, len(dnsArr))
for i, dns := range dnsArr {
port := dns.Port
ip := dns.IP
address := tcpip.FullAddress{NIC: nicID, Port: uint16(port)}
var protocol tcpip.NetworkProtocolNumber
if ip.To4() != nil {
address.Addr = tcpip.Address(ip.To4())
protocol = ipv4.ProtocolNumber
} else {
address.Addr = tcpip.Address(ip.To16())
protocol = ipv6.ProtocolNumber
}
protocolAddr := tcpip.ProtocolAddress{
Protocol: protocol,
AddressWithPrefix: address.Addr.WithPrefix(),
}
// netstack will only reassemble IP fragments when its' dest ip address is registered in NIC.endpoints
if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
log.Errorln("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if address.Addr == ipv4Zero || address.Addr == ipv6Zero {
address.Addr = ""
}
// UDP DNS
id := &stack.TransportEndpointID{
LocalAddress: address.Addr,
LocalPort: uint16(port),
RemotePort: 0,
RemoteAddress: "",
}
// TransportEndpoint for DNS
endpoint := &dnsEndpoint{
stack: s,
uniqueID: s.UniqueID(),
server: serverIn,
}
if tcpiperr := s.RegisterTransportEndpoint(
[]tcpip.NetworkProtocolNumber{
ipv4.ProtocolNumber,
ipv6.ProtocolNumber,
},
udp.ProtocolNumber,
*id,
endpoint,
ports.Flags{LoadBalanced: true}, // it's actually the SO_REUSEPORT. Not sure it take effect.
nicID); tcpiperr != nil {
log.Errorln("Unable to start UDP DNS on tun: %v", tcpiperr.String())
}
ids[i] = id
endpoints[i] = endpoint
}
return endpoints, ids
}
func hijackTcpDns(dnsArr []net.TCPAddr, s *stack.Stack, serverIn *dns.Server) ([]net.Listener, []*dns.Server) {
tcpListeners := make([]net.Listener, len(dnsArr))
dnsServers := make([]*dns.Server, len(dnsArr))
for i, dnsAddr := range dnsArr {
var tcpListener net.Listener
var v4 bool
var err error
port := dnsAddr.Port
ip := dnsAddr.IP
address := tcpip.FullAddress{NIC: nicID, Port: uint16(port)}
if ip.To4() != nil {
address.Addr = tcpip.Address(ip.To4())
v4 = true
} else {
address.Addr = tcpip.Address(ip.To16())
v4 = false
}
if v4 {
tcpListener, err = gonet.ListenTCP(s, address, ipv4.ProtocolNumber)
} else {
tcpListener, err = gonet.ListenTCP(s, address, ipv6.ProtocolNumber)
}
if err != nil {
log.Errorln("can not listen on tun: %v, hijack tcp[%s] failed", err, dnsAddr)
} else {
tcpListeners[i] = tcpListener
server := &D.Server{Listener: &ListenerWrap{
listener: tcpListener,
}, Handler: serverIn}
dnsServer := dns.Server{}
dnsServer.Server = server
go dnsServer.ActivateAndServe()
dnsServers[i] = &dnsServer
}
}
//
//for _, listener := range tcpListeners {
// server := &D.Server{Listener: listener, Handler: serverIn}
//
// dnsServers = append(dnsServers, &dnsServer)
// go dnsServer.ActivateAndServe()
//}
return tcpListeners, dnsServers
}
// Stop stop the DNS Server on tun
func (s *DNSServer) Stop() {
// shutdown TCP DNS Server
s.Server.Shutdown()
// remove TCP endpoint from stack
if s.Listener != nil {
s.Listener.Close()
if s == nil {
return
}
for i := 0; i < len(s.udpEndpointIDs); i++ {
ep := s.udpEndpoints[i]
id := s.udpEndpointIDs[i]
// remove udp endpoint from stack
s.stack.UnregisterTransportEndpoint(
[]tcpip.NetworkProtocolNumber{
ipv4.ProtocolNumber,
ipv6.ProtocolNumber,
},
udp.ProtocolNumber,
*id,
ep,
ports.Flags{LoadBalanced: true}, // should match the RegisterTransportEndpoint
s.NICID)
}
for _, server := range s.dnsServers {
server.Shutdown()
}
for _, listener := range s.tcpListeners {
listener.Close()
}
// remove udp endpoint from stack
s.stack.UnregisterTransportEndpoint(
[]tcpip.NetworkProtocolNumber{
ipv4.ProtocolNumber,
ipv6.ProtocolNumber,
},
udp.ProtocolNumber,
*s.udpEndpointID,
s.udpEndpoint,
ports.Flags{LoadBalanced: true}, // should match the RegisterTransportEndpoint
s.NICID)
}
// DnsHijack return the listening address of DNS Server
func (t *gvisorAdapter) DnsHijack() []string {
results := make([]string, len(t.dnsServers))
for i, dnsServer := range t.dnsServers {
id := dnsServer.udpEndpointID
results[i] = fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort)
dnsHijackArr := make([]string, len(t.dnsServer.udpEndpoints))
for _, id := range t.dnsServer.udpEndpointIDs {
dnsHijackArr = append(dnsHijackArr, fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort))
}
return results
return dnsHijackArr
}
func (t *gvisorAdapter) StopAllDNSServer() {
for _, dnsServer := range t.dnsServers {
dnsServer.Stop()
}
func (t *gvisorAdapter) StopDNSServer() {
t.dnsServer.Stop()
log.Debugln("tun DNS server stoped")
t.dnsServers = nil
t.dnsServer = nil
}
// ReCreateDNSServer recreate the DNS Server on tun
func (t *gvisorAdapter) ReCreateDNSServer(resolver *dns.Resolver, mapper *dns.ResolverEnhancer, addrs []string) error {
t.StopAllDNSServer()
func (t *gvisorAdapter) ReCreateDNSServer(resolver *dns.Resolver, mapper *dns.ResolverEnhancer, dnsHijackArr []string) error {
t.StopDNSServer()
if resolver == nil {
return fmt.Errorf("failed to create DNS server on tun: resolver not provided")
}
if len(addrs) == 0 {
if len(dnsHijackArr) == 0 {
return fmt.Errorf("failed to create DNS server on tun: len(addrs) == 0")
}
for _, addr := range addrs {
var err error
_, port, err := net.SplitHostPort(addr)
if port == "0" || port == "" || err != nil {
return nil
}
var err error
var addrs []net.Addr
for _, addr := range dnsHijackArr {
var (
addrType string
hostPort string
)
udpAddr, err := net.ResolveUDPAddr("udp", addr)
addrType, hostPort, err = Common.SplitNetworkType(addr)
if err != nil {
return err
}
server, err := CreateDNSServer(t.ipstack, resolver, mapper, udpAddr.IP, udpAddr.Port, nicID)
if err != nil {
return err
var (
host, port string
hasPort bool
)
host, port, hasPort, err = Common.SplitHostPort(hostPort)
if !hasPort {
port = "53"
}
t.dnsServers = append(t.dnsServers, server)
log.Infoln("Tun DNS server listening at: %s, fake ip enabled: %v", addr, mapper.FakeIPEnabled())
switch addrType {
case "udp", "":
{
var udpDNS *net.UDPAddr
udpDNS, err = net.ResolveUDPAddr("udp", net.JoinHostPort(host, port))
if err != nil {
return err
}
addrs = append(addrs, udpDNS)
break
}
case "tcp":
{
var tcpDNS *net.TCPAddr
tcpDNS, err = net.ResolveTCPAddr("tcp", net.JoinHostPort(host, port))
if err != nil {
return err
}
addrs = append(addrs, tcpDNS)
break
}
default:
err = fmt.Errorf("unspported dns scheme:%s", addrType)
}
}
server, err := CreateDNSServer(t.ipstack, resolver, mapper, addrs, nicID)
if err != nil {
return err
}
t.dnsServer = server
return nil
}