feat: introduce a new robust approach to handle tproxy udp. (#389)

This commit is contained in:
Ovear 2023-02-17 16:31:15 +08:00 committed by GitHub
parent b2d1cea759
commit 8e4dfbd10d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 246 additions and 12 deletions

View file

@ -1,6 +1,7 @@
package nat package nat
import ( import (
"net"
"sync" "sync"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
@ -10,16 +11,24 @@ type Table struct {
mapping sync.Map mapping sync.Map
} }
func (t *Table) Set(key string, pc C.PacketConn) { type Entry struct {
t.mapping.Store(key, pc) PacketConn C.PacketConn
LocalUDPConnMap sync.Map
}
func (t *Table) Set(key string, e C.PacketConn) {
t.mapping.Store(key, &Entry{
PacketConn: e,
LocalUDPConnMap: sync.Map{},
})
} }
func (t *Table) Get(key string) C.PacketConn { func (t *Table) Get(key string) C.PacketConn {
item, exist := t.mapping.Load(key) entry, exist := t.getEntry(key)
if !exist { if !exist {
return nil return nil
} }
return item.(C.PacketConn) return entry.PacketConn
} }
func (t *Table) GetOrCreateLock(key string) (*sync.Cond, bool) { func (t *Table) GetOrCreateLock(key string) (*sync.Cond, bool) {
@ -31,6 +40,62 @@ func (t *Table) Delete(key string) {
t.mapping.Delete(key) t.mapping.Delete(key)
} }
func (t *Table) GetLocalConn(lAddr, rAddr string) *net.UDPConn {
entry, exist := t.getEntry(lAddr)
if !exist {
return nil
}
item, exist := entry.LocalUDPConnMap.Load(rAddr)
if !exist {
return nil
}
return item.(*net.UDPConn)
}
func (t *Table) AddLocalConn(lAddr, rAddr string, conn *net.UDPConn) bool {
entry, exist := t.getEntry(lAddr)
if !exist {
return false
}
entry.LocalUDPConnMap.Store(rAddr, conn)
return true
}
func (t *Table) RangeLocalConn(lAddr string, f func(key, value any) bool) {
entry, exist := t.getEntry(lAddr)
if !exist {
return
}
entry.LocalUDPConnMap.Range(f)
}
func (t *Table) GetOrCreateLockForLocalConn(lAddr, key string) (*sync.Cond, bool) {
entry, loaded := t.getEntry(lAddr)
if !loaded {
return nil, false
}
item, loaded := entry.LocalUDPConnMap.LoadOrStore(key, sync.NewCond(&sync.Mutex{}))
return item.(*sync.Cond), loaded
}
func (t *Table) DeleteLocalConnMap(lAddr, key string) {
entry, loaded := t.getEntry(lAddr)
if !loaded {
return
}
entry.LocalUDPConnMap.Delete(key)
}
func (t *Table) getEntry(key string) (*Entry, bool) {
item, ok := t.mapping.Load(key)
// This should not happen usually since this function called after PacketConn created
if !ok {
return nil, false
}
entry, ok := item.(*Entry)
return entry, ok
}
// New return *Cache // New return *Cache
func New() *Table { func New() *Table {
return &Table{} return &Table{}

View file

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"sync"
"time" "time"
"github.com/Dreamacro/clash/component/dialer" "github.com/Dreamacro/clash/component/dialer"
@ -216,6 +217,10 @@ type UDPPacket interface {
// LocalAddr returns the source IP/Port of packet // LocalAddr returns the source IP/Port of packet
LocalAddr() net.Addr LocalAddr() net.Addr
SetNatTable(natTable NatTable)
SetUdpInChan(in chan<- PacketAdapter)
} }
type UDPPacketInAddr interface { type UDPPacketInAddr interface {
@ -227,3 +232,23 @@ type PacketAdapter interface {
UDPPacket UDPPacket
Metadata() *Metadata Metadata() *Metadata
} }
type NatTable interface {
Set(key string, e PacketConn)
Get(key string) PacketConn
GetOrCreateLock(key string) (*sync.Cond, bool)
Delete(key string)
GetLocalConn(lAddr, rAddr string) *net.UDPConn
AddLocalConn(lAddr, rAddr string, conn *net.UDPConn) bool
RangeLocalConn(lAddr string, f func(key, value any) bool)
GetOrCreateLockForLocalConn(lAddr, key string) (*sync.Cond, bool)
DeleteLocalConnMap(lAddr, key string)
}

View file

@ -7,6 +7,7 @@ import (
"net/url" "net/url"
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/transport/socks5" "github.com/Dreamacro/clash/transport/socks5"
) )
@ -44,6 +45,13 @@ func (c *packet) InAddr() net.Addr {
return c.pc.LocalAddr() return c.pc.LocalAddr()
} }
func (c *packet) SetNatTable(natTable C.NatTable) {
// no need
}
func (c *packet) SetUdpInChan(in chan<- C.PacketAdapter) {
// no need
}
func ParseSSURL(s string) (addr, cipher, password string, err error) { func ParseSSURL(s string) (addr, cipher, password string, err error) {
u, err := url.Parse(s) u, err := url.Parse(s)
if err != nil { if err != nil {

View file

@ -166,3 +166,11 @@ func (c *packet) Drop() {
func (c *packet) InAddr() net.Addr { func (c *packet) InAddr() net.Addr {
return c.lAddr return c.lAddr
} }
func (c *packet) SetNatTable(natTable C.NatTable) {
// no need
}
func (c *packet) SetUdpInChan(in chan<- C.PacketAdapter) {
// no need
}

View file

@ -4,6 +4,7 @@ import (
"net" "net"
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/transport/socks5" "github.com/Dreamacro/clash/transport/socks5"
) )
@ -39,3 +40,11 @@ func (c *packet) Drop() {
func (c *packet) InAddr() net.Addr { func (c *packet) InAddr() net.Addr {
return c.pc.LocalAddr() return c.pc.LocalAddr()
} }
func (c *packet) SetNatTable(natTable C.NatTable) {
// no need
}
func (c *packet) SetUdpInChan(in chan<- C.PacketAdapter) {
// no need
}

View file

@ -1,16 +1,22 @@
package tproxy package tproxy
import ( import (
"errors"
"fmt"
"github.com/Dreamacro/clash/adapter/inbound"
"github.com/Dreamacro/clash/common/pool"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
"net" "net"
"net/netip" "net/netip"
"github.com/Dreamacro/clash/common/pool"
) )
type packet struct { type packet struct {
pc net.PacketConn pc net.PacketConn
lAddr netip.AddrPort lAddr netip.AddrPort
buf []byte buf []byte
natTable C.NatTable
in chan<- C.PacketAdapter
} }
func (c *packet) Data() []byte { func (c *packet) Data() []byte {
@ -19,13 +25,12 @@ func (c *packet) Data() []byte {
// WriteBack opens a new socket binding `addr` to write UDP packet back // WriteBack opens a new socket binding `addr` to write UDP packet back
func (c *packet) WriteBack(b []byte, addr net.Addr) (n int, err error) { func (c *packet) WriteBack(b []byte, addr net.Addr) (n int, err error) {
tc, err := dialUDP("udp", addr.(*net.UDPAddr).AddrPort(), c.lAddr) tc, err := createOrGetLocalConn(addr, c.LocalAddr(), c.natTable, c.in)
if err != nil { if err != nil {
n = 0 n = 0
return return
} }
n, err = tc.Write(b) n, err = tc.Write(b)
tc.Close()
return return
} }
@ -41,3 +46,82 @@ func (c *packet) Drop() {
func (c *packet) InAddr() net.Addr { func (c *packet) InAddr() net.Addr {
return c.pc.LocalAddr() return c.pc.LocalAddr()
} }
func (c *packet) SetNatTable(natTable C.NatTable) {
c.natTable = natTable
}
func (c *packet) SetUdpInChan(in chan<- C.PacketAdapter) {
c.in = in
}
// this function listen at rAddr and write to lAddr
// for here, rAddr is the ip/port client want to access
// lAddr is the ip/port client opened
func createOrGetLocalConn(rAddr, lAddr net.Addr, natTable C.NatTable, in chan<- C.PacketAdapter) (*net.UDPConn, error) {
remote := rAddr.String()
local := lAddr.String()
localConn := natTable.GetLocalConn(local, remote)
// localConn not exist
if localConn == nil {
lockKey := remote + "-lock"
cond, loaded := natTable.GetOrCreateLockForLocalConn(local, lockKey)
if loaded {
cond.L.Lock()
cond.Wait()
// we should get localConn here
localConn = natTable.GetLocalConn(local, remote)
if localConn == nil {
return nil, fmt.Errorf("localConn is nil, nat entry not exist")
}
cond.L.Unlock()
} else {
if cond == nil {
return nil, fmt.Errorf("cond is nil, nat entry not exist")
}
defer func() {
natTable.DeleteLocalConnMap(local, lockKey)
cond.Broadcast()
}()
conn, err := listenLocalConn(rAddr, lAddr, in)
if err != nil {
log.Errorln("listenLocalConn failed with error: %s, packet loss", err.Error())
return nil, err
}
natTable.AddLocalConn(local, remote, conn)
localConn = conn
}
}
return localConn, nil
}
// this function listen at rAddr
// and send what received to program itself, then send to real remote
func listenLocalConn(rAddr, lAddr net.Addr, in chan<- C.PacketAdapter) (*net.UDPConn, error) {
additions := []inbound.Addition{
inbound.WithInName("DEFAULT-TPROXY"),
inbound.WithSpecialRules(""),
}
lc, err := dialUDP("udp", rAddr.(*net.UDPAddr).AddrPort(), lAddr.(*net.UDPAddr).AddrPort())
if err != nil {
return nil, err
}
go func() {
log.Debugln("TProxy listenLocalConn rAddr=%s lAddr=%s", rAddr.String(), lAddr.String())
for {
buf := pool.Get(pool.UDPBufferSize)
br, err := lc.Read(buf)
if err != nil {
pool.Put(buf)
if errors.Is(err, net.ErrClosed) {
log.Debugln("TProxy local conn listener exit.. rAddr=%s lAddr=%s", rAddr.String(), lAddr.String())
return
}
}
// since following localPackets are pass through this socket which listen rAddr
// I choose current listener as packet's packet conn
handlePacketConn(lc, in, buf[:br], lAddr.(*net.UDPAddr).AddrPort(), rAddr.(*net.UDPAddr).AddrPort(), additions...)
}
}()
return lc, nil
}

View file

@ -4,6 +4,7 @@ import (
"net" "net"
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
C "github.com/Dreamacro/clash/constant"
) )
type packet struct { type packet struct {
@ -33,3 +34,11 @@ func (c *packet) Drop() {
func (c *packet) InAddr() net.Addr { func (c *packet) InAddr() net.Addr {
return c.pc.LocalAddr() return c.pc.LocalAddr()
} }
func (c *packet) SetNatTable(natTable C.NatTable) {
// no need
}
func (c *packet) SetUdpInChan(in chan<- C.PacketAdapter) {
// no need
}

View file

@ -316,5 +316,13 @@ func (s *serverUDPPacket) Drop() {
s.packet.DATA = nil s.packet.DATA = nil
} }
func (s *serverUDPPacket) SetNatTable(natTable C.NatTable) {
// no need
}
func (s *serverUDPPacket) SetUdpInChan(in chan<- C.PacketAdapter) {
// no need
}
var _ C.UDPPacket = &serverUDPPacket{} var _ C.UDPPacket = &serverUDPPacket{}
var _ C.UDPPacketInAddr = &serverUDPPacket{} var _ C.UDPPacketInAddr = &serverUDPPacket{}

View file

@ -2,6 +2,7 @@ package tunnel
import ( import (
"errors" "errors"
"github.com/Dreamacro/clash/log"
"net" "net"
"net/netip" "net/netip"
"time" "time"
@ -32,6 +33,7 @@ func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, oAddr,
buf := pool.Get(pool.UDPBufferSize) buf := pool.Get(pool.UDPBufferSize)
defer func() { defer func() {
_ = pc.Close() _ = pc.Close()
closeAllLocalCoon(key)
natTable.Delete(key) natTable.Delete(key)
_ = pool.Put(buf) _ = pool.Put(buf)
}() }()
@ -60,6 +62,19 @@ func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, oAddr,
} }
} }
func closeAllLocalCoon(lAddr string) {
natTable.RangeLocalConn(lAddr, func(key, value any) bool {
conn, ok := value.(*net.UDPConn)
if !ok || conn == nil {
log.Debugln("Value %#v unknown value when closing TProxy local conn...", conn)
return true
}
conn.Close()
log.Debugln("Closing TProxy local conn... lAddr=%s rAddr=%s", lAddr, key)
return true
})
}
func handleSocket(ctx C.ConnContext, outbound net.Conn) { func handleSocket(ctx C.ConnContext, outbound net.Conn) {
N.Relay(ctx.Conn(), outbound) N.Relay(ctx.Conn(), outbound)
} }

View file

@ -337,9 +337,12 @@ func handleUDPConn(packet C.PacketAdapter) {
} }
oAddr := metadata.DstIP oAddr := metadata.DstIP
natTable.Set(key, pc)
packet.SetNatTable(natTable)
packet.SetUdpInChan(udpQueue)
go handleUDPToLocal(packet, pc, key, oAddr, fAddr) go handleUDPToLocal(packet, pc, key, oAddr, fAddr)
natTable.Set(key, pc)
handle() handle()
}() }()
} }