Chore: adjust ipstack

This commit is contained in:
gVisor bot 2022-04-12 22:33:10 +08:00
parent ebacc76433
commit e54d403a1e
10 changed files with 180 additions and 73 deletions

View file

@ -29,7 +29,4 @@ type Device interface {
// UseIOBased work for other ip stack
UseIOBased() error
// Wait waits for the device to close.
Wait()
}

View file

@ -36,6 +36,9 @@ type Endpoint struct {
// once is used to perform the init action once when attaching.
once sync.Once
// wg keeps track of running goroutines.
wg sync.WaitGroup
}
// New returns stack.LinkEndpoint(.*Endpoint) and error.
@ -60,19 +63,26 @@ func New(rw io.ReadWriter, mtu uint32, offset int) (*Endpoint, error) {
}, nil
}
func (e *Endpoint) Close() {
e.Endpoint.Close()
func (e *Endpoint) Wait() {
e.wg.Wait()
}
// Attach launches the goroutine that reads packets from io.Reader and
// dispatches them via the provided dispatcher.
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.Endpoint.Attach(dispatcher)
e.once.Do(func() {
ctx, cancel := context.WithCancel(context.Background())
go e.dispatchLoop(cancel)
go e.outboundLoop(ctx)
e.wg.Add(2)
go func() {
e.outboundLoop(ctx)
e.wg.Done()
}()
go func() {
e.dispatchLoop(cancel)
e.wg.Done()
}()
})
e.Endpoint.Attach(dispatcher)
}
// dispatchLoop dispatches packets to upper layer.
@ -81,14 +91,19 @@ func (e *Endpoint) dispatchLoop(cancel context.CancelFunc) {
// gracefully after (*Endpoint).dispatchLoop(context.CancelFunc) returns.
defer cancel()
mtu := int(e.mtu)
for {
data := make([]byte, int(e.mtu))
data := make([]byte, mtu)
n, err := e.rw.Read(data)
if err != nil {
break
}
if n == 0 || n > mtu {
continue
}
if !e.IsAttached() {
continue /* unattached, drop packet */
}

View file

@ -32,15 +32,6 @@ func Open(name string, mtu uint32) (_ device.Device, err error) {
}
}()
var (
offset = 4 /* 4 bytes TUN_PI */
defaultMTU = 1500
)
if runtime.GOOS == "windows" {
offset = 0
defaultMTU = 0 /* auto */
}
t := &TUN{
name: name,
mtu: mtu,
@ -101,9 +92,11 @@ func (t *TUN) Write(packet []byte) (int, error) {
}
func (t *TUN) Close() error {
if t.Endpoint != nil {
t.Endpoint.Close()
}
defer func(ep *iobased.Endpoint) {
if ep != nil {
ep.Close()
}
}(t.Endpoint)
return t.nt.Close()
}

View file

@ -0,0 +1,8 @@
//go:build !linux && !windows
package tun
const (
offset = 4 /* 4 bytes TUN_PI */
defaultMTU = 1500
)

View file

@ -5,6 +5,11 @@ import (
"golang.zx2c4.com/wireguard/tun"
)
const (
offset = 0
defaultMTU = 0 /* auto */
)
func init() {
guid, _ := windows.GUIDFromString("{330EAEF8-7578-5DF2-D97B-8DADC0EA85CB}")

View file

@ -37,7 +37,7 @@ const (
// tcpModerateReceiveBufferEnabled is the value used by stack to
// enable or disable tcp receive buffer auto-tuning option.
tcpModerateReceiveBufferEnabled = true
tcpModerateReceiveBufferEnabled = false
// tcpSACKEnabled is the value used by stack to enable or disable
// tcp selective ACK.
@ -47,14 +47,18 @@ const (
tcpRecovery = tcpip.TCPRACKLossDetection
// tcpMinBufferSize is the smallest size of a send/recv buffer.
tcpMinBufferSize = tcp.MinBufferSize // 4 KiB
tcpMinBufferSize = tcp.MinBufferSize
// tcpMaxBufferSize is the maximum permitted size of a send/recv buffer.
tcpMaxBufferSize = tcp.MaxBufferSize // 4 MiB
tcpMaxBufferSize = tcp.MaxBufferSize
// tcpDefaultBufferSize is the default size of the send/recv buffer for
// tcpDefaultBufferSize is the default size of the send buffer for
// a transport endpoint.
tcpDefaultBufferSize = 212 << 10 // 212 KiB
tcpDefaultSendBufferSize = tcp.DefaultSendBufferSize
// tcpDefaultReceiveBufferSize is the default size of the receive buffer
// for a transport endpoint.
tcpDefaultReceiveBufferSize = tcp.DefaultReceiveBufferSize
)
type Option func(*stack.Stack) error
@ -74,7 +78,8 @@ func WithDefault() Option {
// in too large buffers.
//
// Ref: https://github.com/cloudflare/slirpnetstack/blob/master/stack.go
WithTCPBufferSizeRange(tcpMinBufferSize, tcpDefaultBufferSize, tcpMaxBufferSize),
WithTCPSendBufferSizeRange(tcpMinBufferSize, tcpDefaultSendBufferSize, tcpMaxBufferSize),
WithTCPReceiveBufferSizeRange(tcpMinBufferSize, tcpDefaultReceiveBufferSize, tcpMaxBufferSize),
WithTCPCongestionControl(tcpCongestionControlAlgorithm),
WithTCPDelay(tcpDelayEnabled),
@ -154,17 +159,46 @@ func WithICMPLimit(limit rate.Limit) Option {
}
}
// WithTCPBufferSizeRange sets the receive and send buffer size range for TCP.
func WithTCPBufferSizeRange(a, b, c int) Option {
// WithTCPSendBufferSize sets default the send buffer size for TCP.
func WithTCPSendBufferSize(size int) Option {
return func(s *stack.Stack) error {
sndOpt := tcpip.TCPSendBufferSizeRangeOption{Min: tcpMinBufferSize, Default: size, Max: tcpMaxBufferSize}
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sndOpt); err != nil {
return fmt.Errorf("set TCP send buffer size range: %s", err)
}
return nil
}
}
// WithTCPSendBufferSizeRange sets the send buffer size range for TCP.
func WithTCPSendBufferSizeRange(a, b, c int) Option {
return func(s *stack.Stack) error {
sndOpt := tcpip.TCPSendBufferSizeRangeOption{Min: a, Default: b, Max: c}
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sndOpt); err != nil {
return fmt.Errorf("set TCP send buffer size range: %s", err)
}
return nil
}
}
// WithTCPReceiveBufferSize sets the default receive buffer size for TCP.
func WithTCPReceiveBufferSize(size int) Option {
return func(s *stack.Stack) error {
rcvOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: tcpMinBufferSize, Default: size, Max: tcpMaxBufferSize}
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvOpt); err != nil {
return fmt.Errorf("set TCP receive buffer size range: %s", err)
}
return nil
}
}
// WithTCPReceiveBufferSizeRange sets the receive buffer size range for TCP.
func WithTCPReceiveBufferSizeRange(a, b, c int) Option {
return func(s *stack.Stack) error {
rcvOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: a, Default: b, Max: c}
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvOpt); err != nil {
return fmt.Errorf("set TCP receive buffer size range: %s", err)
}
sndOpt := tcpip.TCPSendBufferSizeRangeOption{Min: a, Default: b, Max: c}
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sndOpt); err != nil {
return fmt.Errorf("set TCP send buffer size range: %s", err)
}
return nil
}
}

View file

@ -25,7 +25,6 @@ func (s *gvStack) Close() error {
var err error
if s.device != nil {
err = s.device.Close()
s.device.Wait()
}
if s.Stack != nil {
s.Stack.Close()

View file

@ -1,14 +1,15 @@
package gvisor
import (
"fmt"
"time"
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter"
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option"
"github.com/Dreamacro/clash/log"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
@ -43,8 +44,21 @@ const (
func withTCPHandler(handle adapter.TCPHandleFunc) option.Option {
return func(s *stack.Stack) error {
tcpForwarder := tcp.NewForwarder(s, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
ep, err := r.CreateEndpoint(&wq)
var (
wq waiter.Queue
ep tcpip.Endpoint
err tcpip.Error
id = r.ID()
)
defer func() {
if err != nil {
log.Warnln("[STACK] forward tcp request %s:%d->%s:%d: %s", id.RemoteAddress, id.RemotePort, id.LocalAddress, id.LocalPort, err)
}
}()
// Perform a TCP three-way handshake.
ep, err = r.CreateEndpoint(&wq)
if err != nil {
// RST: prevent potential half-open TCP connection leak.
r.Complete(true)
@ -52,11 +66,11 @@ func withTCPHandler(handle adapter.TCPHandleFunc) option.Option {
}
defer r.Complete(false)
setKeepalive(ep)
err = setSocketOptions(s, ep)
conn := &tcpConn{
TCPConn: gonet.NewTCPConn(&wq, ep),
id: r.ID(),
id: id,
}
handle(conn)
})
@ -65,21 +79,34 @@ func withTCPHandler(handle adapter.TCPHandleFunc) option.Option {
}
}
func setKeepalive(ep tcpip.Endpoint) error {
ep.SocketOptions().SetKeepAlive(true)
func setSocketOptions(s *stack.Stack, ep tcpip.Endpoint) tcpip.Error {
{ /* TCP keepalive options */
ep.SocketOptions().SetKeepAlive(true)
idle := tcpip.KeepaliveIdleOption(tcpKeepaliveIdle)
if err := ep.SetSockOpt(&idle); err != nil {
return fmt.Errorf("set keepalive idle: %s", err)
idle := tcpip.KeepaliveIdleOption(tcpKeepaliveIdle)
if err := ep.SetSockOpt(&idle); err != nil {
return err
}
interval := tcpip.KeepaliveIntervalOption(tcpKeepaliveInterval)
if err := ep.SetSockOpt(&interval); err != nil {
return err
}
if err := ep.SetSockOptInt(tcpip.KeepaliveCountOption, tcpKeepaliveCount); err != nil {
return err
}
}
{ /* TCP recv/send buffer size */
var ss tcpip.TCPSendBufferSizeRangeOption
if err := s.TransportProtocolOption(header.TCPProtocolNumber, &ss); err == nil {
ep.SocketOptions().SetReceiveBufferSize(int64(ss.Default), false)
}
interval := tcpip.KeepaliveIntervalOption(tcpKeepaliveInterval)
if err := ep.SetSockOpt(&interval); err != nil {
return fmt.Errorf("set keepalive interval: %s", err)
}
if err := ep.SetSockOptInt(tcpip.KeepaliveCountOption, tcpKeepaliveCount); err != nil {
return fmt.Errorf("set keepalive count: %s", err)
var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := s.TransportProtocolOption(header.TCPProtocolNumber, &rs); err == nil {
ep.SocketOptions().SetReceiveBufferSize(int64(rs.Default), false)
}
}
return nil
}

View file

@ -6,6 +6,7 @@ import (
"github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter"
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option"
"github.com/Dreamacro/clash/log"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@ -16,16 +17,19 @@ import (
func withUDPHandler(handle adapter.UDPHandleFunc) option.Option {
return func(s *stack.Stack) error {
udpForwarder := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
var wq waiter.Queue
var (
wq waiter.Queue
id = r.ID()
)
ep, err := r.CreateEndpoint(&wq)
if err != nil {
// TODO: handler errors in the future.
log.Warnln("[STACK] udp forwarder request %s:%d->%s:%d: %s", id.RemoteAddress, id.RemotePort, id.LocalAddress, id.LocalPort, err)
return
}
conn := &udpConn{
UDPConn: gonet.NewUDPConn(s, &wq, ep),
id: r.ID(),
id: id,
}
handle(conn)
})
@ -54,7 +58,7 @@ func (c *packet) Data() []byte {
}
// WriteBack write UDP packet with source(ip, port) = `addr`
func (c *packet) WriteBack(b []byte, addr net.Addr) (n int, err error) {
func (c *packet) WriteBack(b []byte, _ net.Addr) (n int, err error) {
return c.pc.WriteTo(b, c.rAddr)
}
@ -64,5 +68,5 @@ func (c *packet) LocalAddr() net.Addr {
}
func (c *packet) Drop() {
pool.Put(c.payload)
_ = pool.Put(c.payload)
}

View file

@ -17,6 +17,7 @@ import (
"github.com/Dreamacro/clash/listener/tun/ipstack"
D "github.com/Dreamacro/clash/listener/tun/ipstack/commons"
"github.com/Dreamacro/clash/listener/tun/ipstack/system/mars"
"github.com/Dreamacro/clash/listener/tun/ipstack/system/mars/nat"
"github.com/Dreamacro/clash/log"
"github.com/Dreamacro/clash/transport/socks5"
)
@ -24,14 +25,20 @@ import (
type sysStack struct {
stack io.Closer
device device.Device
closed bool
}
func (s sysStack) Close() error {
func (s *sysStack) Close() error {
defer func() {
if s.device != nil {
_ = s.device.Close()
}
}()
s.closed = true
if s.stack != nil {
_ = s.stack.Close()
}
if s.device != nil {
_ = s.device.Close()
return s.stack.Close()
}
return nil
}
@ -49,17 +56,25 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
return nil, err
}
ipStack := &sysStack{stack: stack, device: device}
dnsAddr := dnsHijack
tcp := func() {
defer stack.TCP().Close()
defer func(tcp *nat.TCP) {
_ = tcp.Close()
}(stack.TCP())
defer log.Debugln("TCP: closed")
for stack.TCP().SetDeadline(time.Time{}) == nil {
for !ipStack.closed {
if err = stack.TCP().SetDeadline(time.Time{}); err != nil {
break
}
conn, err := stack.TCP().Accept()
if err != nil {
log.Debugln("Accept connection: %v", err)
continue
}
@ -73,13 +88,19 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
go func() {
log.Debugln("[TUN] hijack dns tcp: %s", rAddrPort.String())
defer conn.Close()
defer func(conn net.Conn) {
_ = conn.Close()
}(conn)
buf := pool.Get(pool.UDPBufferSize)
defer pool.Put(buf)
defer func(buf []byte) {
_ = pool.Put(buf)
}(buf)
for {
conn.SetReadDeadline(time.Now().Add(C.DefaultTCPTimeout))
if err = conn.SetReadDeadline(time.Now().Add(C.DefaultTCPTimeout)); err != nil {
break
}
length := uint16(0)
if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
@ -123,10 +144,13 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
}
udp := func() {
defer stack.UDP().Close()
defer func(udp *nat.UDP) {
_ = udp.Close()
}(stack.UDP())
defer log.Debugln("UDP: closed")
for {
for !ipStack.closed {
buf := pool.Get(pool.UDPBufferSize)
n, lRAddr, rRAddr, err := stack.UDP().ReadFrom(buf)
@ -143,15 +167,16 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
if D.ShouldHijackDns(dnsAddr, rAddrPort) {
go func() {
defer pool.Put(buf)
msg, err := D.RelayDnsPacket(raw)
if err != nil {
_ = pool.Put(buf)
return
}
_, _ = stack.UDP().WriteTo(msg, rAddr, lAddr)
_ = pool.Put(buf)
log.Debugln("[TUN] hijack dns udp: %s", rAddrPort.String())
}()
@ -165,7 +190,7 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
return stack.UDP().WriteTo(b, rAddr, lAddr)
},
drop: func() {
pool.Put(buf)
_ = pool.Put(buf)
},
}
@ -186,5 +211,5 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
go udp()
}
return &sysStack{stack: stack, device: device}, nil
return ipStack, nil
}