diff --git a/adapter/outbound/tuic.go b/adapter/outbound/tuic.go index 57b761c3..a7ba1f5e 100644 --- a/adapter/outbound/tuic.go +++ b/adapter/outbound/tuic.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "encoding/hex" "encoding/pem" + "errors" "fmt" "net" "os" @@ -24,7 +25,8 @@ import ( type Tuic struct { *Base - getClient func(udp bool, opts ...dialer.Option) *tuic.Client + getClient func(udp bool, opts ...dialer.Option) *tuic.Client + removeClient func(udp bool, opts ...dialer.Option) } type TuicOption struct { @@ -67,6 +69,10 @@ func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...di return pc, addr, err }) if err != nil { + if errors.Is(err, tuic.TooManyOpenStreams) { + t.removeClient(false, opts...) + return t.DialContext(ctx, metadata, opts...) + } return nil, err } return NewConn(conn, t), err @@ -87,6 +93,10 @@ func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata, op return pc, addr, err }) if err != nil { + if errors.Is(err, tuic.TooManyOpenStreams) { + t.removeClient(true, opts...) + return t.ListenPacketContext(ctx, metadata, opts...) + } return nil, err } return newPacketConn(pc, t), nil @@ -232,6 +242,20 @@ func NewTuic(option TuicOption) (*Tuic, error) { runtime.SetFinalizer(client, closeTuicClient) return client } + removeClient := func(udp bool, opts ...dialer.Option) { + clientMap := tcpClientMap + clientMapMutex := tcpClientMapMutex + if udp { + clientMap = udpClientMap + clientMapMutex = udpClientMapMutex + } + + o := *dialer.ApplyOptions(opts...) + + clientMapMutex.Lock() + defer clientMapMutex.Unlock() + delete(clientMap, o) + } return &Tuic{ Base: &Base{ @@ -242,10 +266,11 @@ func NewTuic(option TuicOption) (*Tuic, error) { iface: option.Interface, prefer: C.NewDNSPrefer(option.IPVersion), }, - getClient: getClient, + getClient: getClient, + removeClient: removeClient, }, nil } func closeTuicClient(client *tuic.Client) { - client.Close(nil) + client.Close(tuic.ClientClosed) } diff --git a/transport/tuic/client.go b/transport/tuic/client.go index 4293c9c4..a28e6ec2 100644 --- a/transport/tuic/client.go +++ b/transport/tuic/client.go @@ -11,6 +11,7 @@ import ( "net" "net/netip" "sync" + "sync/atomic" "time" "github.com/metacubex/quic-go" @@ -20,6 +21,13 @@ import ( "github.com/Dreamacro/clash/transport/tuic/congestion" ) +var ( + ClientClosed = errors.New("tuic: client closed") + TooManyOpenStreams = errors.New("tuic: too many open streams") +) + +const MaxOpenStreams = 100 - 10 + type Client struct { TlsConfig *tls.Config QuicConfig *quic.Config @@ -37,6 +45,8 @@ type Client struct { quicConn quic.Connection connMutex sync.Mutex + openStreams atomic.Int32 + udpInputMap sync.Map } @@ -137,6 +147,7 @@ func (t *Client) getQuicConn(ctx context.Context, dialFn func(ctx context.Contex } } } + stream.CancelRead(0) }() reader := bufio.NewReader(stream) packet, err := ReadPacket(reader) @@ -231,6 +242,11 @@ func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn f if err != nil { return nil, err } + openStreams := t.openStreams.Add(1) + if openStreams >= MaxOpenStreams { + t.openStreams.Add(-1) + return nil, TooManyOpenStreams + } stream, err := func() (stream *quicStreamConn, err error) { defer func() { t.deferQuicConn(quicConn, err) @@ -281,7 +297,9 @@ type quicStreamConn struct { } func (q *quicStreamConn) Close() error { + //defer q.client.openStreams.Add(-1) q.Stream.CancelRead(0) + q.Stream.CancelWrite(0) return q.Stream.Close() } @@ -419,12 +437,8 @@ func (q *quicStreamPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err erro if err != nil { return } + defer stream.Close() _, err = buf.WriteTo(stream) - if err != nil { - _ = stream.Close() - return - } - err = stream.Close() if err != nil { return }