Chore: adjust ipstack
This commit is contained in:
parent
4be17653e0
commit
b179d09efb
10 changed files with 180 additions and 73 deletions
|
@ -29,7 +29,4 @@ type Device interface {
|
|||
|
||||
// UseIOBased work for other ip stack
|
||||
UseIOBased() error
|
||||
|
||||
// Wait waits for the device to close.
|
||||
Wait()
|
||||
}
|
||||
|
|
|
@ -36,6 +36,9 @@ type Endpoint struct {
|
|||
|
||||
// once is used to perform the init action once when attaching.
|
||||
once sync.Once
|
||||
|
||||
// wg keeps track of running goroutines.
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// New returns stack.LinkEndpoint(.*Endpoint) and error.
|
||||
|
@ -60,19 +63,26 @@ func New(rw io.ReadWriter, mtu uint32, offset int) (*Endpoint, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (e *Endpoint) Close() {
|
||||
e.Endpoint.Close()
|
||||
func (e *Endpoint) Wait() {
|
||||
e.wg.Wait()
|
||||
}
|
||||
|
||||
// Attach launches the goroutine that reads packets from io.Reader and
|
||||
// dispatches them via the provided dispatcher.
|
||||
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
e.Endpoint.Attach(dispatcher)
|
||||
e.once.Do(func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go e.dispatchLoop(cancel)
|
||||
go e.outboundLoop(ctx)
|
||||
e.wg.Add(2)
|
||||
go func() {
|
||||
e.outboundLoop(ctx)
|
||||
e.wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
e.dispatchLoop(cancel)
|
||||
e.wg.Done()
|
||||
}()
|
||||
})
|
||||
e.Endpoint.Attach(dispatcher)
|
||||
}
|
||||
|
||||
// dispatchLoop dispatches packets to upper layer.
|
||||
|
@ -81,14 +91,19 @@ func (e *Endpoint) dispatchLoop(cancel context.CancelFunc) {
|
|||
// gracefully after (*Endpoint).dispatchLoop(context.CancelFunc) returns.
|
||||
defer cancel()
|
||||
|
||||
mtu := int(e.mtu)
|
||||
for {
|
||||
data := make([]byte, int(e.mtu))
|
||||
data := make([]byte, mtu)
|
||||
|
||||
n, err := e.rw.Read(data)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if n == 0 || n > mtu {
|
||||
continue
|
||||
}
|
||||
|
||||
if !e.IsAttached() {
|
||||
continue /* unattached, drop packet */
|
||||
}
|
||||
|
|
|
@ -32,15 +32,6 @@ func Open(name string, mtu uint32) (_ device.Device, err error) {
|
|||
}
|
||||
}()
|
||||
|
||||
var (
|
||||
offset = 4 /* 4 bytes TUN_PI */
|
||||
defaultMTU = 1500
|
||||
)
|
||||
if runtime.GOOS == "windows" {
|
||||
offset = 0
|
||||
defaultMTU = 0 /* auto */
|
||||
}
|
||||
|
||||
t := &TUN{
|
||||
name: name,
|
||||
mtu: mtu,
|
||||
|
@ -101,9 +92,11 @@ func (t *TUN) Write(packet []byte) (int, error) {
|
|||
}
|
||||
|
||||
func (t *TUN) Close() error {
|
||||
if t.Endpoint != nil {
|
||||
t.Endpoint.Close()
|
||||
defer func(ep *iobased.Endpoint) {
|
||||
if ep != nil {
|
||||
ep.Close()
|
||||
}
|
||||
}(t.Endpoint)
|
||||
return t.nt.Close()
|
||||
}
|
||||
|
||||
|
|
8
listener/tun/device/tun/tun_wireguard_unix.go
Normal file
8
listener/tun/device/tun/tun_wireguard_unix.go
Normal file
|
@ -0,0 +1,8 @@
|
|||
//go:build !linux && !windows
|
||||
|
||||
package tun
|
||||
|
||||
const (
|
||||
offset = 4 /* 4 bytes TUN_PI */
|
||||
defaultMTU = 1500
|
||||
)
|
|
@ -5,6 +5,11 @@ import (
|
|||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
const (
|
||||
offset = 0
|
||||
defaultMTU = 0 /* auto */
|
||||
)
|
||||
|
||||
func init() {
|
||||
guid, _ := windows.GUIDFromString("{330EAEF8-7578-5DF2-D97B-8DADC0EA85CB}")
|
||||
|
|
@ -37,7 +37,7 @@ const (
|
|||
|
||||
// tcpModerateReceiveBufferEnabled is the value used by stack to
|
||||
// enable or disable tcp receive buffer auto-tuning option.
|
||||
tcpModerateReceiveBufferEnabled = true
|
||||
tcpModerateReceiveBufferEnabled = false
|
||||
|
||||
// tcpSACKEnabled is the value used by stack to enable or disable
|
||||
// tcp selective ACK.
|
||||
|
@ -47,14 +47,18 @@ const (
|
|||
tcpRecovery = tcpip.TCPRACKLossDetection
|
||||
|
||||
// tcpMinBufferSize is the smallest size of a send/recv buffer.
|
||||
tcpMinBufferSize = tcp.MinBufferSize // 4 KiB
|
||||
tcpMinBufferSize = tcp.MinBufferSize
|
||||
|
||||
// tcpMaxBufferSize is the maximum permitted size of a send/recv buffer.
|
||||
tcpMaxBufferSize = tcp.MaxBufferSize // 4 MiB
|
||||
tcpMaxBufferSize = tcp.MaxBufferSize
|
||||
|
||||
// tcpDefaultBufferSize is the default size of the send/recv buffer for
|
||||
// tcpDefaultBufferSize is the default size of the send buffer for
|
||||
// a transport endpoint.
|
||||
tcpDefaultBufferSize = 212 << 10 // 212 KiB
|
||||
tcpDefaultSendBufferSize = tcp.DefaultSendBufferSize
|
||||
|
||||
// tcpDefaultReceiveBufferSize is the default size of the receive buffer
|
||||
// for a transport endpoint.
|
||||
tcpDefaultReceiveBufferSize = tcp.DefaultReceiveBufferSize
|
||||
)
|
||||
|
||||
type Option func(*stack.Stack) error
|
||||
|
@ -74,7 +78,8 @@ func WithDefault() Option {
|
|||
// in too large buffers.
|
||||
//
|
||||
// Ref: https://github.com/cloudflare/slirpnetstack/blob/master/stack.go
|
||||
WithTCPBufferSizeRange(tcpMinBufferSize, tcpDefaultBufferSize, tcpMaxBufferSize),
|
||||
WithTCPSendBufferSizeRange(tcpMinBufferSize, tcpDefaultSendBufferSize, tcpMaxBufferSize),
|
||||
WithTCPReceiveBufferSizeRange(tcpMinBufferSize, tcpDefaultReceiveBufferSize, tcpMaxBufferSize),
|
||||
|
||||
WithTCPCongestionControl(tcpCongestionControlAlgorithm),
|
||||
WithTCPDelay(tcpDelayEnabled),
|
||||
|
@ -154,17 +159,46 @@ func WithICMPLimit(limit rate.Limit) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithTCPBufferSizeRange sets the receive and send buffer size range for TCP.
|
||||
func WithTCPBufferSizeRange(a, b, c int) Option {
|
||||
// WithTCPSendBufferSize sets default the send buffer size for TCP.
|
||||
func WithTCPSendBufferSize(size int) Option {
|
||||
return func(s *stack.Stack) error {
|
||||
sndOpt := tcpip.TCPSendBufferSizeRangeOption{Min: tcpMinBufferSize, Default: size, Max: tcpMaxBufferSize}
|
||||
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sndOpt); err != nil {
|
||||
return fmt.Errorf("set TCP send buffer size range: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithTCPSendBufferSizeRange sets the send buffer size range for TCP.
|
||||
func WithTCPSendBufferSizeRange(a, b, c int) Option {
|
||||
return func(s *stack.Stack) error {
|
||||
sndOpt := tcpip.TCPSendBufferSizeRangeOption{Min: a, Default: b, Max: c}
|
||||
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sndOpt); err != nil {
|
||||
return fmt.Errorf("set TCP send buffer size range: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithTCPReceiveBufferSize sets the default receive buffer size for TCP.
|
||||
func WithTCPReceiveBufferSize(size int) Option {
|
||||
return func(s *stack.Stack) error {
|
||||
rcvOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: tcpMinBufferSize, Default: size, Max: tcpMaxBufferSize}
|
||||
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvOpt); err != nil {
|
||||
return fmt.Errorf("set TCP receive buffer size range: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithTCPReceiveBufferSizeRange sets the receive buffer size range for TCP.
|
||||
func WithTCPReceiveBufferSizeRange(a, b, c int) Option {
|
||||
return func(s *stack.Stack) error {
|
||||
rcvOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: a, Default: b, Max: c}
|
||||
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvOpt); err != nil {
|
||||
return fmt.Errorf("set TCP receive buffer size range: %s", err)
|
||||
}
|
||||
sndOpt := tcpip.TCPSendBufferSizeRangeOption{Min: a, Default: b, Max: c}
|
||||
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sndOpt); err != nil {
|
||||
return fmt.Errorf("set TCP send buffer size range: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,7 +25,6 @@ func (s *gvStack) Close() error {
|
|||
var err error
|
||||
if s.device != nil {
|
||||
err = s.device.Close()
|
||||
s.device.Wait()
|
||||
}
|
||||
if s.Stack != nil {
|
||||
s.Stack.Close()
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
package gvisor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter"
|
||||
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option"
|
||||
"github.com/Dreamacro/clash/log"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
|
@ -43,8 +44,21 @@ const (
|
|||
func withTCPHandler(handle adapter.TCPHandleFunc) option.Option {
|
||||
return func(s *stack.Stack) error {
|
||||
tcpForwarder := tcp.NewForwarder(s, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) {
|
||||
var wq waiter.Queue
|
||||
ep, err := r.CreateEndpoint(&wq)
|
||||
var (
|
||||
wq waiter.Queue
|
||||
ep tcpip.Endpoint
|
||||
err tcpip.Error
|
||||
id = r.ID()
|
||||
)
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.Warnln("[STACK] forward tcp request %s:%d->%s:%d: %s", id.RemoteAddress, id.RemotePort, id.LocalAddress, id.LocalPort, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Perform a TCP three-way handshake.
|
||||
ep, err = r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
// RST: prevent potential half-open TCP connection leak.
|
||||
r.Complete(true)
|
||||
|
@ -52,11 +66,11 @@ func withTCPHandler(handle adapter.TCPHandleFunc) option.Option {
|
|||
}
|
||||
defer r.Complete(false)
|
||||
|
||||
setKeepalive(ep)
|
||||
err = setSocketOptions(s, ep)
|
||||
|
||||
conn := &tcpConn{
|
||||
TCPConn: gonet.NewTCPConn(&wq, ep),
|
||||
id: r.ID(),
|
||||
id: id,
|
||||
}
|
||||
handle(conn)
|
||||
})
|
||||
|
@ -65,21 +79,34 @@ func withTCPHandler(handle adapter.TCPHandleFunc) option.Option {
|
|||
}
|
||||
}
|
||||
|
||||
func setKeepalive(ep tcpip.Endpoint) error {
|
||||
func setSocketOptions(s *stack.Stack, ep tcpip.Endpoint) tcpip.Error {
|
||||
{ /* TCP keepalive options */
|
||||
ep.SocketOptions().SetKeepAlive(true)
|
||||
|
||||
idle := tcpip.KeepaliveIdleOption(tcpKeepaliveIdle)
|
||||
if err := ep.SetSockOpt(&idle); err != nil {
|
||||
return fmt.Errorf("set keepalive idle: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
interval := tcpip.KeepaliveIntervalOption(tcpKeepaliveInterval)
|
||||
if err := ep.SetSockOpt(&interval); err != nil {
|
||||
return fmt.Errorf("set keepalive interval: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ep.SetSockOptInt(tcpip.KeepaliveCountOption, tcpKeepaliveCount); err != nil {
|
||||
return fmt.Errorf("set keepalive count: %s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
{ /* TCP recv/send buffer size */
|
||||
var ss tcpip.TCPSendBufferSizeRangeOption
|
||||
if err := s.TransportProtocolOption(header.TCPProtocolNumber, &ss); err == nil {
|
||||
ep.SocketOptions().SetReceiveBufferSize(int64(ss.Default), false)
|
||||
}
|
||||
|
||||
var rs tcpip.TCPReceiveBufferSizeRangeOption
|
||||
if err := s.TransportProtocolOption(header.TCPProtocolNumber, &rs); err == nil {
|
||||
ep.SocketOptions().SetReceiveBufferSize(int64(rs.Default), false)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"github.com/Dreamacro/clash/common/pool"
|
||||
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter"
|
||||
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option"
|
||||
"github.com/Dreamacro/clash/log"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
|
@ -16,16 +17,19 @@ import (
|
|||
func withUDPHandler(handle adapter.UDPHandleFunc) option.Option {
|
||||
return func(s *stack.Stack) error {
|
||||
udpForwarder := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
|
||||
var wq waiter.Queue
|
||||
var (
|
||||
wq waiter.Queue
|
||||
id = r.ID()
|
||||
)
|
||||
ep, err := r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
// TODO: handler errors in the future.
|
||||
log.Warnln("[STACK] udp forwarder request %s:%d->%s:%d: %s", id.RemoteAddress, id.RemotePort, id.LocalAddress, id.LocalPort, err)
|
||||
return
|
||||
}
|
||||
|
||||
conn := &udpConn{
|
||||
UDPConn: gonet.NewUDPConn(s, &wq, ep),
|
||||
id: r.ID(),
|
||||
id: id,
|
||||
}
|
||||
handle(conn)
|
||||
})
|
||||
|
@ -54,7 +58,7 @@ func (c *packet) Data() []byte {
|
|||
}
|
||||
|
||||
// WriteBack write UDP packet with source(ip, port) = `addr`
|
||||
func (c *packet) WriteBack(b []byte, addr net.Addr) (n int, err error) {
|
||||
func (c *packet) WriteBack(b []byte, _ net.Addr) (n int, err error) {
|
||||
return c.pc.WriteTo(b, c.rAddr)
|
||||
}
|
||||
|
||||
|
@ -64,5 +68,5 @@ func (c *packet) LocalAddr() net.Addr {
|
|||
}
|
||||
|
||||
func (c *packet) Drop() {
|
||||
pool.Put(c.payload)
|
||||
_ = pool.Put(c.payload)
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/Dreamacro/clash/listener/tun/ipstack"
|
||||
D "github.com/Dreamacro/clash/listener/tun/ipstack/commons"
|
||||
"github.com/Dreamacro/clash/listener/tun/ipstack/system/mars"
|
||||
"github.com/Dreamacro/clash/listener/tun/ipstack/system/mars/nat"
|
||||
"github.com/Dreamacro/clash/log"
|
||||
"github.com/Dreamacro/clash/transport/socks5"
|
||||
)
|
||||
|
@ -24,15 +25,21 @@ import (
|
|||
type sysStack struct {
|
||||
stack io.Closer
|
||||
device device.Device
|
||||
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (s sysStack) Close() error {
|
||||
if s.stack != nil {
|
||||
_ = s.stack.Close()
|
||||
}
|
||||
func (s *sysStack) Close() error {
|
||||
defer func() {
|
||||
if s.device != nil {
|
||||
_ = s.device.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
s.closed = true
|
||||
if s.stack != nil {
|
||||
return s.stack.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -49,17 +56,25 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
|
|||
return nil, err
|
||||
}
|
||||
|
||||
ipStack := &sysStack{stack: stack, device: device}
|
||||
|
||||
dnsAddr := dnsHijack
|
||||
|
||||
tcp := func() {
|
||||
defer stack.TCP().Close()
|
||||
defer func(tcp *nat.TCP) {
|
||||
_ = tcp.Close()
|
||||
}(stack.TCP())
|
||||
|
||||
defer log.Debugln("TCP: closed")
|
||||
|
||||
for stack.TCP().SetDeadline(time.Time{}) == nil {
|
||||
for !ipStack.closed {
|
||||
if err = stack.TCP().SetDeadline(time.Time{}); err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
conn, err := stack.TCP().Accept()
|
||||
if err != nil {
|
||||
log.Debugln("Accept connection: %v", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -73,13 +88,19 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
|
|||
go func() {
|
||||
log.Debugln("[TUN] hijack dns tcp: %s", rAddrPort.String())
|
||||
|
||||
defer conn.Close()
|
||||
defer func(conn net.Conn) {
|
||||
_ = conn.Close()
|
||||
}(conn)
|
||||
|
||||
buf := pool.Get(pool.UDPBufferSize)
|
||||
defer pool.Put(buf)
|
||||
defer func(buf []byte) {
|
||||
_ = pool.Put(buf)
|
||||
}(buf)
|
||||
|
||||
for {
|
||||
conn.SetReadDeadline(time.Now().Add(C.DefaultTCPTimeout))
|
||||
if err = conn.SetReadDeadline(time.Now().Add(C.DefaultTCPTimeout)); err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
length := uint16(0)
|
||||
if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
|
||||
|
@ -123,10 +144,13 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
|
|||
}
|
||||
|
||||
udp := func() {
|
||||
defer stack.UDP().Close()
|
||||
defer func(udp *nat.UDP) {
|
||||
_ = udp.Close()
|
||||
}(stack.UDP())
|
||||
|
||||
defer log.Debugln("UDP: closed")
|
||||
|
||||
for {
|
||||
for !ipStack.closed {
|
||||
buf := pool.Get(pool.UDPBufferSize)
|
||||
|
||||
n, lRAddr, rRAddr, err := stack.UDP().ReadFrom(buf)
|
||||
|
@ -143,15 +167,16 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
|
|||
|
||||
if D.ShouldHijackDns(dnsAddr, rAddrPort) {
|
||||
go func() {
|
||||
defer pool.Put(buf)
|
||||
|
||||
msg, err := D.RelayDnsPacket(raw)
|
||||
if err != nil {
|
||||
_ = pool.Put(buf)
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = stack.UDP().WriteTo(msg, rAddr, lAddr)
|
||||
|
||||
_ = pool.Put(buf)
|
||||
|
||||
log.Debugln("[TUN] hijack dns udp: %s", rAddrPort.String())
|
||||
}()
|
||||
|
||||
|
@ -165,7 +190,7 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
|
|||
return stack.UDP().WriteTo(b, rAddr, lAddr)
|
||||
},
|
||||
drop: func() {
|
||||
pool.Put(buf)
|
||||
_ = pool.Put(buf)
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -186,5 +211,5 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
|
|||
go udp()
|
||||
}
|
||||
|
||||
return &sysStack{stack: stack, device: device}, nil
|
||||
return ipStack, nil
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue