diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 29096951..887c473a 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -330,7 +330,9 @@ func updateIPTables(cfg *config.Config) { log.Infoln("[IPTABLES] Setting iptables completed") } -func Cleanup() { +func Shutdown() { P.Cleanup() tproxy.CleanupTProxyIPTables() + + log.Warnln("Clash shutting down") } diff --git a/hub/hub.go b/hub/hub.go index cde0bb57..471fdb5e 100644 --- a/hub/hub.go +++ b/hub/hub.go @@ -48,7 +48,3 @@ func Parse(options ...Option) error { executor.ApplyConfig(cfg, true) return nil } - -func Cleanup() { - executor.Cleanup() -} diff --git a/listener/tun/device/device.go b/listener/tun/device/device.go index 70115cbd..73b03dee 100644 --- a/listener/tun/device/device.go +++ b/listener/tun/device/device.go @@ -29,4 +29,7 @@ type Device interface { // UseIOBased work for other ip stack UseIOBased() error + + // Wait waits for the device to close. + Wait() } diff --git a/listener/tun/device/iobased/endpoint.go b/listener/tun/device/iobased/endpoint.go index a187491e..c0942d10 100644 --- a/listener/tun/device/iobased/endpoint.go +++ b/listener/tun/device/iobased/endpoint.go @@ -103,7 +103,7 @@ func (e *Endpoint) dispatchLoop(cancel context.CancelFunc) { case header.IPv6Version: e.InjectInbound(header.IPv6ProtocolNumber, pkt) } - pkt.DecRef() /* release */ + pkt.DecRef() } } diff --git a/listener/tun/device/tun/tun_wireguard.go b/listener/tun/device/tun/tun_wireguard.go index 50db6511..30398b55 100644 --- a/listener/tun/device/tun/tun_wireguard.go +++ b/listener/tun/device/tun/tun_wireguard.go @@ -106,6 +106,9 @@ func (t *TUN) Write(packet []byte) (int, error) { } func (t *TUN) Close() error { + if t.Endpoint != nil { + t.Endpoint.Close() + } return t.nt.Close() } diff --git a/listener/tun/ipstack/gvisor/adapter/handler.go b/listener/tun/ipstack/gvisor/adapter/handler.go index 2878b713..715f6636 100644 --- a/listener/tun/ipstack/gvisor/adapter/handler.go +++ b/listener/tun/ipstack/gvisor/adapter/handler.go @@ -3,6 +3,12 @@ package adapter // Handler is a TCP/UDP connection handler that implements // HandleTCPConn and HandleUDPConn methods. type Handler interface { - HandleTCPConn(TCPConn) - HandleUDPConn(UDPConn) + HandleTCP(TCPConn) + HandleUDP(UDPConn) } + +// TCPHandleFunc handles incoming TCP connection. +type TCPHandleFunc func(TCPConn) + +// UDPHandleFunc handles incoming UDP connection. +type UDPHandleFunc func(UDPConn) diff --git a/listener/tun/ipstack/gvisor/handler.go b/listener/tun/ipstack/gvisor/handler.go index 76aab2f5..6365234e 100644 --- a/listener/tun/ipstack/gvisor/handler.go +++ b/listener/tun/ipstack/gvisor/handler.go @@ -24,7 +24,7 @@ type GVHandler struct { UDPIn chan<- *inbound.PacketAdapter } -func (gh *GVHandler) HandleTCPConn(tunConn adapter.TCPConn) { +func (gh *GVHandler) HandleTCP(tunConn adapter.TCPConn) { id := tunConn.ID() rAddr := &net.UDPAddr{ @@ -77,7 +77,7 @@ func (gh *GVHandler) HandleTCPConn(tunConn adapter.TCPConn) { gh.TCPIn <- inbound.NewSocket(socks5.ParseAddrToSocksAddr(rAddr), tunConn, C.TUN) } -func (gh *GVHandler) HandleUDPConn(tunConn adapter.UDPConn) { +func (gh *GVHandler) HandleUDP(tunConn adapter.UDPConn) { id := tunConn.ID() rAddr := &net.UDPAddr{ diff --git a/listener/tun/ipstack/gvisor/nic.go b/listener/tun/ipstack/gvisor/nic.go index fb8ac1a2..0ca96778 100644 --- a/listener/tun/ipstack/gvisor/nic.go +++ b/listener/tun/ipstack/gvisor/nic.go @@ -3,14 +3,13 @@ package gvisor import ( "fmt" + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( - // defaultNICID is the ID of default NIC used by DefaultStack. - defaultNICID tcpip.NICID = 0x01 - // nicPromiscuousModeEnabled is the value used by stack to enable // or disable NIC's promiscuous mode. nicPromiscuousModeEnabled = true @@ -21,9 +20,9 @@ const ( ) // withCreatingNIC creates NIC for stack. -func withCreatingNIC(ep stack.LinkEndpoint) Option { - return func(s *gvStack) error { - if err := s.CreateNICWithOptions(s.nicID, ep, +func withCreatingNIC(nicID tcpip.NICID, ep stack.LinkEndpoint) option.Option { + return func(s *stack.Stack) error { + if err := s.CreateNICWithOptions(nicID, ep, stack.NICOptions{ Disabled: false, // If no queueing discipline was specified @@ -37,21 +36,21 @@ func withCreatingNIC(ep stack.LinkEndpoint) Option { } } -// withPromiscuousMode sets promiscuous mode in the given NIC. -func withPromiscuousMode(v bool) Option { - return func(s *gvStack) error { - if err := s.SetPromiscuousMode(s.nicID, v); err != nil { +// withPromiscuousMode sets promiscuous mode in the given NICs. +func withPromiscuousMode(nicID tcpip.NICID, v bool) option.Option { + return func(s *stack.Stack) error { + if err := s.SetPromiscuousMode(nicID, v); err != nil { return fmt.Errorf("set promiscuous mode: %s", err) } return nil } } -// withSpoofing sets address spoofing in the given NIC, allowing +// withSpoofing sets address spoofing in the given NICs, allowing // endpoints to bind to any address in the NIC. -func withSpoofing(v bool) Option { - return func(s *gvStack) error { - if err := s.SetSpoofing(s.nicID, v); err != nil { +func withSpoofing(nicID tcpip.NICID, v bool) option.Option { + return func(s *stack.Stack) error { + if err := s.SetSpoofing(nicID, v); err != nil { return fmt.Errorf("set spoofing: %s", err) } return nil diff --git a/listener/tun/ipstack/gvisor/opts.go b/listener/tun/ipstack/gvisor/option/option.go similarity index 93% rename from listener/tun/ipstack/gvisor/opts.go rename to listener/tun/ipstack/gvisor/option/option.go index 7fd5a65b..34508e0d 100644 --- a/listener/tun/ipstack/gvisor/opts.go +++ b/listener/tun/ipstack/gvisor/option/option.go @@ -1,4 +1,4 @@ -package gvisor +package option import ( "fmt" @@ -7,6 +7,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" ) @@ -56,11 +57,11 @@ const ( tcpDefaultBufferSize = 212 << 10 // 212 KiB ) -type Option func(*gvStack) error +type Option func(*stack.Stack) error // WithDefault sets all default values for stack. func WithDefault() Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { opts := []Option{ WithDefaultTTL(defaultTimeToLive), WithForwarding(ipForwardingEnabled), @@ -110,7 +111,7 @@ func WithDefault() Option { // WithDefaultTTL sets the default TTL used by stack. func WithDefaultTTL(ttl uint8) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { opt := tcpip.DefaultTTLOption(ttl) if err := s.SetNetworkProtocolOption(ipv4.ProtocolNumber, &opt); err != nil { return fmt.Errorf("set ipv4 default TTL: %s", err) @@ -124,7 +125,7 @@ func WithDefaultTTL(ttl uint8) Option { // WithForwarding sets packet forwarding between NICs for IPv4 & IPv6. func WithForwarding(v bool) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, v); err != nil { return fmt.Errorf("set ipv4 forwarding: %s", err) } @@ -138,7 +139,7 @@ func WithForwarding(v bool) Option { // WithICMPBurst sets the number of ICMP messages that can be sent // in a single burst. func WithICMPBurst(burst int) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { s.SetICMPBurst(burst) return nil } @@ -147,7 +148,7 @@ func WithICMPBurst(burst int) Option { // WithICMPLimit sets the maximum number of ICMP messages permitted // by rate limiter. func WithICMPLimit(limit rate.Limit) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { s.SetICMPLimit(limit) return nil } @@ -155,7 +156,7 @@ func WithICMPLimit(limit rate.Limit) Option { // WithTCPBufferSizeRange sets the receive and send buffer size range for TCP. func WithTCPBufferSizeRange(a, b, c int) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { rcvOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: a, Default: b, Max: c} if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvOpt); err != nil { return fmt.Errorf("set TCP receive buffer size range: %s", err) @@ -170,7 +171,7 @@ func WithTCPBufferSizeRange(a, b, c int) Option { // WithTCPCongestionControl sets the current congestion control algorithm. func WithTCPCongestionControl(cc string) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { opt := tcpip.CongestionControlOption(cc) if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { return fmt.Errorf("set TCP congestion control algorithm: %s", err) @@ -181,7 +182,7 @@ func WithTCPCongestionControl(cc string) Option { // WithTCPDelay enables or disables Nagle's algorithm in TCP. func WithTCPDelay(v bool) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { opt := tcpip.TCPDelayEnabled(v) if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { return fmt.Errorf("set TCP delay: %s", err) @@ -192,7 +193,7 @@ func WithTCPDelay(v bool) Option { // WithTCPModerateReceiveBuffer sets receive buffer moderation for TCP. func WithTCPModerateReceiveBuffer(v bool) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { opt := tcpip.TCPModerateReceiveBufferOption(v) if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { return fmt.Errorf("set TCP moderate receive buffer: %s", err) @@ -203,7 +204,7 @@ func WithTCPModerateReceiveBuffer(v bool) Option { // WithTCPSACKEnabled sets the SACK option for TCP. func WithTCPSACKEnabled(v bool) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { opt := tcpip.TCPSACKEnabled(v) if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { return fmt.Errorf("set TCP SACK: %s", err) @@ -214,7 +215,7 @@ func WithTCPSACKEnabled(v bool) Option { // WithTCPRecovery sets the recovery option for TCP. func WithTCPRecovery(v tcpip.TCPRecovery) Option { - return func(s *gvStack) error { + return func(s *stack.Stack) error { if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { return fmt.Errorf("set TCP Recovery: %s", err) } diff --git a/listener/tun/ipstack/gvisor/icmp.go b/listener/tun/ipstack/gvisor/route.go similarity index 51% rename from listener/tun/ipstack/gvisor/icmp.go rename to listener/tun/ipstack/gvisor/route.go index 8b56d397..5a3d3bf4 100644 --- a/listener/tun/ipstack/gvisor/icmp.go +++ b/listener/tun/ipstack/gvisor/route.go @@ -1,22 +1,23 @@ package gvisor import ( + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) -func withICMPHandler() Option { - return func(s *gvStack) error { - // Add default route table for IPv4 and IPv6. - // This will handle all incoming ICMP packets. +func withRouteTable(nicID tcpip.NICID) option.Option { + return func(s *stack.Stack) error { s.SetRouteTable([]tcpip.Route{ { Destination: header.IPv4EmptySubnet, - NIC: s.nicID, + NIC: nicID, }, { Destination: header.IPv6EmptySubnet, - NIC: s.nicID, + NIC: nicID, }, }) return nil diff --git a/listener/tun/ipstack/gvisor/stack.go b/listener/tun/ipstack/gvisor/stack.go index 104c0300..9061995d 100644 --- a/listener/tun/ipstack/gvisor/stack.go +++ b/listener/tun/ipstack/gvisor/stack.go @@ -5,6 +5,7 @@ import ( "github.com/Dreamacro/clash/listener/tun/device" "github.com/Dreamacro/clash/listener/tun/ipstack" "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter" + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -18,23 +19,23 @@ import ( type gvStack struct { *stack.Stack device device.Device - - handler adapter.Handler - nicID tcpip.NICID } func (s *gvStack) Close() error { + var err error + if s.device != nil { + err = s.device.Close() + s.device.Wait() + } if s.Stack != nil { s.Stack.Close() + s.Stack.Wait() } - if s.device != nil { - _ = s.device.Close() - } - return nil + return err } // New allocates a new *gvStack with given options. -func New(device device.Device, handler adapter.Handler, opts ...Option) (ipstack.Stack, error) { +func New(device device.Device, handler adapter.Handler, opts ...option.Option) (ipstack.Stack, error) { s := &gvStack{ Stack: stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ @@ -49,19 +50,15 @@ func New(device device.Device, handler adapter.Handler, opts ...Option) (ipstack }, }), - device: device, - handler: handler, - nicID: defaultNICID, + device: device, } - opts = append(opts, - // Important: We must initiate transport protocol handlers - // before creating NIC, otherwise NIC would dispatch packets - // to stack and cause race condition. - withICMPHandler(), withTCPHandler(), withUDPHandler(), + // Generate unique NIC id. + nicID := tcpip.NICID(s.Stack.UniqueID()) - // Create stack NIC and then bind link endpoint. - withCreatingNIC(device.(stack.LinkEndpoint)), + opts = append(opts, + // Create stack NIC and then bind link endpoint to it. + withCreatingNIC(nicID, device), // In the past we did s.AddAddressRange to assign 0.0.0.0/0 // onto the interface. We need that to be able to terminate @@ -70,27 +67,34 @@ func New(device device.Device, handler adapter.Handler, opts ...Option) (ipstack // Promiscuous mode. https://github.com/google/gvisor/issues/3876 // // Ref: https://github.com/cloudflare/slirpnetstack/blob/master/stack.go - withPromiscuousMode(nicPromiscuousModeEnabled), + withPromiscuousMode(nicID, nicPromiscuousModeEnabled), - // Enable spoofing if a stack may send packets from unowned addresses. - // This change required changes to some netgophers since previously, - // promiscuous mode was enough to let the netstack respond to all - // incoming packets regardless of the packet's destination address. Now - // that a stack.Route is not held for each incoming packet, finding a route - // may fail with local addresses we don't own but accepted packets for - // while in promiscuous mode. Since we also want to be able to send from - // any address (in response the received promiscuous mode packets), we need - // to enable spoofing. + // Enable spoofing if a stack may send packets from unowned + // addresses. This change required changes to some netgophers + // since previously, promiscuous mode was enough to let the + // netstack respond to all incoming packets regardless of the + // packet's destination address. Now that a stack.Route is not + // held for each incoming packet, finding a route may fail with + // local addresses we don't own but accepted packets for while + // in promiscuous mode. Since we also want to be able to send + // from any address (in response the received promiscuous mode + // packets), we need to enable spoofing. // // Ref: https://github.com/google/gvisor/commit/8c0701462a84ff77e602f1626aec49479c308127 - withSpoofing(nicSpoofingEnabled), + withSpoofing(nicID, nicSpoofingEnabled), + + // Add default route table for IPv4 and IPv6. This will handle + // all incoming ICMP packets. + withRouteTable(nicID), + + // Initiate transport protocol (TCP/UDP) with given handler. + withTCPHandler(handler.HandleTCP), withUDPHandler(handler.HandleUDP), ) for _, opt := range opts { - if err := opt(s); err != nil { + if err := opt(s.Stack); err != nil { return nil, err } } - return s, nil } diff --git a/listener/tun/ipstack/gvisor/tcp.go b/listener/tun/ipstack/gvisor/tcp.go index 8b8277e0..8bffb932 100644 --- a/listener/tun/ipstack/gvisor/tcp.go +++ b/listener/tun/ipstack/gvisor/tcp.go @@ -4,6 +4,9 @@ import ( "fmt" "time" + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter" + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -37,9 +40,9 @@ const ( tcpKeepaliveInterval = 30 * time.Second ) -func withTCPHandler() Option { - return func(s *gvStack) error { - tcpForwarder := tcp.NewForwarder(s.Stack, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) { +func withTCPHandler(handle adapter.TCPHandleFunc) option.Option { + return func(s *stack.Stack) error { + tcpForwarder := tcp.NewForwarder(s, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) { var wq waiter.Queue ep, err := r.CreateEndpoint(&wq) if err != nil { @@ -55,7 +58,7 @@ func withTCPHandler() Option { TCPConn: gonet.NewTCPConn(&wq, ep), id: r.ID(), } - s.handler.HandleTCPConn(conn) + handle(conn) }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) return nil diff --git a/listener/tun/ipstack/gvisor/udp.go b/listener/tun/ipstack/gvisor/udp.go index 6efbd204..688583a0 100644 --- a/listener/tun/ipstack/gvisor/udp.go +++ b/listener/tun/ipstack/gvisor/udp.go @@ -5,6 +5,7 @@ import ( "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter" + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -12,9 +13,9 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -func withUDPHandler() Option { - return func(s *gvStack) error { - udpForwarder := udp.NewForwarder(s.Stack, func(r *udp.ForwarderRequest) { +func withUDPHandler(handle adapter.UDPHandleFunc) option.Option { + return func(s *stack.Stack) error { + udpForwarder := udp.NewForwarder(s, func(r *udp.ForwarderRequest) { var wq waiter.Queue ep, err := r.CreateEndpoint(&wq) if err != nil { @@ -23,10 +24,10 @@ func withUDPHandler() Option { } conn := &udpConn{ - UDPConn: gonet.NewUDPConn(s.Stack, &wq, ep), + UDPConn: gonet.NewUDPConn(s, &wq, ep), id: r.ID(), } - s.handler.HandleUDPConn(conn) + handle(conn) }) s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) return nil diff --git a/listener/tun/ipstack/system/stack.go b/listener/tun/ipstack/system/stack.go index ad3f465d..d8b250ba 100644 --- a/listener/tun/ipstack/system/stack.go +++ b/listener/tun/ipstack/system/stack.go @@ -36,8 +36,6 @@ func (s sysStack) Close() error { return nil } -var ipv4LoopBack = netip.MustParsePrefix("127.0.0.0/8") - func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Prefix, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.Stack, error) { var ( gateway = tunAddress.Masked().Addr().Next() @@ -71,12 +69,6 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref rAddrIp, _ := netip.AddrFromSlice(rAddr.IP) rAddrPort := netip.AddrPortFrom(rAddrIp, uint16(rAddr.Port)) - if ipv4LoopBack.Contains(rAddrIp) { - conn.Close() - - continue - } - if D.ShouldHijackDns(dnsAddr, rAddrPort) { go func() { log.Debugln("[TUN] hijack dns tcp: %s", rAddrPort.String()) @@ -149,12 +141,6 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref rAddrIp, _ := netip.AddrFromSlice(rAddr.IP) rAddrPort := netip.AddrPortFrom(rAddrIp, uint16(rAddr.Port)) - if ipv4LoopBack.Contains(rAddrIp) { - pool.Put(buf) - - continue - } - if D.ShouldHijackDns(dnsAddr, rAddrPort) { go func() { defer pool.Put(buf) diff --git a/listener/tun/tun_adapter.go b/listener/tun/tun_adapter.go index 333a85bd..75d568e1 100644 --- a/listener/tun/tun_adapter.go +++ b/listener/tun/tun_adapter.go @@ -18,6 +18,7 @@ import ( "github.com/Dreamacro/clash/listener/tun/ipstack" "github.com/Dreamacro/clash/listener/tun/ipstack/commons" "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor" + "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" "github.com/Dreamacro/clash/listener/tun/ipstack/system" "github.com/Dreamacro/clash/log" ) @@ -67,7 +68,7 @@ func New(tunConf *config.Tun, tunAddressPrefix string, tcpIn chan<- C.ConnContex DNSAdds: tunConf.DNSHijack, TCPIn: tcpIn, UDPIn: udpIn, }, - gvisor.WithDefault(), + option.WithDefault(), ) if err != nil { diff --git a/main.go b/main.go index bc8c7a55..104a4964 100644 --- a/main.go +++ b/main.go @@ -100,13 +100,9 @@ func main() { log.Fatalln("Parse config error: %s", err.Error()) } + defer executor.Shutdown() + sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) <-sigCh - - // cleanup - log.Warnln("Clash cleanup") - hub.Cleanup() - - log.Warnln("Clash shutting down") }