From 4ab986cccb739f48085acebcad50c1855f2009ff Mon Sep 17 00:00:00 2001 From: Skyxim Date: Sun, 9 Jan 2022 00:35:45 +0800 Subject: [PATCH] [Refactor] gvisor support hijack dns list dns-hijack: - 1.1.1.1 - 8.8.8.8:53 - tcp://1.1.1.1:53 - udp://223.5.5.5 - 10.0.0.1:5353 --- common/net/tcpip.go | 46 ++++ listener/tun/ipstack/gvisor/tun.go | 12 +- listener/tun/ipstack/gvisor/tundns.go | 383 +++++++++++++++++--------- 3 files changed, 308 insertions(+), 133 deletions(-) create mode 100644 common/net/tcpip.go diff --git a/common/net/tcpip.go b/common/net/tcpip.go new file mode 100644 index 00000000..a84e7e4c --- /dev/null +++ b/common/net/tcpip.go @@ -0,0 +1,46 @@ +package net + +import ( + "fmt" + "net" + "strings" +) + +func SplitNetworkType(s string) (string, string, error) { + var ( + shecme string + hostPort string + ) + result := strings.Split(s, "://") + if len(result) == 2 { + shecme = result[0] + hostPort = result[1] + } else if len(result) == 1 { + hostPort = result[0] + } else { + return "", "", fmt.Errorf("tcp/udp style error") + } + + if len(shecme) == 0 { + shecme = "udp" + } + + if shecme != "tcp" && shecme != "udp" { + return "", "", fmt.Errorf("scheme should be tcp:// or udp://") + } else { + return shecme, hostPort, nil + } +} + +func SplitHostPort(s string) (host, port string, hasPort bool, err error) { + temp := s + hasPort = true + + if !strings.Contains(s, ":") && !strings.Contains(s, "]:") { + temp += ":0" + hasPort = false + } + + host, port, err = net.SplitHostPort(temp) + return +} diff --git a/listener/tun/ipstack/gvisor/tun.go b/listener/tun/ipstack/gvisor/tun.go index e550739d..b19e6fa7 100644 --- a/listener/tun/ipstack/gvisor/tun.go +++ b/listener/tun/ipstack/gvisor/tun.go @@ -34,10 +34,10 @@ import ( const nicID tcpip.NICID = 1 type gvisorAdapter struct { - device dev.TunDevice - ipstack *stack.Stack - dnsServers []*DNSServer - udpIn chan<- *inbound.PacketAdapter + device dev.TunDevice + ipstack *stack.Stack + dnsServer *DNSServer + udpIn chan<- *inbound.PacketAdapter stackName string autoRoute bool @@ -47,7 +47,7 @@ type gvisorAdapter struct { writeHandle *channel.NotificationHandle } -// GvisorAdapter create GvisorAdapter +// NewAdapter GvisorAdapter create GvisorAdapter func NewAdapter(device dev.TunDevice, conf config.Tun, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.TunAdapter, error) { ipstack := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, @@ -132,7 +132,7 @@ func (t *gvisorAdapter) AutoRoute() bool { // Close close the TunAdapter func (t *gvisorAdapter) Close() { - t.StopAllDNSServer() + t.StopDNSServer() if t.ipstack != nil { t.ipstack.Close() } diff --git a/listener/tun/ipstack/gvisor/tundns.go b/listener/tun/ipstack/gvisor/tundns.go index 5724aa70..7fba6bbe 100644 --- a/listener/tun/ipstack/gvisor/tundns.go +++ b/listener/tun/ipstack/gvisor/tundns.go @@ -2,13 +2,14 @@ package gvisor import ( "fmt" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "net" + Common "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/dns" "github.com/Dreamacro/clash/log" D "github.com/miekg/dns" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -23,15 +24,33 @@ var ( ipv6Zero = tcpip.Address(net.IPv6zero.To16()) ) +type ListenerWrap struct { + net.Listener + listener net.Listener +} + +func (l *ListenerWrap) Accept() (conn net.Conn, err error) { + conn, err = l.listener.Accept() + log.Debugln("[DNS] hijack tcp:%s", l.Addr()) + return +} + +func (l *ListenerWrap) Close() error { + return l.listener.Close() +} + +func (l *ListenerWrap) Addr() net.Addr { + return l.listener.Addr() +} + // DNSServer is DNS Server listening on tun devcice type DNSServer struct { - *dns.Server - resolver *dns.Resolver - - stack *stack.Stack - tcpListener net.Listener - udpEndpoint *dnsEndpoint - udpEndpointID *stack.TransportEndpointID + dnsServers []*dns.Server + tcpListeners []net.Listener + resolver *dns.Resolver + stack *stack.Stack + udpEndpoints []*dnsEndpoint + udpEndpointIDs []*stack.TransportEndpointID tcpip.NICID } @@ -66,6 +85,7 @@ func (e *dnsEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pack var msg D.Msg msg.Unpack(pkt.Data().AsRange().ToOwnedView()) writer := dnsResponseWriter{s: e.stack, pkt: pkt, id: id} + log.Debugln("[DNS] hijack udp:%s:%d", id.LocalAddress.String(), id.LocalPort) go e.server.ServeDNS(&writer, &msg) } @@ -129,167 +149,276 @@ func (w *dnsResponseWriter) Close() error { } // CreateDNSServer create a dns server on given netstack -func CreateDNSServer(s *stack.Stack, resolver *dns.Resolver, mapper *dns.ResolverEnhancer, ip net.IP, port int, nicID tcpip.NICID) (*DNSServer, error) { - var v4 bool +func CreateDNSServer(s *stack.Stack, resolver *dns.Resolver, mapper *dns.ResolverEnhancer, dnsHijack []net.Addr, nicID tcpip.NICID) (*DNSServer, error) { var err error - - address := tcpip.FullAddress{NIC: nicID, Port: uint16(port)} - var protocol tcpip.NetworkProtocolNumber - if ip.To4() != nil { - v4 = true - address.Addr = tcpip.Address(ip.To4()) - protocol = ipv4.ProtocolNumber - - } else { - v4 = false - address.Addr = tcpip.Address(ip.To16()) - protocol = ipv6.ProtocolNumber - } - protocolAddr := tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: address.Addr.WithPrefix(), - } - // netstack will only reassemble IP fragments when its' dest ip address is registered in NIC.endpoints - if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { - log.Errorln("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) - } - - if address.Addr == ipv4Zero || address.Addr == ipv6Zero { - address.Addr = "" - } - handler := dns.NewHandler(resolver, mapper) serverIn := &dns.Server{} serverIn.SetHandler(handler) - - // UDP DNS - id := &stack.TransportEndpointID{ - LocalAddress: address.Addr, - LocalPort: uint16(port), - RemotePort: 0, - RemoteAddress: "", - } - - // TransportEndpoint for DNS - endpoint := &dnsEndpoint{ - stack: s, - uniqueID: s.UniqueID(), - server: serverIn, - } - - if tcpiperr := s.RegisterTransportEndpoint( - []tcpip.NetworkProtocolNumber{ - ipv4.ProtocolNumber, - ipv6.ProtocolNumber, - }, - udp.ProtocolNumber, - *id, - endpoint, - ports.Flags{LoadBalanced: true}, // it's actually the SO_REUSEPORT. Not sure it take effect. - nicID); tcpiperr != nil { - log.Errorln("Unable to start UDP DNS on tun: %v", tcpiperr.String()) - } - - // TCP DNS - var tcpListener net.Listener - if v4 { - tcpListener, err = gonet.ListenTCP(s, address, ipv4.ProtocolNumber) - } else { - tcpListener, err = gonet.ListenTCP(s, address, ipv6.ProtocolNumber) - } - if err != nil { - return nil, fmt.Errorf("can not listen on tun: %v", err) + tcpDnsArr := make([]net.TCPAddr, 0, len(dnsHijack)) + udpDnsArr := make([]net.UDPAddr, 0, len(dnsHijack)) + for _, d := range dnsHijack { + switch d.(type) { + case *net.TCPAddr: + { + tcpDnsArr = append(tcpDnsArr, *d.(*net.TCPAddr)) + break + } + case *net.UDPAddr: + { + udpDnsArr = append(udpDnsArr, *d.(*net.UDPAddr)) + break + } + } } + endpoints, ids := hijackUdpDns(udpDnsArr, s, serverIn) + tcpListeners, dnsServers := hijackTcpDns(tcpDnsArr, s, serverIn) server := &DNSServer{ - Server: serverIn, - resolver: resolver, - stack: s, - tcpListener: tcpListener, - udpEndpoint: endpoint, - udpEndpointID: id, - NICID: nicID, + resolver: resolver, + stack: s, + udpEndpoints: endpoints, + udpEndpointIDs: ids, + NICID: nicID, + tcpListeners: tcpListeners, } - server.SetHandler(handler) - server.Server.Server = &D.Server{Listener: tcpListener, Handler: server} - go func() { - server.ActivateAndServe() - }() + server.dnsServers = dnsServers return server, err } +func hijackUdpDns(dnsArr []net.UDPAddr, s *stack.Stack, serverIn *dns.Server) ([]*dnsEndpoint, []*stack.TransportEndpointID) { + endpoints := make([]*dnsEndpoint, len(dnsArr)) + ids := make([]*stack.TransportEndpointID, len(dnsArr)) + for i, dns := range dnsArr { + port := dns.Port + ip := dns.IP + address := tcpip.FullAddress{NIC: nicID, Port: uint16(port)} + var protocol tcpip.NetworkProtocolNumber + if ip.To4() != nil { + address.Addr = tcpip.Address(ip.To4()) + protocol = ipv4.ProtocolNumber + + } else { + address.Addr = tcpip.Address(ip.To16()) + protocol = ipv6.ProtocolNumber + } + + protocolAddr := tcpip.ProtocolAddress{ + Protocol: protocol, + AddressWithPrefix: address.Addr.WithPrefix(), + } + + // netstack will only reassemble IP fragments when its' dest ip address is registered in NIC.endpoints + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + log.Errorln("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) + } + + if address.Addr == ipv4Zero || address.Addr == ipv6Zero { + address.Addr = "" + } + + // UDP DNS + id := &stack.TransportEndpointID{ + LocalAddress: address.Addr, + LocalPort: uint16(port), + RemotePort: 0, + RemoteAddress: "", + } + + // TransportEndpoint for DNS + endpoint := &dnsEndpoint{ + stack: s, + uniqueID: s.UniqueID(), + server: serverIn, + } + + if tcpiperr := s.RegisterTransportEndpoint( + []tcpip.NetworkProtocolNumber{ + ipv4.ProtocolNumber, + ipv6.ProtocolNumber, + }, + udp.ProtocolNumber, + *id, + endpoint, + ports.Flags{LoadBalanced: true}, // it's actually the SO_REUSEPORT. Not sure it take effect. + nicID); tcpiperr != nil { + log.Errorln("Unable to start UDP DNS on tun: %v", tcpiperr.String()) + } + + ids[i] = id + endpoints[i] = endpoint + } + + return endpoints, ids +} + +func hijackTcpDns(dnsArr []net.TCPAddr, s *stack.Stack, serverIn *dns.Server) ([]net.Listener, []*dns.Server) { + tcpListeners := make([]net.Listener, len(dnsArr)) + dnsServers := make([]*dns.Server, len(dnsArr)) + + for i, dnsAddr := range dnsArr { + var tcpListener net.Listener + var v4 bool + var err error + port := dnsAddr.Port + ip := dnsAddr.IP + address := tcpip.FullAddress{NIC: nicID, Port: uint16(port)} + if ip.To4() != nil { + address.Addr = tcpip.Address(ip.To4()) + v4 = true + } else { + address.Addr = tcpip.Address(ip.To16()) + v4 = false + } + + if v4 { + tcpListener, err = gonet.ListenTCP(s, address, ipv4.ProtocolNumber) + } else { + tcpListener, err = gonet.ListenTCP(s, address, ipv6.ProtocolNumber) + } + + if err != nil { + log.Errorln("can not listen on tun: %v, hijack tcp[%s] failed", err, dnsAddr) + } else { + tcpListeners[i] = tcpListener + server := &D.Server{Listener: &ListenerWrap{ + listener: tcpListener, + }, Handler: serverIn} + dnsServer := dns.Server{} + dnsServer.Server = server + go dnsServer.ActivateAndServe() + dnsServers[i] = &dnsServer + } + + } + // + //for _, listener := range tcpListeners { + // server := &D.Server{Listener: listener, Handler: serverIn} + // + // dnsServers = append(dnsServers, &dnsServer) + // go dnsServer.ActivateAndServe() + //} + + return tcpListeners, dnsServers +} + // Stop stop the DNS Server on tun func (s *DNSServer) Stop() { - // shutdown TCP DNS Server - s.Server.Shutdown() - // remove TCP endpoint from stack - if s.Listener != nil { - s.Listener.Close() + if s == nil { + return + } + + for i := 0; i < len(s.udpEndpointIDs); i++ { + ep := s.udpEndpoints[i] + id := s.udpEndpointIDs[i] + // remove udp endpoint from stack + s.stack.UnregisterTransportEndpoint( + []tcpip.NetworkProtocolNumber{ + ipv4.ProtocolNumber, + ipv6.ProtocolNumber, + }, + udp.ProtocolNumber, + *id, + ep, + ports.Flags{LoadBalanced: true}, // should match the RegisterTransportEndpoint + s.NICID) + } + + for _, server := range s.dnsServers { + server.Shutdown() + } + + for _, listener := range s.tcpListeners { + listener.Close() } - // remove udp endpoint from stack - s.stack.UnregisterTransportEndpoint( - []tcpip.NetworkProtocolNumber{ - ipv4.ProtocolNumber, - ipv6.ProtocolNumber, - }, - udp.ProtocolNumber, - *s.udpEndpointID, - s.udpEndpoint, - ports.Flags{LoadBalanced: true}, // should match the RegisterTransportEndpoint - s.NICID) } // DnsHijack return the listening address of DNS Server func (t *gvisorAdapter) DnsHijack() []string { - results := make([]string, len(t.dnsServers)) - for i, dnsServer := range t.dnsServers { - id := dnsServer.udpEndpointID - results[i] = fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort) + dnsHijackArr := make([]string, len(t.dnsServer.udpEndpoints)) + for _, id := range t.dnsServer.udpEndpointIDs { + dnsHijackArr = append(dnsHijackArr, fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort)) } - return results + return dnsHijackArr } -func (t *gvisorAdapter) StopAllDNSServer() { - for _, dnsServer := range t.dnsServers { - dnsServer.Stop() - } +func (t *gvisorAdapter) StopDNSServer() { + t.dnsServer.Stop() log.Debugln("tun DNS server stoped") - t.dnsServers = nil + t.dnsServer = nil } // ReCreateDNSServer recreate the DNS Server on tun -func (t *gvisorAdapter) ReCreateDNSServer(resolver *dns.Resolver, mapper *dns.ResolverEnhancer, addrs []string) error { - t.StopAllDNSServer() +func (t *gvisorAdapter) ReCreateDNSServer(resolver *dns.Resolver, mapper *dns.ResolverEnhancer, dnsHijackArr []string) error { + t.StopDNSServer() if resolver == nil { return fmt.Errorf("failed to create DNS server on tun: resolver not provided") } - if len(addrs) == 0 { + if len(dnsHijackArr) == 0 { return fmt.Errorf("failed to create DNS server on tun: len(addrs) == 0") } - for _, addr := range addrs { - var err error - _, port, err := net.SplitHostPort(addr) - if port == "0" || port == "" || err != nil { - return nil - } + var err error + var addrs []net.Addr + for _, addr := range dnsHijackArr { + var ( + addrType string + hostPort string + ) - udpAddr, err := net.ResolveUDPAddr("udp", addr) + addrType, hostPort, err = Common.SplitNetworkType(addr) if err != nil { return err } - server, err := CreateDNSServer(t.ipstack, resolver, mapper, udpAddr.IP, udpAddr.Port, nicID) - if err != nil { - return err + var ( + host, port string + hasPort bool + ) + + host, port, hasPort, err = Common.SplitHostPort(hostPort) + if !hasPort { + port = "53" } - t.dnsServers = append(t.dnsServers, server) - log.Infoln("Tun DNS server listening at: %s, fake ip enabled: %v", addr, mapper.FakeIPEnabled()) + + switch addrType { + case "udp", "": + { + var udpDNS *net.UDPAddr + udpDNS, err = net.ResolveUDPAddr("udp", net.JoinHostPort(host, port)) + if err != nil { + return err + } + + addrs = append(addrs, udpDNS) + break + } + case "tcp": + { + var tcpDNS *net.TCPAddr + tcpDNS, err = net.ResolveTCPAddr("tcp", net.JoinHostPort(host, port)) + if err != nil { + return err + } + + addrs = append(addrs, tcpDNS) + break + } + default: + err = fmt.Errorf("unspported dns scheme:%s", addrType) + } + } + server, err := CreateDNSServer(t.ipstack, resolver, mapper, addrs, nicID) + if err != nil { + return err + } + + t.dnsServer = server + return nil }