diff --git a/listener/tun/ipstack/system/mars/mars.go b/listener/tun/ipstack/system/mars/mars.go index e150437e..a553c2d6 100644 --- a/listener/tun/ipstack/system/mars/mars.go +++ b/listener/tun/ipstack/system/mars/mars.go @@ -27,10 +27,8 @@ func StartListener(device io.ReadWriteCloser, gateway, portal, broadcast netip.A } func (t *StackListener) Close() error { - _ = t.tcp.Close() _ = t.udp.Close() - - return t.device.Close() + return t.tcp.Close() } func (t *StackListener) TCP() *nat.TCP { diff --git a/listener/tun/ipstack/system/stack.go b/listener/tun/ipstack/system/stack.go index 92751d36..803e5db0 100644 --- a/listener/tun/ipstack/system/stack.go +++ b/listener/tun/ipstack/system/stack.go @@ -7,6 +7,7 @@ import ( "net/netip" "runtime" "strconv" + "sync" "time" "github.com/Dreamacro/clash/adapter/inbound" @@ -28,6 +29,8 @@ type sysStack struct { device device.Device closed bool + once sync.Once + wg sync.WaitGroup } func (s *sysStack) Close() error { @@ -38,10 +41,12 @@ func (s *sysStack) Close() error { }() s.closed = true - if s.stack != nil { - return s.stack.Close() - } - return nil + + err := s.stack.Close() + + s.wg.Wait() + + return err } func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Prefix, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.Stack, error) { @@ -67,16 +72,10 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref _ = tcp.Close() }(stack.TCP()) - defer log.Debugln("TCP: closed") - for !ipStack.closed { - if err = stack.TCP().SetDeadline(time.Time{}); err != nil { - break - } - conn, err := stack.TCP().Accept() if err != nil { - log.Debugln("Accept connection: %v", err) + log.Debugln("[STACK] accept connection error: %v", err) continue } @@ -146,6 +145,8 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref tcpIn <- context.NewConnContext(conn, metadata) } + + ipStack.wg.Done() } udp := func() { @@ -153,14 +154,13 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref _ = udp.Close() }(stack.UDP()) - defer log.Debugln("UDP: closed") - for !ipStack.closed { buf := pool.Get(pool.UDPBufferSize) n, lRAddr, rRAddr, err := stack.UDP().ReadFrom(buf) if err != nil { - return + _ = pool.Put(buf) + break } raw := buf[:n] @@ -209,17 +209,23 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref default: } } + + ipStack.wg.Done() } - go tcp() + ipStack.once.Do(func() { + ipStack.wg.Add(1) + go tcp() - numUDPWorkers := 4 - if num := runtime.GOMAXPROCS(0); num > numUDPWorkers { - numUDPWorkers = num - } - for i := 0; i < numUDPWorkers; i++ { - go udp() - } + numUDPWorkers := 4 + if num := runtime.GOMAXPROCS(0); num > numUDPWorkers { + numUDPWorkers = num + } + for i := 0; i < numUDPWorkers; i++ { + ipStack.wg.Add(1) + go udp() + } + }) return ipStack, nil }