fix: tuic set MaxOpenStreams

This commit is contained in:
wwqgtxx 2022-11-25 18:32:30 +08:00
parent 76d2838721
commit 21a91e88a1
2 changed files with 47 additions and 8 deletions

View file

@ -6,6 +6,7 @@ import (
"crypto/tls" "crypto/tls"
"encoding/hex" "encoding/hex"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"net" "net"
"os" "os"
@ -25,6 +26,7 @@ import (
type Tuic struct { type Tuic struct {
*Base *Base
getClient func(udp bool, opts ...dialer.Option) *tuic.Client getClient func(udp bool, opts ...dialer.Option) *tuic.Client
removeClient func(udp bool, opts ...dialer.Option)
} }
type TuicOption struct { type TuicOption struct {
@ -67,6 +69,10 @@ func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...di
return pc, addr, err return pc, addr, err
}) })
if err != nil { if err != nil {
if errors.Is(err, tuic.TooManyOpenStreams) {
t.removeClient(false, opts...)
return t.DialContext(ctx, metadata, opts...)
}
return nil, err return nil, err
} }
return NewConn(conn, t), err return NewConn(conn, t), err
@ -87,6 +93,10 @@ func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata, op
return pc, addr, err return pc, addr, err
}) })
if err != nil { if err != nil {
if errors.Is(err, tuic.TooManyOpenStreams) {
t.removeClient(true, opts...)
return t.ListenPacketContext(ctx, metadata, opts...)
}
return nil, err return nil, err
} }
return newPacketConn(pc, t), nil return newPacketConn(pc, t), nil
@ -232,6 +242,20 @@ func NewTuic(option TuicOption) (*Tuic, error) {
runtime.SetFinalizer(client, closeTuicClient) runtime.SetFinalizer(client, closeTuicClient)
return client return client
} }
removeClient := func(udp bool, opts ...dialer.Option) {
clientMap := tcpClientMap
clientMapMutex := tcpClientMapMutex
if udp {
clientMap = udpClientMap
clientMapMutex = udpClientMapMutex
}
o := *dialer.ApplyOptions(opts...)
clientMapMutex.Lock()
defer clientMapMutex.Unlock()
delete(clientMap, o)
}
return &Tuic{ return &Tuic{
Base: &Base{ Base: &Base{
@ -243,9 +267,10 @@ func NewTuic(option TuicOption) (*Tuic, error) {
prefer: C.NewDNSPrefer(option.IPVersion), prefer: C.NewDNSPrefer(option.IPVersion),
}, },
getClient: getClient, getClient: getClient,
removeClient: removeClient,
}, nil }, nil
} }
func closeTuicClient(client *tuic.Client) { func closeTuicClient(client *tuic.Client) {
client.Close(nil) client.Close(tuic.ClientClosed)
} }

View file

@ -11,6 +11,7 @@ import (
"net" "net"
"net/netip" "net/netip"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/metacubex/quic-go" "github.com/metacubex/quic-go"
@ -20,6 +21,13 @@ import (
"github.com/Dreamacro/clash/transport/tuic/congestion" "github.com/Dreamacro/clash/transport/tuic/congestion"
) )
var (
ClientClosed = errors.New("tuic: client closed")
TooManyOpenStreams = errors.New("tuic: too many open streams")
)
const MaxOpenStreams = 100 - 10
type Client struct { type Client struct {
TlsConfig *tls.Config TlsConfig *tls.Config
QuicConfig *quic.Config QuicConfig *quic.Config
@ -37,6 +45,8 @@ type Client struct {
quicConn quic.Connection quicConn quic.Connection
connMutex sync.Mutex connMutex sync.Mutex
openStreams atomic.Int32
udpInputMap sync.Map udpInputMap sync.Map
} }
@ -137,6 +147,7 @@ func (t *Client) getQuicConn(ctx context.Context, dialFn func(ctx context.Contex
} }
} }
} }
stream.CancelRead(0)
}() }()
reader := bufio.NewReader(stream) reader := bufio.NewReader(stream)
packet, err := ReadPacket(reader) packet, err := ReadPacket(reader)
@ -231,6 +242,11 @@ func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn f
if err != nil { if err != nil {
return nil, err return nil, err
} }
openStreams := t.openStreams.Add(1)
if openStreams >= MaxOpenStreams {
t.openStreams.Add(-1)
return nil, TooManyOpenStreams
}
stream, err := func() (stream *quicStreamConn, err error) { stream, err := func() (stream *quicStreamConn, err error) {
defer func() { defer func() {
t.deferQuicConn(quicConn, err) t.deferQuicConn(quicConn, err)
@ -281,7 +297,9 @@ type quicStreamConn struct {
} }
func (q *quicStreamConn) Close() error { func (q *quicStreamConn) Close() error {
//defer q.client.openStreams.Add(-1)
q.Stream.CancelRead(0) q.Stream.CancelRead(0)
q.Stream.CancelWrite(0)
return q.Stream.Close() return q.Stream.Close()
} }
@ -419,12 +437,8 @@ func (q *quicStreamPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err erro
if err != nil { if err != nil {
return return
} }
defer stream.Close()
_, err = buf.WriteTo(stream) _, err = buf.WriteTo(stream)
if err != nil {
_ = stream.Close()
return
}
err = stream.Close()
if err != nil { if err != nil {
return return
} }