From 8e4dfbd10d22b1c6f67753c89e11c90dcaee984d Mon Sep 17 00:00:00 2001 From: Ovear Date: Fri, 17 Feb 2023 16:31:15 +0800 Subject: [PATCH] feat: introduce a new robust approach to handle tproxy udp. (#389) --- component/nat/table.go | 73 ++++++++++++++++++++++++-- constant/adapters.go | 25 +++++++++ listener/shadowsocks/utils.go | 8 +++ listener/sing/sing.go | 8 +++ listener/socks/utils.go | 9 ++++ listener/tproxy/packet.go | 98 ++++++++++++++++++++++++++++++++--- listener/tunnel/packet.go | 9 ++++ transport/tuic/server.go | 8 +++ tunnel/connection.go | 15 ++++++ tunnel/tunnel.go | 5 +- 10 files changed, 246 insertions(+), 12 deletions(-) diff --git a/component/nat/table.go b/component/nat/table.go index fbb16dec..5dcd91ed 100644 --- a/component/nat/table.go +++ b/component/nat/table.go @@ -1,6 +1,7 @@ package nat import ( + "net" "sync" C "github.com/Dreamacro/clash/constant" @@ -10,16 +11,24 @@ type Table struct { mapping sync.Map } -func (t *Table) Set(key string, pc C.PacketConn) { - t.mapping.Store(key, pc) +type Entry struct { + PacketConn C.PacketConn + LocalUDPConnMap sync.Map +} + +func (t *Table) Set(key string, e C.PacketConn) { + t.mapping.Store(key, &Entry{ + PacketConn: e, + LocalUDPConnMap: sync.Map{}, + }) } func (t *Table) Get(key string) C.PacketConn { - item, exist := t.mapping.Load(key) + entry, exist := t.getEntry(key) if !exist { return nil } - return item.(C.PacketConn) + return entry.PacketConn } func (t *Table) GetOrCreateLock(key string) (*sync.Cond, bool) { @@ -31,6 +40,62 @@ func (t *Table) Delete(key string) { t.mapping.Delete(key) } +func (t *Table) GetLocalConn(lAddr, rAddr string) *net.UDPConn { + entry, exist := t.getEntry(lAddr) + if !exist { + return nil + } + item, exist := entry.LocalUDPConnMap.Load(rAddr) + if !exist { + return nil + } + return item.(*net.UDPConn) +} + +func (t *Table) AddLocalConn(lAddr, rAddr string, conn *net.UDPConn) bool { + entry, exist := t.getEntry(lAddr) + if !exist { + return false + } + entry.LocalUDPConnMap.Store(rAddr, conn) + return true +} + +func (t *Table) RangeLocalConn(lAddr string, f func(key, value any) bool) { + entry, exist := t.getEntry(lAddr) + if !exist { + return + } + entry.LocalUDPConnMap.Range(f) +} + +func (t *Table) GetOrCreateLockForLocalConn(lAddr, key string) (*sync.Cond, bool) { + entry, loaded := t.getEntry(lAddr) + if !loaded { + return nil, false + } + item, loaded := entry.LocalUDPConnMap.LoadOrStore(key, sync.NewCond(&sync.Mutex{})) + return item.(*sync.Cond), loaded +} + +func (t *Table) DeleteLocalConnMap(lAddr, key string) { + entry, loaded := t.getEntry(lAddr) + if !loaded { + return + } + entry.LocalUDPConnMap.Delete(key) +} + +func (t *Table) getEntry(key string) (*Entry, bool) { + item, ok := t.mapping.Load(key) + // This should not happen usually since this function called after PacketConn created + if !ok { + return nil, false + } + entry, ok := item.(*Entry) + return entry, ok +} + // New return *Cache func New() *Table { return &Table{} diff --git a/constant/adapters.go b/constant/adapters.go index 4480a953..879ee6d7 100644 --- a/constant/adapters.go +++ b/constant/adapters.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/netip" + "sync" "time" "github.com/Dreamacro/clash/component/dialer" @@ -216,6 +217,10 @@ type UDPPacket interface { // LocalAddr returns the source IP/Port of packet LocalAddr() net.Addr + + SetNatTable(natTable NatTable) + + SetUdpInChan(in chan<- PacketAdapter) } type UDPPacketInAddr interface { @@ -227,3 +232,23 @@ type PacketAdapter interface { UDPPacket Metadata() *Metadata } + +type NatTable interface { + Set(key string, e PacketConn) + + Get(key string) PacketConn + + GetOrCreateLock(key string) (*sync.Cond, bool) + + Delete(key string) + + GetLocalConn(lAddr, rAddr string) *net.UDPConn + + AddLocalConn(lAddr, rAddr string, conn *net.UDPConn) bool + + RangeLocalConn(lAddr string, f func(key, value any) bool) + + GetOrCreateLockForLocalConn(lAddr, key string) (*sync.Cond, bool) + + DeleteLocalConnMap(lAddr, key string) +} diff --git a/listener/shadowsocks/utils.go b/listener/shadowsocks/utils.go index 2e9fd003..eee5660a 100644 --- a/listener/shadowsocks/utils.go +++ b/listener/shadowsocks/utils.go @@ -7,6 +7,7 @@ import ( "net/url" "github.com/Dreamacro/clash/common/pool" + C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/transport/socks5" ) @@ -44,6 +45,13 @@ func (c *packet) InAddr() net.Addr { return c.pc.LocalAddr() } +func (c *packet) SetNatTable(natTable C.NatTable) { + // no need +} + +func (c *packet) SetUdpInChan(in chan<- C.PacketAdapter) { + // no need +} func ParseSSURL(s string) (addr, cipher, password string, err error) { u, err := url.Parse(s) if err != nil { diff --git a/listener/sing/sing.go b/listener/sing/sing.go index 27a9d6ac..a3e15154 100644 --- a/listener/sing/sing.go +++ b/listener/sing/sing.go @@ -166,3 +166,11 @@ func (c *packet) Drop() { func (c *packet) InAddr() net.Addr { return c.lAddr } + +func (c *packet) SetNatTable(natTable C.NatTable) { + // no need +} + +func (c *packet) SetUdpInChan(in chan<- C.PacketAdapter) { + // no need +} diff --git a/listener/socks/utils.go b/listener/socks/utils.go index 4c53b9e5..29898fda 100644 --- a/listener/socks/utils.go +++ b/listener/socks/utils.go @@ -4,6 +4,7 @@ import ( "net" "github.com/Dreamacro/clash/common/pool" + C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/transport/socks5" ) @@ -39,3 +40,11 @@ func (c *packet) Drop() { func (c *packet) InAddr() net.Addr { return c.pc.LocalAddr() } + +func (c *packet) SetNatTable(natTable C.NatTable) { + // no need +} + +func (c *packet) SetUdpInChan(in chan<- C.PacketAdapter) { + // no need +} diff --git a/listener/tproxy/packet.go b/listener/tproxy/packet.go index e86a11ca..d66fac51 100644 --- a/listener/tproxy/packet.go +++ b/listener/tproxy/packet.go @@ -1,16 +1,22 @@ package tproxy import ( + "errors" + "fmt" + "github.com/Dreamacro/clash/adapter/inbound" + "github.com/Dreamacro/clash/common/pool" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" "net" "net/netip" - - "github.com/Dreamacro/clash/common/pool" ) type packet struct { - pc net.PacketConn - lAddr netip.AddrPort - buf []byte + pc net.PacketConn + lAddr netip.AddrPort + buf []byte + natTable C.NatTable + in chan<- C.PacketAdapter } func (c *packet) Data() []byte { @@ -19,13 +25,12 @@ func (c *packet) Data() []byte { // WriteBack opens a new socket binding `addr` to write UDP packet back func (c *packet) WriteBack(b []byte, addr net.Addr) (n int, err error) { - tc, err := dialUDP("udp", addr.(*net.UDPAddr).AddrPort(), c.lAddr) + tc, err := createOrGetLocalConn(addr, c.LocalAddr(), c.natTable, c.in) if err != nil { n = 0 return } n, err = tc.Write(b) - tc.Close() return } @@ -41,3 +46,82 @@ func (c *packet) Drop() { func (c *packet) InAddr() net.Addr { return c.pc.LocalAddr() } + +func (c *packet) SetNatTable(natTable C.NatTable) { + c.natTable = natTable +} + +func (c *packet) SetUdpInChan(in chan<- C.PacketAdapter) { + c.in = in +} + +// this function listen at rAddr and write to lAddr +// for here, rAddr is the ip/port client want to access +// lAddr is the ip/port client opened +func createOrGetLocalConn(rAddr, lAddr net.Addr, natTable C.NatTable, in chan<- C.PacketAdapter) (*net.UDPConn, error) { + remote := rAddr.String() + local := lAddr.String() + localConn := natTable.GetLocalConn(local, remote) + // localConn not exist + if localConn == nil { + lockKey := remote + "-lock" + cond, loaded := natTable.GetOrCreateLockForLocalConn(local, lockKey) + if loaded { + cond.L.Lock() + cond.Wait() + // we should get localConn here + localConn = natTable.GetLocalConn(local, remote) + if localConn == nil { + return nil, fmt.Errorf("localConn is nil, nat entry not exist") + } + cond.L.Unlock() + } else { + if cond == nil { + return nil, fmt.Errorf("cond is nil, nat entry not exist") + } + defer func() { + natTable.DeleteLocalConnMap(local, lockKey) + cond.Broadcast() + }() + conn, err := listenLocalConn(rAddr, lAddr, in) + if err != nil { + log.Errorln("listenLocalConn failed with error: %s, packet loss", err.Error()) + return nil, err + } + natTable.AddLocalConn(local, remote, conn) + localConn = conn + } + } + return localConn, nil +} + +// this function listen at rAddr +// and send what received to program itself, then send to real remote +func listenLocalConn(rAddr, lAddr net.Addr, in chan<- C.PacketAdapter) (*net.UDPConn, error) { + additions := []inbound.Addition{ + inbound.WithInName("DEFAULT-TPROXY"), + inbound.WithSpecialRules(""), + } + lc, err := dialUDP("udp", rAddr.(*net.UDPAddr).AddrPort(), lAddr.(*net.UDPAddr).AddrPort()) + if err != nil { + return nil, err + } + go func() { + log.Debugln("TProxy listenLocalConn rAddr=%s lAddr=%s", rAddr.String(), lAddr.String()) + for { + buf := pool.Get(pool.UDPBufferSize) + br, err := lc.Read(buf) + if err != nil { + pool.Put(buf) + if errors.Is(err, net.ErrClosed) { + log.Debugln("TProxy local conn listener exit.. rAddr=%s lAddr=%s", rAddr.String(), lAddr.String()) + return + } + } + // since following localPackets are pass through this socket which listen rAddr + // I choose current listener as packet's packet conn + handlePacketConn(lc, in, buf[:br], lAddr.(*net.UDPAddr).AddrPort(), rAddr.(*net.UDPAddr).AddrPort(), additions...) + } + }() + return lc, nil +} diff --git a/listener/tunnel/packet.go b/listener/tunnel/packet.go index 602f7675..fa85879f 100644 --- a/listener/tunnel/packet.go +++ b/listener/tunnel/packet.go @@ -4,6 +4,7 @@ import ( "net" "github.com/Dreamacro/clash/common/pool" + C "github.com/Dreamacro/clash/constant" ) type packet struct { @@ -33,3 +34,11 @@ func (c *packet) Drop() { func (c *packet) InAddr() net.Addr { return c.pc.LocalAddr() } + +func (c *packet) SetNatTable(natTable C.NatTable) { + // no need +} + +func (c *packet) SetUdpInChan(in chan<- C.PacketAdapter) { + // no need +} diff --git a/transport/tuic/server.go b/transport/tuic/server.go index 2830b324..fdea899d 100644 --- a/transport/tuic/server.go +++ b/transport/tuic/server.go @@ -316,5 +316,13 @@ func (s *serverUDPPacket) Drop() { s.packet.DATA = nil } +func (s *serverUDPPacket) SetNatTable(natTable C.NatTable) { + // no need +} + +func (s *serverUDPPacket) SetUdpInChan(in chan<- C.PacketAdapter) { + // no need +} + var _ C.UDPPacket = &serverUDPPacket{} var _ C.UDPPacketInAddr = &serverUDPPacket{} diff --git a/tunnel/connection.go b/tunnel/connection.go index d8bd26c9..687b2887 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -2,6 +2,7 @@ package tunnel import ( "errors" + "github.com/Dreamacro/clash/log" "net" "net/netip" "time" @@ -32,6 +33,7 @@ func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, oAddr, buf := pool.Get(pool.UDPBufferSize) defer func() { _ = pc.Close() + closeAllLocalCoon(key) natTable.Delete(key) _ = pool.Put(buf) }() @@ -60,6 +62,19 @@ func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, oAddr, } } +func closeAllLocalCoon(lAddr string) { + natTable.RangeLocalConn(lAddr, func(key, value any) bool { + conn, ok := value.(*net.UDPConn) + if !ok || conn == nil { + log.Debugln("Value %#v unknown value when closing TProxy local conn...", conn) + return true + } + conn.Close() + log.Debugln("Closing TProxy local conn... lAddr=%s rAddr=%s", lAddr, key) + return true + }) +} + func handleSocket(ctx C.ConnContext, outbound net.Conn) { N.Relay(ctx.Conn(), outbound) } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 9c6c155f..5c3814bc 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -337,9 +337,12 @@ func handleUDPConn(packet C.PacketAdapter) { } oAddr := metadata.DstIP + natTable.Set(key, pc) + packet.SetNatTable(natTable) + packet.SetUdpInChan(udpQueue) + go handleUDPToLocal(packet, pc, key, oAddr, fAddr) - natTable.Set(key, pc) handle() }() }