From 69b223041c71da393586d0b761026d7338cb6163 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Fri, 25 Nov 2022 20:14:05 +0800 Subject: [PATCH] chore: tuic use a simple client pool --- adapter/outbound/tuic.go | 127 +++++++++++++++++++++++---------------- transport/tuic/client.go | 32 ++++++++-- 2 files changed, 102 insertions(+), 57 deletions(-) diff --git a/adapter/outbound/tuic.go b/adapter/outbound/tuic.go index a7ba1f5e..30ea0485 100644 --- a/adapter/outbound/tuic.go +++ b/adapter/outbound/tuic.go @@ -17,6 +17,7 @@ import ( "github.com/metacubex/quic-go" + "github.com/Dreamacro/clash/common/generics/list" "github.com/Dreamacro/clash/component/dialer" tlsC "github.com/Dreamacro/clash/component/tls" C "github.com/Dreamacro/clash/constant" @@ -25,8 +26,8 @@ import ( type Tuic struct { *Base - getClient func(udp bool, opts ...dialer.Option) *tuic.Client - removeClient func(udp bool, opts ...dialer.Option) + newClient func(udp bool, opts ...dialer.Option) *tuic.Client + getClient func(udp bool, opts ...dialer.Option) *tuic.Client } type TuicOption struct { @@ -57,7 +58,7 @@ type TuicOption struct { // DialContext implements C.ProxyAdapter func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { opts = t.Base.DialOptions(opts...) - conn, err := t.getClient(false, opts...).DialContext(ctx, metadata, func(ctx context.Context) (net.PacketConn, net.Addr, error) { + dialFn := func(ctx context.Context) (net.PacketConn, net.Addr, error) { pc, err := dialer.ListenPacket(ctx, "udp", "", opts...) if err != nil { return nil, nil, err @@ -67,12 +68,12 @@ func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...di return nil, nil, err } return pc, addr, err - }) + } + conn, err := t.getClient(false, opts...).DialContext(ctx, metadata, dialFn) + if errors.Is(err, tuic.TooManyOpenStreams) { + conn, err = t.newClient(false, opts...).DialContext(ctx, metadata, dialFn) + } if err != nil { - if errors.Is(err, tuic.TooManyOpenStreams) { - t.removeClient(false, opts...) - return t.DialContext(ctx, metadata, opts...) - } return nil, err } return NewConn(conn, t), err @@ -81,7 +82,7 @@ func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...di // ListenPacketContext implements C.ProxyAdapter func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { opts = t.Base.DialOptions(opts...) - pc, err := t.getClient(true, opts...).ListenPacketContext(ctx, metadata, func(ctx context.Context) (net.PacketConn, net.Addr, error) { + dialFn := func(ctx context.Context) (net.PacketConn, net.Addr, error) { pc, err := dialer.ListenPacket(ctx, "udp", "", opts...) if err != nil { return nil, nil, err @@ -91,12 +92,12 @@ func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata, op return nil, nil, err } return pc, addr, err - }) + } + pc, err := t.getClient(true, opts...).ListenPacketContext(ctx, metadata, dialFn) + if errors.Is(err, tuic.TooManyOpenStreams) { + pc, err = t.newClient(false, opts...).ListenPacketContext(ctx, metadata, dialFn) + } if err != nil { - if errors.Is(err, tuic.TooManyOpenStreams) { - t.removeClient(true, opts...) - return t.ListenPacketContext(ctx, metadata, opts...) - } return nil, err } return newPacketConn(pc, t), nil @@ -194,37 +195,23 @@ func NewTuic(option TuicOption) (*Tuic, error) { tlsConfig.ServerName = "" } tkn := tuic.GenTKN(option.Token) - tcpClientMap := make(map[any]*tuic.Client) - tcpClientMapMutex := &sync.Mutex{} - udpClientMap := make(map[any]*tuic.Client) - udpClientMapMutex := &sync.Mutex{} - getClient := func(udp bool, opts ...dialer.Option) *tuic.Client { - clientMap := tcpClientMap - clientMapMutex := tcpClientMapMutex + tcpClients := list.New[*tuic.Client]() + tcpClientsMutex := &sync.Mutex{} + udpClients := list.New[*tuic.Client]() + udpClientsMutex := &sync.Mutex{} + newClient := func(udp bool, opts ...dialer.Option) *tuic.Client { + clients := tcpClients + clientsMutex := tcpClientsMutex if udp { - clientMap = udpClientMap - clientMapMutex = udpClientMapMutex + clients = udpClients + clientsMutex = udpClientsMutex } - o := *dialer.ApplyOptions(opts...) + var o any = *dialer.ApplyOptions(opts...) + + clientsMutex.Lock() + defer clientsMutex.Unlock() - clientMapMutex.Lock() - defer clientMapMutex.Unlock() - for key := range clientMap { - client := clientMap[key] - if client == nil { - delete(clientMap, key) // It is safe in Golang - continue - } - if key == o { - client.LastVisited = time.Now() - return client - } - if time.Now().Sub(client.LastVisited) > 30*time.Minute { - delete(clientMap, key) - continue - } - } client := &tuic.Client{ TlsConfig: tlsConfig, QuicConfig: quicConfig, @@ -235,26 +222,60 @@ func NewTuic(option TuicOption) (*Tuic, error) { ReduceRtt: option.ReduceRtt, RequestTimeout: option.RequestTimeout, MaxUdpRelayPacketSize: option.MaxUdpRelayPacketSize, + Key: o, LastVisited: time.Now(), UDP: udp, } - clientMap[o] = client + clients.PushFront(client) runtime.SetFinalizer(client, closeTuicClient) return client } - removeClient := func(udp bool, opts ...dialer.Option) { - clientMap := tcpClientMap - clientMapMutex := tcpClientMapMutex + getClient := func(udp bool, opts ...dialer.Option) *tuic.Client { + clients := tcpClients + clientsMutex := tcpClientsMutex if udp { - clientMap = udpClientMap - clientMapMutex = udpClientMapMutex + clients = udpClients + clientsMutex = udpClientsMutex } - o := *dialer.ApplyOptions(opts...) + var o any = *dialer.ApplyOptions(opts...) + var bestClient *tuic.Client - clientMapMutex.Lock() - defer clientMapMutex.Unlock() - delete(clientMap, o) + func() { + clientsMutex.Lock() + defer clientsMutex.Unlock() + for it := clients.Front(); it != nil; { + client := it.Value + if client == nil { + next := it.Next() + clients.Remove(it) + it = next + continue + } + if client.Key == o { + if bestClient == nil { + bestClient = client + } else { + if client.OpenStreams.Load() < bestClient.OpenStreams.Load() { + bestClient = client + } + } + } + if time.Now().Sub(client.LastVisited) > 30*time.Minute { + next := it.Next() + clients.Remove(it) + it = next + continue + } + it = it.Next() + } + }() + + if bestClient == nil { + return newClient(udp, opts...) + } else { + return bestClient + } } return &Tuic{ @@ -266,8 +287,8 @@ func NewTuic(option TuicOption) (*Tuic, error) { iface: option.Interface, prefer: C.NewDNSPrefer(option.IPVersion), }, - getClient: getClient, - removeClient: removeClient, + newClient: newClient, + getClient: getClient, }, nil } diff --git a/transport/tuic/client.go b/transport/tuic/client.go index d5d2664d..332c664e 100644 --- a/transport/tuic/client.go +++ b/transport/tuic/client.go @@ -39,13 +39,14 @@ type Client struct { RequestTimeout int MaxUdpRelayPacketSize int + Key any LastVisited time.Time UDP bool quicConn quic.Connection connMutex sync.Mutex - openStreams atomic.Int32 + OpenStreams atomic.Int32 udpInputMap sync.Map } @@ -242,9 +243,9 @@ func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn f if err != nil { return nil, err } - openStreams := t.openStreams.Add(1) + openStreams := t.OpenStreams.Add(1) if openStreams >= MaxOpenStreams { - t.openStreams.Add(-1) + t.OpenStreams.Add(-1) return nil, TooManyOpenStreams } stream, err := func() (stream *quicStreamConn, err error) { @@ -300,6 +301,9 @@ type quicStreamConn struct { lAddr net.Addr rAddr net.Addr client *Client + + closeOnce sync.Once + closeErr error } func (q *quicStreamConn) Write(p []byte) (n int, err error) { @@ -309,8 +313,15 @@ func (q *quicStreamConn) Write(p []byte) (n int, err error) { } func (q *quicStreamConn) Close() error { + q.closeOnce.Do(func() { + q.closeErr = q.close() + }) + return q.closeErr +} + +func (q *quicStreamConn) close() error { defer time.AfterFunc(C.DefaultTCPTimeout, func() { - q.client.openStreams.Add(-1) + q.client.OpenStreams.Add(-1) }) // https://github.com/cloudflare/cloudflared/commit/ed2bac026db46b239699ac5ce4fcf122d7cab2cd @@ -342,6 +353,11 @@ func (t *Client) ListenPacketContext(ctx context.Context, metadata *C.Metadata, if err != nil { return nil, err } + openStreams := t.OpenStreams.Add(1) + if openStreams >= MaxOpenStreams { + t.OpenStreams.Add(-1) + return nil, TooManyOpenStreams + } pipe1, pipe2 := net.Pipe() inputCh := make(chan udpData) @@ -380,16 +396,21 @@ type quicStreamPacketConn struct { closeOnce sync.Once closeErr error + closed bool } func (q *quicStreamPacketConn) Close() error { q.closeOnce.Do(func() { + q.closed = true q.closeErr = q.close() }) return q.closeErr } func (q *quicStreamPacketConn) close() (err error) { + defer time.AfterFunc(C.DefaultTCPTimeout, func() { + q.client.OpenStreams.Add(-1) + }) defer func() { q.client.deferQuicConn(q.quicConn, err) }() @@ -441,6 +462,9 @@ func (q *quicStreamPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err erro if len(p) > q.client.MaxUdpRelayPacketSize { return 0, fmt.Errorf("udp packet too large(%d > %d)", len(p), q.client.MaxUdpRelayPacketSize) } + if q.closed { + return 0, net.ErrClosed + } defer func() { q.client.deferQuicConn(q.quicConn, err) }()