Chore: adjust ipstack

This commit is contained in:
yaling888 2022-04-12 22:33:10 +08:00 committed by Meta Gowork
parent 4be17653e0
commit b179d09efb
10 changed files with 180 additions and 73 deletions

View file

@ -29,7 +29,4 @@ type Device interface {
// UseIOBased work for other ip stack // UseIOBased work for other ip stack
UseIOBased() error UseIOBased() error
// Wait waits for the device to close.
Wait()
} }

View file

@ -36,6 +36,9 @@ type Endpoint struct {
// once is used to perform the init action once when attaching. // once is used to perform the init action once when attaching.
once sync.Once once sync.Once
// wg keeps track of running goroutines.
wg sync.WaitGroup
} }
// New returns stack.LinkEndpoint(.*Endpoint) and error. // New returns stack.LinkEndpoint(.*Endpoint) and error.
@ -60,19 +63,26 @@ func New(rw io.ReadWriter, mtu uint32, offset int) (*Endpoint, error) {
}, nil }, nil
} }
func (e *Endpoint) Close() { func (e *Endpoint) Wait() {
e.Endpoint.Close() e.wg.Wait()
} }
// Attach launches the goroutine that reads packets from io.Reader and // Attach launches the goroutine that reads packets from io.Reader and
// dispatches them via the provided dispatcher. // dispatches them via the provided dispatcher.
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.Endpoint.Attach(dispatcher)
e.once.Do(func() { e.once.Do(func() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
go e.dispatchLoop(cancel) e.wg.Add(2)
go e.outboundLoop(ctx) go func() {
e.outboundLoop(ctx)
e.wg.Done()
}()
go func() {
e.dispatchLoop(cancel)
e.wg.Done()
}()
}) })
e.Endpoint.Attach(dispatcher)
} }
// dispatchLoop dispatches packets to upper layer. // dispatchLoop dispatches packets to upper layer.
@ -81,14 +91,19 @@ func (e *Endpoint) dispatchLoop(cancel context.CancelFunc) {
// gracefully after (*Endpoint).dispatchLoop(context.CancelFunc) returns. // gracefully after (*Endpoint).dispatchLoop(context.CancelFunc) returns.
defer cancel() defer cancel()
mtu := int(e.mtu)
for { for {
data := make([]byte, int(e.mtu)) data := make([]byte, mtu)
n, err := e.rw.Read(data) n, err := e.rw.Read(data)
if err != nil { if err != nil {
break break
} }
if n == 0 || n > mtu {
continue
}
if !e.IsAttached() { if !e.IsAttached() {
continue /* unattached, drop packet */ continue /* unattached, drop packet */
} }

View file

@ -32,15 +32,6 @@ func Open(name string, mtu uint32) (_ device.Device, err error) {
} }
}() }()
var (
offset = 4 /* 4 bytes TUN_PI */
defaultMTU = 1500
)
if runtime.GOOS == "windows" {
offset = 0
defaultMTU = 0 /* auto */
}
t := &TUN{ t := &TUN{
name: name, name: name,
mtu: mtu, mtu: mtu,
@ -101,9 +92,11 @@ func (t *TUN) Write(packet []byte) (int, error) {
} }
func (t *TUN) Close() error { func (t *TUN) Close() error {
if t.Endpoint != nil { defer func(ep *iobased.Endpoint) {
t.Endpoint.Close() if ep != nil {
} ep.Close()
}
}(t.Endpoint)
return t.nt.Close() return t.nt.Close()
} }

View file

@ -0,0 +1,8 @@
//go:build !linux && !windows
package tun
const (
offset = 4 /* 4 bytes TUN_PI */
defaultMTU = 1500
)

View file

@ -5,6 +5,11 @@ import (
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
) )
const (
offset = 0
defaultMTU = 0 /* auto */
)
func init() { func init() {
guid, _ := windows.GUIDFromString("{330EAEF8-7578-5DF2-D97B-8DADC0EA85CB}") guid, _ := windows.GUIDFromString("{330EAEF8-7578-5DF2-D97B-8DADC0EA85CB}")

View file

@ -37,7 +37,7 @@ const (
// tcpModerateReceiveBufferEnabled is the value used by stack to // tcpModerateReceiveBufferEnabled is the value used by stack to
// enable or disable tcp receive buffer auto-tuning option. // enable or disable tcp receive buffer auto-tuning option.
tcpModerateReceiveBufferEnabled = true tcpModerateReceiveBufferEnabled = false
// tcpSACKEnabled is the value used by stack to enable or disable // tcpSACKEnabled is the value used by stack to enable or disable
// tcp selective ACK. // tcp selective ACK.
@ -47,14 +47,18 @@ const (
tcpRecovery = tcpip.TCPRACKLossDetection tcpRecovery = tcpip.TCPRACKLossDetection
// tcpMinBufferSize is the smallest size of a send/recv buffer. // tcpMinBufferSize is the smallest size of a send/recv buffer.
tcpMinBufferSize = tcp.MinBufferSize // 4 KiB tcpMinBufferSize = tcp.MinBufferSize
// tcpMaxBufferSize is the maximum permitted size of a send/recv buffer. // tcpMaxBufferSize is the maximum permitted size of a send/recv buffer.
tcpMaxBufferSize = tcp.MaxBufferSize // 4 MiB tcpMaxBufferSize = tcp.MaxBufferSize
// tcpDefaultBufferSize is the default size of the send/recv buffer for // tcpDefaultBufferSize is the default size of the send buffer for
// a transport endpoint. // a transport endpoint.
tcpDefaultBufferSize = 212 << 10 // 212 KiB tcpDefaultSendBufferSize = tcp.DefaultSendBufferSize
// tcpDefaultReceiveBufferSize is the default size of the receive buffer
// for a transport endpoint.
tcpDefaultReceiveBufferSize = tcp.DefaultReceiveBufferSize
) )
type Option func(*stack.Stack) error type Option func(*stack.Stack) error
@ -74,7 +78,8 @@ func WithDefault() Option {
// in too large buffers. // in too large buffers.
// //
// Ref: https://github.com/cloudflare/slirpnetstack/blob/master/stack.go // Ref: https://github.com/cloudflare/slirpnetstack/blob/master/stack.go
WithTCPBufferSizeRange(tcpMinBufferSize, tcpDefaultBufferSize, tcpMaxBufferSize), WithTCPSendBufferSizeRange(tcpMinBufferSize, tcpDefaultSendBufferSize, tcpMaxBufferSize),
WithTCPReceiveBufferSizeRange(tcpMinBufferSize, tcpDefaultReceiveBufferSize, tcpMaxBufferSize),
WithTCPCongestionControl(tcpCongestionControlAlgorithm), WithTCPCongestionControl(tcpCongestionControlAlgorithm),
WithTCPDelay(tcpDelayEnabled), WithTCPDelay(tcpDelayEnabled),
@ -154,17 +159,46 @@ func WithICMPLimit(limit rate.Limit) Option {
} }
} }
// WithTCPBufferSizeRange sets the receive and send buffer size range for TCP. // WithTCPSendBufferSize sets default the send buffer size for TCP.
func WithTCPBufferSizeRange(a, b, c int) Option { func WithTCPSendBufferSize(size int) Option {
return func(s *stack.Stack) error {
sndOpt := tcpip.TCPSendBufferSizeRangeOption{Min: tcpMinBufferSize, Default: size, Max: tcpMaxBufferSize}
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sndOpt); err != nil {
return fmt.Errorf("set TCP send buffer size range: %s", err)
}
return nil
}
}
// WithTCPSendBufferSizeRange sets the send buffer size range for TCP.
func WithTCPSendBufferSizeRange(a, b, c int) Option {
return func(s *stack.Stack) error {
sndOpt := tcpip.TCPSendBufferSizeRangeOption{Min: a, Default: b, Max: c}
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sndOpt); err != nil {
return fmt.Errorf("set TCP send buffer size range: %s", err)
}
return nil
}
}
// WithTCPReceiveBufferSize sets the default receive buffer size for TCP.
func WithTCPReceiveBufferSize(size int) Option {
return func(s *stack.Stack) error {
rcvOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: tcpMinBufferSize, Default: size, Max: tcpMaxBufferSize}
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvOpt); err != nil {
return fmt.Errorf("set TCP receive buffer size range: %s", err)
}
return nil
}
}
// WithTCPReceiveBufferSizeRange sets the receive buffer size range for TCP.
func WithTCPReceiveBufferSizeRange(a, b, c int) Option {
return func(s *stack.Stack) error { return func(s *stack.Stack) error {
rcvOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: a, Default: b, Max: c} rcvOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: a, Default: b, Max: c}
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvOpt); err != nil { if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvOpt); err != nil {
return fmt.Errorf("set TCP receive buffer size range: %s", err) return fmt.Errorf("set TCP receive buffer size range: %s", err)
} }
sndOpt := tcpip.TCPSendBufferSizeRangeOption{Min: a, Default: b, Max: c}
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sndOpt); err != nil {
return fmt.Errorf("set TCP send buffer size range: %s", err)
}
return nil return nil
} }
} }

View file

@ -25,7 +25,6 @@ func (s *gvStack) Close() error {
var err error var err error
if s.device != nil { if s.device != nil {
err = s.device.Close() err = s.device.Close()
s.device.Wait()
} }
if s.Stack != nil { if s.Stack != nil {
s.Stack.Close() s.Stack.Close()

View file

@ -1,14 +1,15 @@
package gvisor package gvisor
import ( import (
"fmt"
"time" "time"
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter" "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter"
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option"
"github.com/Dreamacro/clash/log"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/pkg/waiter"
@ -43,8 +44,21 @@ const (
func withTCPHandler(handle adapter.TCPHandleFunc) option.Option { func withTCPHandler(handle adapter.TCPHandleFunc) option.Option {
return func(s *stack.Stack) error { return func(s *stack.Stack) error {
tcpForwarder := tcp.NewForwarder(s, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) { tcpForwarder := tcp.NewForwarder(s, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue var (
ep, err := r.CreateEndpoint(&wq) wq waiter.Queue
ep tcpip.Endpoint
err tcpip.Error
id = r.ID()
)
defer func() {
if err != nil {
log.Warnln("[STACK] forward tcp request %s:%d->%s:%d: %s", id.RemoteAddress, id.RemotePort, id.LocalAddress, id.LocalPort, err)
}
}()
// Perform a TCP three-way handshake.
ep, err = r.CreateEndpoint(&wq)
if err != nil { if err != nil {
// RST: prevent potential half-open TCP connection leak. // RST: prevent potential half-open TCP connection leak.
r.Complete(true) r.Complete(true)
@ -52,11 +66,11 @@ func withTCPHandler(handle adapter.TCPHandleFunc) option.Option {
} }
defer r.Complete(false) defer r.Complete(false)
setKeepalive(ep) err = setSocketOptions(s, ep)
conn := &tcpConn{ conn := &tcpConn{
TCPConn: gonet.NewTCPConn(&wq, ep), TCPConn: gonet.NewTCPConn(&wq, ep),
id: r.ID(), id: id,
} }
handle(conn) handle(conn)
}) })
@ -65,21 +79,34 @@ func withTCPHandler(handle adapter.TCPHandleFunc) option.Option {
} }
} }
func setKeepalive(ep tcpip.Endpoint) error { func setSocketOptions(s *stack.Stack, ep tcpip.Endpoint) tcpip.Error {
ep.SocketOptions().SetKeepAlive(true) { /* TCP keepalive options */
ep.SocketOptions().SetKeepAlive(true)
idle := tcpip.KeepaliveIdleOption(tcpKeepaliveIdle) idle := tcpip.KeepaliveIdleOption(tcpKeepaliveIdle)
if err := ep.SetSockOpt(&idle); err != nil { if err := ep.SetSockOpt(&idle); err != nil {
return fmt.Errorf("set keepalive idle: %s", err) return err
}
interval := tcpip.KeepaliveIntervalOption(tcpKeepaliveInterval)
if err := ep.SetSockOpt(&interval); err != nil {
return err
}
if err := ep.SetSockOptInt(tcpip.KeepaliveCountOption, tcpKeepaliveCount); err != nil {
return err
}
} }
{ /* TCP recv/send buffer size */
var ss tcpip.TCPSendBufferSizeRangeOption
if err := s.TransportProtocolOption(header.TCPProtocolNumber, &ss); err == nil {
ep.SocketOptions().SetReceiveBufferSize(int64(ss.Default), false)
}
interval := tcpip.KeepaliveIntervalOption(tcpKeepaliveInterval) var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := ep.SetSockOpt(&interval); err != nil { if err := s.TransportProtocolOption(header.TCPProtocolNumber, &rs); err == nil {
return fmt.Errorf("set keepalive interval: %s", err) ep.SocketOptions().SetReceiveBufferSize(int64(rs.Default), false)
} }
if err := ep.SetSockOptInt(tcpip.KeepaliveCountOption, tcpKeepaliveCount); err != nil {
return fmt.Errorf("set keepalive count: %s", err)
} }
return nil return nil
} }

View file

@ -6,6 +6,7 @@ import (
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter" "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter"
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option" "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option"
"github.com/Dreamacro/clash/log"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
@ -16,16 +17,19 @@ import (
func withUDPHandler(handle adapter.UDPHandleFunc) option.Option { func withUDPHandler(handle adapter.UDPHandleFunc) option.Option {
return func(s *stack.Stack) error { return func(s *stack.Stack) error {
udpForwarder := udp.NewForwarder(s, func(r *udp.ForwarderRequest) { udpForwarder := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
var wq waiter.Queue var (
wq waiter.Queue
id = r.ID()
)
ep, err := r.CreateEndpoint(&wq) ep, err := r.CreateEndpoint(&wq)
if err != nil { if err != nil {
// TODO: handler errors in the future. log.Warnln("[STACK] udp forwarder request %s:%d->%s:%d: %s", id.RemoteAddress, id.RemotePort, id.LocalAddress, id.LocalPort, err)
return return
} }
conn := &udpConn{ conn := &udpConn{
UDPConn: gonet.NewUDPConn(s, &wq, ep), UDPConn: gonet.NewUDPConn(s, &wq, ep),
id: r.ID(), id: id,
} }
handle(conn) handle(conn)
}) })
@ -54,7 +58,7 @@ func (c *packet) Data() []byte {
} }
// WriteBack write UDP packet with source(ip, port) = `addr` // WriteBack write UDP packet with source(ip, port) = `addr`
func (c *packet) WriteBack(b []byte, addr net.Addr) (n int, err error) { func (c *packet) WriteBack(b []byte, _ net.Addr) (n int, err error) {
return c.pc.WriteTo(b, c.rAddr) return c.pc.WriteTo(b, c.rAddr)
} }
@ -64,5 +68,5 @@ func (c *packet) LocalAddr() net.Addr {
} }
func (c *packet) Drop() { func (c *packet) Drop() {
pool.Put(c.payload) _ = pool.Put(c.payload)
} }

View file

@ -17,6 +17,7 @@ import (
"github.com/Dreamacro/clash/listener/tun/ipstack" "github.com/Dreamacro/clash/listener/tun/ipstack"
D "github.com/Dreamacro/clash/listener/tun/ipstack/commons" D "github.com/Dreamacro/clash/listener/tun/ipstack/commons"
"github.com/Dreamacro/clash/listener/tun/ipstack/system/mars" "github.com/Dreamacro/clash/listener/tun/ipstack/system/mars"
"github.com/Dreamacro/clash/listener/tun/ipstack/system/mars/nat"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
"github.com/Dreamacro/clash/transport/socks5" "github.com/Dreamacro/clash/transport/socks5"
) )
@ -24,14 +25,20 @@ import (
type sysStack struct { type sysStack struct {
stack io.Closer stack io.Closer
device device.Device device device.Device
closed bool
} }
func (s sysStack) Close() error { func (s *sysStack) Close() error {
defer func() {
if s.device != nil {
_ = s.device.Close()
}
}()
s.closed = true
if s.stack != nil { if s.stack != nil {
_ = s.stack.Close() return s.stack.Close()
}
if s.device != nil {
_ = s.device.Close()
} }
return nil return nil
} }
@ -49,17 +56,25 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
return nil, err return nil, err
} }
ipStack := &sysStack{stack: stack, device: device}
dnsAddr := dnsHijack dnsAddr := dnsHijack
tcp := func() { tcp := func() {
defer stack.TCP().Close() defer func(tcp *nat.TCP) {
_ = tcp.Close()
}(stack.TCP())
defer log.Debugln("TCP: closed") defer log.Debugln("TCP: closed")
for stack.TCP().SetDeadline(time.Time{}) == nil { for !ipStack.closed {
if err = stack.TCP().SetDeadline(time.Time{}); err != nil {
break
}
conn, err := stack.TCP().Accept() conn, err := stack.TCP().Accept()
if err != nil { if err != nil {
log.Debugln("Accept connection: %v", err) log.Debugln("Accept connection: %v", err)
continue continue
} }
@ -73,13 +88,19 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
go func() { go func() {
log.Debugln("[TUN] hijack dns tcp: %s", rAddrPort.String()) log.Debugln("[TUN] hijack dns tcp: %s", rAddrPort.String())
defer conn.Close() defer func(conn net.Conn) {
_ = conn.Close()
}(conn)
buf := pool.Get(pool.UDPBufferSize) buf := pool.Get(pool.UDPBufferSize)
defer pool.Put(buf) defer func(buf []byte) {
_ = pool.Put(buf)
}(buf)
for { for {
conn.SetReadDeadline(time.Now().Add(C.DefaultTCPTimeout)) if err = conn.SetReadDeadline(time.Now().Add(C.DefaultTCPTimeout)); err != nil {
break
}
length := uint16(0) length := uint16(0)
if err := binary.Read(conn, binary.BigEndian, &length); err != nil { if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
@ -123,10 +144,13 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
} }
udp := func() { udp := func() {
defer stack.UDP().Close() defer func(udp *nat.UDP) {
_ = udp.Close()
}(stack.UDP())
defer log.Debugln("UDP: closed") defer log.Debugln("UDP: closed")
for { for !ipStack.closed {
buf := pool.Get(pool.UDPBufferSize) buf := pool.Get(pool.UDPBufferSize)
n, lRAddr, rRAddr, err := stack.UDP().ReadFrom(buf) n, lRAddr, rRAddr, err := stack.UDP().ReadFrom(buf)
@ -143,15 +167,16 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
if D.ShouldHijackDns(dnsAddr, rAddrPort) { if D.ShouldHijackDns(dnsAddr, rAddrPort) {
go func() { go func() {
defer pool.Put(buf)
msg, err := D.RelayDnsPacket(raw) msg, err := D.RelayDnsPacket(raw)
if err != nil { if err != nil {
_ = pool.Put(buf)
return return
} }
_, _ = stack.UDP().WriteTo(msg, rAddr, lAddr) _, _ = stack.UDP().WriteTo(msg, rAddr, lAddr)
_ = pool.Put(buf)
log.Debugln("[TUN] hijack dns udp: %s", rAddrPort.String()) log.Debugln("[TUN] hijack dns udp: %s", rAddrPort.String())
}() }()
@ -165,7 +190,7 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
return stack.UDP().WriteTo(b, rAddr, lAddr) return stack.UDP().WriteTo(b, rAddr, lAddr)
}, },
drop: func() { drop: func() {
pool.Put(buf) _ = pool.Put(buf)
}, },
} }
@ -186,5 +211,5 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
go udp() go udp()
} }
return &sysStack{stack: stack, device: device}, nil return ipStack, nil
} }