From 191243a1d2a8252c76092e65864e338f5e93db62 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Thu, 3 Aug 2023 23:07:30 +0800 Subject: [PATCH] chore: better tuicV5 deFragger --- common/cache/lrucache.go | 42 +++++++++++++++--- transport/tuic/v4/server.go | 8 ++-- transport/tuic/v5/frag.go | 85 +++++++++++++++++++++++++------------ transport/tuic/v5/packet.go | 14 +++--- transport/tuic/v5/server.go | 6 +-- 5 files changed, 105 insertions(+), 50 deletions(-) diff --git a/common/cache/lrucache.go b/common/cache/lrucache.go index 73600e71..1373b0be 100644 --- a/common/cache/lrucache.go +++ b/common/cache/lrucache.go @@ -82,6 +82,9 @@ func New[K comparable, V any](options ...Option[K, V]) *LruCache[K, V] { // Get returns the any representation of a cached response and a bool // set to true if the key was found. func (c *LruCache[K, V]) Get(key K) (V, bool) { + c.mu.Lock() + defer c.mu.Unlock() + el := c.get(key) if el == nil { return getZero[V](), false @@ -91,11 +94,29 @@ func (c *LruCache[K, V]) Get(key K) (V, bool) { return value, true } +func (c *LruCache[K, V]) GetOrStore(key K, constructor func() V) (V, bool) { + c.mu.Lock() + defer c.mu.Unlock() + + el := c.get(key) + if el == nil { + value := constructor() + c.set(key, value) + return value, false + } + value := el.value + + return value, true +} + // GetWithExpire returns the any representation of a cached response, // a time.Time Give expected expires, // and a bool set to true if the key was found. // This method will NOT check the maxAge of element and will NOT update the expires. func (c *LruCache[K, V]) GetWithExpire(key K) (V, time.Time, bool) { + c.mu.Lock() + defer c.mu.Unlock() + el := c.get(key) if el == nil { return getZero[V](), time.Time{}, false @@ -115,11 +136,18 @@ func (c *LruCache[K, V]) Exist(key K) bool { // Set stores the any representation of a response for a given key. func (c *LruCache[K, V]) Set(key K, value V) { + c.mu.Lock() + defer c.mu.Unlock() + + c.set(key, value) +} + +func (c *LruCache[K, V]) set(key K, value V) { expires := int64(0) if c.maxAge > 0 { expires = time.Now().Unix() + c.maxAge } - c.SetWithExpire(key, value, time.Unix(expires, 0)) + c.setWithExpire(key, value, time.Unix(expires, 0)) } // SetWithExpire stores the any representation of a response for a given key and given expires. @@ -128,6 +156,10 @@ func (c *LruCache[K, V]) SetWithExpire(key K, value V, expires time.Time) { c.mu.Lock() defer c.mu.Unlock() + c.setWithExpire(key, value, expires) +} + +func (c *LruCache[K, V]) setWithExpire(key K, value V, expires time.Time) { if le, ok := c.cache[key]; ok { c.lru.MoveToBack(le) e := le.Value @@ -165,9 +197,6 @@ func (c *LruCache[K, V]) CloneTo(n *LruCache[K, V]) { } func (c *LruCache[K, V]) get(key K) *entry[K, V] { - c.mu.Lock() - defer c.mu.Unlock() - le, ok := c.cache[key] if !ok { return nil @@ -191,12 +220,11 @@ func (c *LruCache[K, V]) get(key K) *entry[K, V] { // Delete removes the value associated with a key. func (c *LruCache[K, V]) Delete(key K) { c.mu.Lock() + defer c.mu.Unlock() if le, ok := c.cache[key]; ok { c.deleteElement(le) } - - c.mu.Unlock() } func (c *LruCache[K, V]) maybeDeleteOldest() { @@ -219,10 +247,10 @@ func (c *LruCache[K, V]) deleteElement(le *list.Element[*entry[K, V]]) { func (c *LruCache[K, V]) Clear() error { c.mu.Lock() + defer c.mu.Unlock() c.cache = make(map[K]*list.Element[*entry[K, V]]) - c.mu.Unlock() return nil } diff --git a/transport/tuic/v4/server.go b/transport/tuic/v4/server.go index 9513ccfd..b0012d96 100644 --- a/transport/tuic/v4/server.go +++ b/transport/tuic/v4/server.go @@ -66,10 +66,10 @@ func (s *serverHandler) HandleMessage(message []byte) (err error) { if err != nil { return } - return s.parsePacket(packet, common.NATIVE) + return s.parsePacket(&packet, common.NATIVE) } -func (s *serverHandler) parsePacket(packet Packet, udpRelayMode common.UdpRelayMode) (err error) { +func (s *serverHandler) parsePacket(packet *Packet, udpRelayMode common.UdpRelayMode) (err error) { <-s.authCh if !s.authOk.Load() { return @@ -97,7 +97,7 @@ func (s *serverHandler) parsePacket(packet Packet, udpRelayMode common.UdpRelayM return s.HandleUdpFn(packet.ADDR.SocksAddr(), &serverUDPPacket{ pc: pc, - packet: &packet, + packet: packet, rAddr: N.NewCustomAddr("tuic", fmt.Sprintf("tuic-%s-%d", s.uuid, assocId), s.quicConn.RemoteAddr()), // for tunnel's handleUDPConn }) } @@ -166,7 +166,7 @@ func (s *serverHandler) HandleUniStream(reader *bufio.Reader) (err error) { if err != nil { return } - return s.parsePacket(packet, common.QUIC) + return s.parsePacket(&packet, common.QUIC) case DissociateType: var disassociate Dissociate disassociate, err = ReadDissociateWithHead(commandHead, reader) diff --git a/transport/tuic/v5/frag.go b/transport/tuic/v5/frag.go index 30b7b3f5..ae9dbf10 100644 --- a/transport/tuic/v5/frag.go +++ b/transport/tuic/v5/frag.go @@ -2,6 +2,9 @@ package v5 import ( "bytes" + "sync" + + "github.com/Dreamacro/clash/common/cache" "github.com/metacubex/quic-go" ) @@ -39,42 +42,68 @@ func fragWriteNative(quicConn quic.Connection, packet Packet, buf *bytes.Buffer, } type deFragger struct { - pkgID uint16 - frags []*Packet - count uint8 + lru *cache.LruCache[uint16, *packetBag] + once sync.Once } -func (d *deFragger) Feed(m Packet) *Packet { +type packetBag struct { + frags []*Packet + count uint8 + mutex sync.Mutex +} + +func newPacketBag() *packetBag { + return new(packetBag) +} + +func (d *deFragger) init() { + if d.lru == nil { + d.lru = cache.New( + cache.WithAge[uint16, *packetBag](10), + cache.WithUpdateAgeOnGet[uint16, *packetBag](), + ) + } +} + +func (d *deFragger) Feed(m *Packet) *Packet { if m.FRAG_TOTAL <= 1 { - return &m + return m } if m.FRAG_ID >= m.FRAG_TOTAL { // wtf is this? return nil } - if d.count == 0 || m.PKT_ID != d.pkgID { + d.once.Do(d.init) // lazy init + bag, _ := d.lru.GetOrStore(m.PKT_ID, newPacketBag) + bag.mutex.Lock() + defer bag.mutex.Unlock() + if int(m.FRAG_TOTAL) != len(bag.frags) { // new message, clear previous state - d.pkgID = m.PKT_ID - d.frags = make([]*Packet, m.FRAG_TOTAL) - d.count = 1 - d.frags[m.FRAG_ID] = &m - } else if d.frags[m.FRAG_ID] == nil { - d.frags[m.FRAG_ID] = &m - d.count++ - if int(d.count) == len(d.frags) { - // all fragments received, assemble - var data []byte - for _, frag := range d.frags { - data = append(data, frag.DATA...) - } - p := d.frags[0] // recover from first fragment - p.SIZE = uint16(len(data)) - p.DATA = data - p.FRAG_ID = 0 - p.FRAG_TOTAL = 1 - d.count = 0 - return p - } + bag.frags = make([]*Packet, m.FRAG_TOTAL) + bag.count = 1 + bag.frags[m.FRAG_ID] = m + return nil } - return nil + if bag.frags[m.FRAG_ID] != nil { + return nil + } + bag.frags[m.FRAG_ID] = m + bag.count++ + if int(bag.count) != len(bag.frags) { + return nil + } + + // all fragments received, assemble + var data []byte + for _, frag := range bag.frags { + data = append(data, frag.DATA...) + } + p := *bag.frags[0] // recover from first fragment + p.SIZE = uint16(len(data)) + p.DATA = data + p.FRAG_ID = 0 + p.FRAG_TOTAL = 1 + bag.frags = nil + d.lru.Delete(m.PKT_ID) + return &p } diff --git a/transport/tuic/v5/packet.go b/transport/tuic/v5/packet.go index 4a11d671..cd3ed12b 100644 --- a/transport/tuic/v5/packet.go +++ b/transport/tuic/v5/packet.go @@ -103,7 +103,7 @@ func (q *quicStreamPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err err if err != nil { return } - if packetPtr := q.deFragger.Feed(packet); packetPtr != nil { + if packetPtr := q.deFragger.Feed(&packet); packetPtr != nil { n = copy(p, packet.DATA) addr = packetPtr.ADDR.UDPAddr() return @@ -123,7 +123,7 @@ func (q *quicStreamPacketConn) WaitReadFrom() (data []byte, put func(), addr net if err != nil { return } - if packetPtr := q.deFragger.Feed(packet); packetPtr != nil { + if packetPtr := q.deFragger.Feed(&packet); packetPtr != nil { data = packetPtr.DATA addr = packetPtr.ADDR.UDPAddr() return @@ -178,16 +178,14 @@ func (q *quicStreamPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err erro default: // native if len(p) > q.maxUdpRelayPacketSize { err = fragWriteNative(q.quicConn, packet, buf, q.maxUdpRelayPacketSize) + } else { + err = packet.WriteTo(buf) if err != nil { return } + data := buf.Bytes() + err = q.quicConn.SendMessage(data) } - err = packet.WriteTo(buf) - if err != nil { - return - } - data := buf.Bytes() - err = q.quicConn.SendMessage(data) var tooLarge quic.ErrMessageTooLarge if errors.As(err, &tooLarge) { diff --git a/transport/tuic/v5/server.go b/transport/tuic/v5/server.go index 96b3d24f..30259583 100644 --- a/transport/tuic/v5/server.go +++ b/transport/tuic/v5/server.go @@ -73,7 +73,7 @@ func (s *serverHandler) HandleMessage(message []byte) (err error) { if err != nil { return } - return s.parsePacket(packet, common.NATIVE) + return s.parsePacket(&packet, common.NATIVE) case HeartbeatType: var heartbeat Heartbeat heartbeat, err = ReadHeartbeatWithHead(commandHead, reader) @@ -85,7 +85,7 @@ func (s *serverHandler) HandleMessage(message []byte) (err error) { return } -func (s *serverHandler) parsePacket(packet Packet, udpRelayMode common.UdpRelayMode) (err error) { +func (s *serverHandler) parsePacket(packet *Packet, udpRelayMode common.UdpRelayMode) (err error) { <-s.authCh if !s.authOk.Load() { return @@ -179,7 +179,7 @@ func (s *serverHandler) HandleUniStream(reader *bufio.Reader) (err error) { if err != nil { return } - return s.parsePacket(packet, common.QUIC) + return s.parsePacket(&packet, common.QUIC) case DissociateType: var disassociate Dissociate disassociate, err = ReadDissociateWithHead(commandHead, reader)