From b179d09efb8ed2c77a2b7a60ccc4e715fc274c0d Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Tue, 12 Apr 2022 22:33:10 +0800 Subject: [PATCH] Chore: adjust ipstack --- listener/tun/device/device.go | 3 - listener/tun/device/iobased/endpoint.go | 27 ++++++-- listener/tun/device/tun/tun_wireguard.go | 17 ++---- listener/tun/device/tun/tun_wireguard_unix.go | 8 +++ ...un_windows.go => tun_wireguard_windows.go} | 5 ++ listener/tun/ipstack/gvisor/option/option.go | 58 ++++++++++++++---- listener/tun/ipstack/gvisor/stack.go | 1 - listener/tun/ipstack/gvisor/tcp.go | 61 +++++++++++++------ listener/tun/ipstack/gvisor/udp.go | 14 +++-- listener/tun/ipstack/system/stack.go | 59 ++++++++++++------ 10 files changed, 180 insertions(+), 73 deletions(-) create mode 100644 listener/tun/device/tun/tun_wireguard_unix.go rename listener/tun/device/tun/{tun_windows.go => tun_wireguard_windows.go} (84%) diff --git a/listener/tun/device/device.go b/listener/tun/device/device.go index 73b03dee..70115cbd 100644 --- a/listener/tun/device/device.go +++ b/listener/tun/device/device.go @@ -29,7 +29,4 @@ type Device interface { // UseIOBased work for other ip stack UseIOBased() error - - // Wait waits for the device to close. - Wait() } diff --git a/listener/tun/device/iobased/endpoint.go b/listener/tun/device/iobased/endpoint.go index c0942d10..e0a4583a 100644 --- a/listener/tun/device/iobased/endpoint.go +++ b/listener/tun/device/iobased/endpoint.go @@ -36,6 +36,9 @@ type Endpoint struct { // once is used to perform the init action once when attaching. once sync.Once + + // wg keeps track of running goroutines. + wg sync.WaitGroup } // New returns stack.LinkEndpoint(.*Endpoint) and error. @@ -60,19 +63,26 @@ func New(rw io.ReadWriter, mtu uint32, offset int) (*Endpoint, error) { }, nil } -func (e *Endpoint) Close() { - e.Endpoint.Close() +func (e *Endpoint) Wait() { + e.wg.Wait() } // Attach launches the goroutine that reads packets from io.Reader and // dispatches them via the provided dispatcher. func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.Endpoint.Attach(dispatcher) e.once.Do(func() { ctx, cancel := context.WithCancel(context.Background()) - go e.dispatchLoop(cancel) - go e.outboundLoop(ctx) + e.wg.Add(2) + go func() { + e.outboundLoop(ctx) + e.wg.Done() + }() + go func() { + e.dispatchLoop(cancel) + e.wg.Done() + }() }) - e.Endpoint.Attach(dispatcher) } // dispatchLoop dispatches packets to upper layer. @@ -81,14 +91,19 @@ func (e *Endpoint) dispatchLoop(cancel context.CancelFunc) { // gracefully after (*Endpoint).dispatchLoop(context.CancelFunc) returns. defer cancel() + mtu := int(e.mtu) for { - data := make([]byte, int(e.mtu)) + data := make([]byte, mtu) n, err := e.rw.Read(data) if err != nil { break } + if n == 0 || n > mtu { + continue + } + if !e.IsAttached() { continue /* unattached, drop packet */ } diff --git a/listener/tun/device/tun/tun_wireguard.go b/listener/tun/device/tun/tun_wireguard.go index 35008425..529a1054 100644 --- a/listener/tun/device/tun/tun_wireguard.go +++ b/listener/tun/device/tun/tun_wireguard.go @@ -32,15 +32,6 @@ func Open(name string, mtu uint32) (_ device.Device, err error) { } }() - var ( - offset = 4 /* 4 bytes TUN_PI */ - defaultMTU = 1500 - ) - if runtime.GOOS == "windows" { - offset = 0 - defaultMTU = 0 /* auto */ - } - t := &TUN{ name: name, mtu: mtu, @@ -101,9 +92,11 @@ func (t *TUN) Write(packet []byte) (int, error) { } func (t *TUN) Close() error { - if t.Endpoint != nil { - t.Endpoint.Close() - } + defer func(ep *iobased.Endpoint) { + if ep != nil { + ep.Close() + } + }(t.Endpoint) return t.nt.Close() } diff --git a/listener/tun/device/tun/tun_wireguard_unix.go b/listener/tun/device/tun/tun_wireguard_unix.go new file mode 100644 index 00000000..d88787fb --- /dev/null +++ b/listener/tun/device/tun/tun_wireguard_unix.go @@ -0,0 +1,8 @@ +//go:build !linux && !windows + +package tun + +const ( + offset = 4 /* 4 bytes TUN_PI */ + defaultMTU = 1500 +) diff --git a/listener/tun/device/tun/tun_windows.go b/listener/tun/device/tun/tun_wireguard_windows.go similarity index 84% rename from listener/tun/device/tun/tun_windows.go rename to listener/tun/device/tun/tun_wireguard_windows.go index 29877440..f73286d8 100644 --- a/listener/tun/device/tun/tun_windows.go +++ b/listener/tun/device/tun/tun_wireguard_windows.go @@ -5,6 +5,11 @@ import ( "golang.zx2c4.com/wireguard/tun" ) +const ( + offset = 0 + defaultMTU = 0 /* auto */ +) + func init() { guid, _ := windows.GUIDFromString("{330EAEF8-7578-5DF2-D97B-8DADC0EA85CB}") diff --git a/listener/tun/ipstack/gvisor/option/option.go b/listener/tun/ipstack/gvisor/option/option.go index 34508e0d..2076fd58 100644 --- a/listener/tun/ipstack/gvisor/option/option.go +++ b/listener/tun/ipstack/gvisor/option/option.go @@ -37,7 +37,7 @@ const ( // tcpModerateReceiveBufferEnabled is the value used by stack to // enable or disable tcp receive buffer auto-tuning option. - tcpModerateReceiveBufferEnabled = true + tcpModerateReceiveBufferEnabled = false // tcpSACKEnabled is the value used by stack to enable or disable // tcp selective ACK. @@ -47,14 +47,18 @@ const ( tcpRecovery = tcpip.TCPRACKLossDetection // tcpMinBufferSize is the smallest size of a send/recv buffer. - tcpMinBufferSize = tcp.MinBufferSize // 4 KiB + tcpMinBufferSize = tcp.MinBufferSize // tcpMaxBufferSize is the maximum permitted size of a send/recv buffer. - tcpMaxBufferSize = tcp.MaxBufferSize // 4 MiB + tcpMaxBufferSize = tcp.MaxBufferSize - // tcpDefaultBufferSize is the default size of the send/recv buffer for + // tcpDefaultBufferSize is the default size of the send buffer for // a transport endpoint. - tcpDefaultBufferSize = 212 << 10 // 212 KiB + tcpDefaultSendBufferSize = tcp.DefaultSendBufferSize + + // tcpDefaultReceiveBufferSize is the default size of the receive buffer + // for a transport endpoint. + tcpDefaultReceiveBufferSize = tcp.DefaultReceiveBufferSize ) type Option func(*stack.Stack) error @@ -74,7 +78,8 @@ func WithDefault() Option { // in too large buffers. // // Ref: https://github.com/cloudflare/slirpnetstack/blob/master/stack.go - WithTCPBufferSizeRange(tcpMinBufferSize, tcpDefaultBufferSize, tcpMaxBufferSize), + WithTCPSendBufferSizeRange(tcpMinBufferSize, tcpDefaultSendBufferSize, tcpMaxBufferSize), + WithTCPReceiveBufferSizeRange(tcpMinBufferSize, tcpDefaultReceiveBufferSize, tcpMaxBufferSize), WithTCPCongestionControl(tcpCongestionControlAlgorithm), WithTCPDelay(tcpDelayEnabled), @@ -154,17 +159,46 @@ func WithICMPLimit(limit rate.Limit) Option { } } -// WithTCPBufferSizeRange sets the receive and send buffer size range for TCP. -func WithTCPBufferSizeRange(a, b, c int) Option { +// WithTCPSendBufferSize sets default the send buffer size for TCP. +func WithTCPSendBufferSize(size int) Option { + return func(s *stack.Stack) error { + sndOpt := tcpip.TCPSendBufferSizeRangeOption{Min: tcpMinBufferSize, Default: size, Max: tcpMaxBufferSize} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sndOpt); err != nil { + return fmt.Errorf("set TCP send buffer size range: %s", err) + } + return nil + } +} + +// WithTCPSendBufferSizeRange sets the send buffer size range for TCP. +func WithTCPSendBufferSizeRange(a, b, c int) Option { + return func(s *stack.Stack) error { + sndOpt := tcpip.TCPSendBufferSizeRangeOption{Min: a, Default: b, Max: c} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sndOpt); err != nil { + return fmt.Errorf("set TCP send buffer size range: %s", err) + } + return nil + } +} + +// WithTCPReceiveBufferSize sets the default receive buffer size for TCP. +func WithTCPReceiveBufferSize(size int) Option { + return func(s *stack.Stack) error { + rcvOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: tcpMinBufferSize, Default: size, Max: tcpMaxBufferSize} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvOpt); err != nil { + return fmt.Errorf("set TCP receive buffer size range: %s", err) + } + return nil + } +} + +// WithTCPReceiveBufferSizeRange sets the receive buffer size range for TCP. +func WithTCPReceiveBufferSizeRange(a, b, c int) Option { return func(s *stack.Stack) error { rcvOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: a, Default: b, Max: c} if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvOpt); err != nil { return fmt.Errorf("set TCP receive buffer size range: %s", err) } - sndOpt := tcpip.TCPSendBufferSizeRangeOption{Min: a, Default: b, Max: c} - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sndOpt); err != nil { - return fmt.Errorf("set TCP send buffer size range: %s", err) - } return nil } } diff --git a/listener/tun/ipstack/gvisor/stack.go b/listener/tun/ipstack/gvisor/stack.go index 9061995d..c762e6a2 100644 --- a/listener/tun/ipstack/gvisor/stack.go +++ b/listener/tun/ipstack/gvisor/stack.go @@ -25,7 +25,6 @@ func (s *gvStack) Close() error { var err error if s.device != nil { err = s.device.Close() - s.device.Wait() } if s.Stack != nil { s.Stack.Close() diff --git a/listener/tun/ipstack/gvisor/tcp.go b/listener/tun/ipstack/gvisor/tcp.go index 8bffb932..61f5d90e 100644 --- a/listener/tun/ipstack/gvisor/tcp.go +++ b/listener/tun/ipstack/gvisor/tcp.go @@ -1,14 +1,15 @@ package gvisor import ( - "fmt" "time" "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter" "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" + "github.com/Dreamacro/clash/log" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/waiter" @@ -43,8 +44,21 @@ const ( func withTCPHandler(handle adapter.TCPHandleFunc) option.Option { return func(s *stack.Stack) error { tcpForwarder := tcp.NewForwarder(s, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) { - var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) + var ( + wq waiter.Queue + ep tcpip.Endpoint + err tcpip.Error + id = r.ID() + ) + + defer func() { + if err != nil { + log.Warnln("[STACK] forward tcp request %s:%d->%s:%d: %s", id.RemoteAddress, id.RemotePort, id.LocalAddress, id.LocalPort, err) + } + }() + + // Perform a TCP three-way handshake. + ep, err = r.CreateEndpoint(&wq) if err != nil { // RST: prevent potential half-open TCP connection leak. r.Complete(true) @@ -52,11 +66,11 @@ func withTCPHandler(handle adapter.TCPHandleFunc) option.Option { } defer r.Complete(false) - setKeepalive(ep) + err = setSocketOptions(s, ep) conn := &tcpConn{ TCPConn: gonet.NewTCPConn(&wq, ep), - id: r.ID(), + id: id, } handle(conn) }) @@ -65,21 +79,34 @@ func withTCPHandler(handle adapter.TCPHandleFunc) option.Option { } } -func setKeepalive(ep tcpip.Endpoint) error { - ep.SocketOptions().SetKeepAlive(true) +func setSocketOptions(s *stack.Stack, ep tcpip.Endpoint) tcpip.Error { + { /* TCP keepalive options */ + ep.SocketOptions().SetKeepAlive(true) - idle := tcpip.KeepaliveIdleOption(tcpKeepaliveIdle) - if err := ep.SetSockOpt(&idle); err != nil { - return fmt.Errorf("set keepalive idle: %s", err) + idle := tcpip.KeepaliveIdleOption(tcpKeepaliveIdle) + if err := ep.SetSockOpt(&idle); err != nil { + return err + } + + interval := tcpip.KeepaliveIntervalOption(tcpKeepaliveInterval) + if err := ep.SetSockOpt(&interval); err != nil { + return err + } + + if err := ep.SetSockOptInt(tcpip.KeepaliveCountOption, tcpKeepaliveCount); err != nil { + return err + } } + { /* TCP recv/send buffer size */ + var ss tcpip.TCPSendBufferSizeRangeOption + if err := s.TransportProtocolOption(header.TCPProtocolNumber, &ss); err == nil { + ep.SocketOptions().SetReceiveBufferSize(int64(ss.Default), false) + } - interval := tcpip.KeepaliveIntervalOption(tcpKeepaliveInterval) - if err := ep.SetSockOpt(&interval); err != nil { - return fmt.Errorf("set keepalive interval: %s", err) - } - - if err := ep.SetSockOptInt(tcpip.KeepaliveCountOption, tcpKeepaliveCount); err != nil { - return fmt.Errorf("set keepalive count: %s", err) + var rs tcpip.TCPReceiveBufferSizeRangeOption + if err := s.TransportProtocolOption(header.TCPProtocolNumber, &rs); err == nil { + ep.SocketOptions().SetReceiveBufferSize(int64(rs.Default), false) + } } return nil } diff --git a/listener/tun/ipstack/gvisor/udp.go b/listener/tun/ipstack/gvisor/udp.go index 688583a0..502e3a9c 100644 --- a/listener/tun/ipstack/gvisor/udp.go +++ b/listener/tun/ipstack/gvisor/udp.go @@ -6,6 +6,7 @@ import ( "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter" "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" + "github.com/Dreamacro/clash/log" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -16,16 +17,19 @@ import ( func withUDPHandler(handle adapter.UDPHandleFunc) option.Option { return func(s *stack.Stack) error { udpForwarder := udp.NewForwarder(s, func(r *udp.ForwarderRequest) { - var wq waiter.Queue + var ( + wq waiter.Queue + id = r.ID() + ) ep, err := r.CreateEndpoint(&wq) if err != nil { - // TODO: handler errors in the future. + log.Warnln("[STACK] udp forwarder request %s:%d->%s:%d: %s", id.RemoteAddress, id.RemotePort, id.LocalAddress, id.LocalPort, err) return } conn := &udpConn{ UDPConn: gonet.NewUDPConn(s, &wq, ep), - id: r.ID(), + id: id, } handle(conn) }) @@ -54,7 +58,7 @@ func (c *packet) Data() []byte { } // WriteBack write UDP packet with source(ip, port) = `addr` -func (c *packet) WriteBack(b []byte, addr net.Addr) (n int, err error) { +func (c *packet) WriteBack(b []byte, _ net.Addr) (n int, err error) { return c.pc.WriteTo(b, c.rAddr) } @@ -64,5 +68,5 @@ func (c *packet) LocalAddr() net.Addr { } func (c *packet) Drop() { - pool.Put(c.payload) + _ = pool.Put(c.payload) } diff --git a/listener/tun/ipstack/system/stack.go b/listener/tun/ipstack/system/stack.go index d8b250ba..d0b24e3a 100644 --- a/listener/tun/ipstack/system/stack.go +++ b/listener/tun/ipstack/system/stack.go @@ -17,6 +17,7 @@ import ( "github.com/Dreamacro/clash/listener/tun/ipstack" D "github.com/Dreamacro/clash/listener/tun/ipstack/commons" "github.com/Dreamacro/clash/listener/tun/ipstack/system/mars" + "github.com/Dreamacro/clash/listener/tun/ipstack/system/mars/nat" "github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/transport/socks5" ) @@ -24,14 +25,20 @@ import ( type sysStack struct { stack io.Closer device device.Device + + closed bool } -func (s sysStack) Close() error { +func (s *sysStack) Close() error { + defer func() { + if s.device != nil { + _ = s.device.Close() + } + }() + + s.closed = true if s.stack != nil { - _ = s.stack.Close() - } - if s.device != nil { - _ = s.device.Close() + return s.stack.Close() } return nil } @@ -49,17 +56,25 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref return nil, err } + ipStack := &sysStack{stack: stack, device: device} + dnsAddr := dnsHijack tcp := func() { - defer stack.TCP().Close() + defer func(tcp *nat.TCP) { + _ = tcp.Close() + }(stack.TCP()) + defer log.Debugln("TCP: closed") - for stack.TCP().SetDeadline(time.Time{}) == nil { + for !ipStack.closed { + if err = stack.TCP().SetDeadline(time.Time{}); err != nil { + break + } + conn, err := stack.TCP().Accept() if err != nil { log.Debugln("Accept connection: %v", err) - continue } @@ -73,13 +88,19 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref go func() { log.Debugln("[TUN] hijack dns tcp: %s", rAddrPort.String()) - defer conn.Close() + defer func(conn net.Conn) { + _ = conn.Close() + }(conn) buf := pool.Get(pool.UDPBufferSize) - defer pool.Put(buf) + defer func(buf []byte) { + _ = pool.Put(buf) + }(buf) for { - conn.SetReadDeadline(time.Now().Add(C.DefaultTCPTimeout)) + if err = conn.SetReadDeadline(time.Now().Add(C.DefaultTCPTimeout)); err != nil { + break + } length := uint16(0) if err := binary.Read(conn, binary.BigEndian, &length); err != nil { @@ -123,10 +144,13 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref } udp := func() { - defer stack.UDP().Close() + defer func(udp *nat.UDP) { + _ = udp.Close() + }(stack.UDP()) + defer log.Debugln("UDP: closed") - for { + for !ipStack.closed { buf := pool.Get(pool.UDPBufferSize) n, lRAddr, rRAddr, err := stack.UDP().ReadFrom(buf) @@ -143,15 +167,16 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref if D.ShouldHijackDns(dnsAddr, rAddrPort) { go func() { - defer pool.Put(buf) - msg, err := D.RelayDnsPacket(raw) if err != nil { + _ = pool.Put(buf) return } _, _ = stack.UDP().WriteTo(msg, rAddr, lAddr) + _ = pool.Put(buf) + log.Debugln("[TUN] hijack dns udp: %s", rAddrPort.String()) }() @@ -165,7 +190,7 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref return stack.UDP().WriteTo(b, rAddr, lAddr) }, drop: func() { - pool.Put(buf) + _ = pool.Put(buf) }, } @@ -186,5 +211,5 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref go udp() } - return &sysStack{stack: stack, device: device}, nil + return ipStack, nil }