chore: update proxy's udpConn when received a new packet

This commit is contained in:
gVisor bot 2023-06-03 21:40:09 +08:00
parent 5f25633a0c
commit 35ead7d20b
11 changed files with 80 additions and 28 deletions

26
component/nat/proxy.go Normal file
View file

@ -0,0 +1,26 @@
package nat
import (
"net"
"github.com/Dreamacro/clash/common/atomic"
C "github.com/Dreamacro/clash/constant"
)
type writeBackProxy struct {
wb atomic.TypedValue[C.WriteBack]
}
func (w *writeBackProxy) WriteBack(b []byte, addr net.Addr) (n int, err error) {
return w.wb.Load().WriteBack(b, addr)
}
func (w *writeBackProxy) UpdateWriteBack(wb C.WriteBack) {
w.wb.Store(wb)
}
func NewWriteBackProxy(wb C.WriteBack) C.WriteBackProxy {
w := &writeBackProxy{}
w.UpdateWriteBack(wb)
return w
}

View file

@ -13,22 +13,24 @@ type Table struct {
type Entry struct {
PacketConn C.PacketConn
WriteBackProxy C.WriteBackProxy
LocalUDPConnMap sync.Map
}
func (t *Table) Set(key string, e C.PacketConn) {
func (t *Table) Set(key string, e C.PacketConn, w C.WriteBackProxy) {
t.mapping.Store(key, &Entry{
PacketConn: e,
WriteBackProxy: w,
LocalUDPConnMap: sync.Map{},
})
}
func (t *Table) Get(key string) C.PacketConn {
func (t *Table) Get(key string) (C.PacketConn, C.WriteBackProxy) {
entry, exist := t.getEntry(key)
if !exist {
return nil
return nil, nil
}
return entry.PacketConn
return entry.PacketConn, entry.WriteBackProxy
}
func (t *Table) GetOrCreateLock(key string) (*sync.Cond, bool) {

View file

@ -217,7 +217,7 @@ type UDPPacket interface {
// - variable source IP/Port is important to STUN
// - if addr is not provided, WriteBack will write out UDP packet with SourceIP/Port equals to original Target,
// this is important when using Fake-IP.
WriteBack(b []byte, addr net.Addr) (n int, err error)
WriteBack
// Drop call after packet is used, could recycle buffer in this function.
Drop()
@ -236,10 +236,19 @@ type PacketAdapter interface {
Metadata() *Metadata
}
type NatTable interface {
Set(key string, e PacketConn)
type WriteBack interface {
WriteBack(b []byte, addr net.Addr) (n int, err error)
}
Get(key string) PacketConn
type WriteBackProxy interface {
WriteBack
UpdateWriteBack(wb WriteBack)
}
type NatTable interface {
Set(key string, e PacketConn, w WriteBackProxy)
Get(key string) (PacketConn, WriteBackProxy)
GetOrCreateLock(key string) (*sync.Cond, bool)

View file

@ -58,7 +58,7 @@ func (l *UDPListener) LocalAddr() net.Addr {
return l.packetConn.LocalAddr()
}
func handleSocksUDP(pc net.PacketConn, in chan<- C.PacketAdapter, buf []byte, put func(), addr net.Addr) {
func handleSocksUDP(pc net.PacketConn, in chan<- C.PacketAdapter, buf []byte, put func(), addr net.Addr, additions ...inbound.Addition) {
tgtAddr := socks5.SplitAddr(buf)
if tgtAddr == nil {
// Unresolved UDP packet, return buffer to the pool
@ -77,7 +77,7 @@ func handleSocksUDP(pc net.PacketConn, in chan<- C.PacketAdapter, buf []byte, pu
put: put,
}
select {
case in <- inbound.NewPacket(target, packet, C.SHADOWSOCKS):
case in <- inbound.NewPacket(target, packet, C.SHADOWSOCKS, additions...):
default:
}
}

View file

@ -38,7 +38,9 @@ func (c *packet) LocalAddr() net.Addr {
func (c *packet) Drop() {
if c.put != nil {
c.put()
c.put = nil
}
c.payload = nil
}
func (c *packet) InAddr() net.Addr {

View file

@ -4,7 +4,7 @@ import (
"net"
"github.com/Dreamacro/clash/adapter/inbound"
"github.com/Dreamacro/clash/common/pool"
N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/common/sockopt"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
@ -53,36 +53,40 @@ func NewUDP(addr string, in chan<- C.PacketAdapter, additions ...inbound.Additio
packetConn: l,
addr: addr,
}
conn := N.NewEnhancePacketConn(l)
go func() {
for {
buf := pool.Get(pool.UDPBufferSize)
n, remoteAddr, err := l.ReadFrom(buf)
data, put, remoteAddr, err := conn.WaitReadFrom()
if err != nil {
pool.Put(buf)
if put != nil {
put()
}
if sl.closed {
break
}
continue
}
handleSocksUDP(l, in, buf[:n], remoteAddr, additions...)
handleSocksUDP(l, in, data, put, remoteAddr, additions...)
}
}()
return sl, nil
}
func handleSocksUDP(pc net.PacketConn, in chan<- C.PacketAdapter, buf []byte, addr net.Addr, additions ...inbound.Addition) {
func handleSocksUDP(pc net.PacketConn, in chan<- C.PacketAdapter, buf []byte, put func(), addr net.Addr, additions ...inbound.Addition) {
target, payload, err := socks5.DecodeUDPPacket(buf)
if err != nil {
// Unresolved UDP packet, return buffer to the pool
pool.Put(buf)
if put != nil {
put()
}
return
}
packet := &packet{
pc: pc,
rAddr: addr,
payload: payload,
bufRef: buf,
put: put,
}
select {
case in <- inbound.NewPacket(target, packet, C.SOCKS5, additions...):

View file

@ -3,7 +3,6 @@ package socks
import (
"net"
"github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/transport/socks5"
)
@ -11,7 +10,7 @@ type packet struct {
pc net.PacketConn
rAddr net.Addr
payload []byte
bufRef []byte
put func()
}
func (c *packet) Data() []byte {
@ -33,7 +32,11 @@ func (c *packet) LocalAddr() net.Addr {
}
func (c *packet) Drop() {
pool.Put(c.bufRef)
if c.put != nil {
c.put()
c.put = nil
}
c.payload = nil
}
func (c *packet) InAddr() net.Addr {

View file

@ -41,7 +41,8 @@ func (c *packet) LocalAddr() net.Addr {
}
func (c *packet) Drop() {
pool.Put(c.buf)
_ = pool.Put(c.buf)
c.buf = nil
}
func (c *packet) InAddr() net.Addr {

View file

@ -27,7 +27,8 @@ func (c *packet) LocalAddr() net.Addr {
}
func (c *packet) Drop() {
pool.Put(c.payload)
_ = pool.Put(c.payload)
c.payload = nil
}
func (c *packet) InAddr() net.Addr {

View file

@ -26,7 +26,7 @@ func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata
return nil
}
func handleUDPToLocal(packet C.UDPPacket, pc N.EnhancePacketConn, key string, oAddrPort netip.AddrPort, fAddr netip.Addr) {
func handleUDPToLocal(writeBack C.WriteBack, pc N.EnhancePacketConn, key string, oAddrPort netip.AddrPort, fAddr netip.Addr) {
defer func() {
_ = pc.Close()
closeAllLocalCoon(key)
@ -59,7 +59,7 @@ func handleUDPToLocal(packet C.UDPPacket, pc N.EnhancePacketConn, key string, oA
log.Warnln("server return a [%T](%s) which isn't a *net.UDPAddr, force replace to (%s), this may be caused by a wrongly implemented server", from, from, oAddrPort)
}
_, err = packet.WriteBack(data, fromUDPAddr)
_, err = writeBack.WriteBack(data, fromUDPAddr)
if put != nil {
put()
}

View file

@ -303,8 +303,11 @@ func handleUDPConn(packet C.PacketAdapter) {
key := packet.LocalAddr().String()
handle := func() bool {
pc := natTable.Get(key)
pc, proxy := natTable.Get(key)
if pc != nil {
if proxy != nil {
proxy.UpdateWriteBack(packet)
}
_ = handleUDPToRemote(packet, pc, metadata)
return true
}
@ -384,9 +387,10 @@ func handleUDPConn(packet C.PacketAdapter) {
}
oAddrPort := metadata.AddrPort()
natTable.Set(key, pc)
writeBackProxy := nat.NewWriteBackProxy(packet)
natTable.Set(key, pc, writeBackProxy)
go handleUDPToLocal(packet, pc, key, oAddrPort, fAddr)
go handleUDPToLocal(writeBackProxy, pc, key, oAddrPort, fAddr)
handle()
}()