chore: tuic use a simple client pool

This commit is contained in:
wwqgtxx 2022-11-25 20:14:05 +08:00
parent c7bad89af3
commit 7b44cde4bd
2 changed files with 102 additions and 57 deletions

View file

@ -17,6 +17,7 @@ import (
"github.com/metacubex/quic-go" "github.com/metacubex/quic-go"
"github.com/Dreamacro/clash/common/generics/list"
"github.com/Dreamacro/clash/component/dialer" "github.com/Dreamacro/clash/component/dialer"
tlsC "github.com/Dreamacro/clash/component/tls" tlsC "github.com/Dreamacro/clash/component/tls"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
@ -25,8 +26,8 @@ import (
type Tuic struct { type Tuic struct {
*Base *Base
newClient func(udp bool, opts ...dialer.Option) *tuic.Client
getClient func(udp bool, opts ...dialer.Option) *tuic.Client getClient func(udp bool, opts ...dialer.Option) *tuic.Client
removeClient func(udp bool, opts ...dialer.Option)
} }
type TuicOption struct { type TuicOption struct {
@ -57,7 +58,7 @@ type TuicOption struct {
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
opts = t.Base.DialOptions(opts...) opts = t.Base.DialOptions(opts...)
conn, err := t.getClient(false, opts...).DialContext(ctx, metadata, func(ctx context.Context) (net.PacketConn, net.Addr, error) { dialFn := func(ctx context.Context) (net.PacketConn, net.Addr, error) {
pc, err := dialer.ListenPacket(ctx, "udp", "", opts...) pc, err := dialer.ListenPacket(ctx, "udp", "", opts...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -67,12 +68,12 @@ func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...di
return nil, nil, err return nil, nil, err
} }
return pc, addr, err return pc, addr, err
})
if err != nil {
if errors.Is(err, tuic.TooManyOpenStreams) {
t.removeClient(false, opts...)
return t.DialContext(ctx, metadata, opts...)
} }
conn, err := t.getClient(false, opts...).DialContext(ctx, metadata, dialFn)
if errors.Is(err, tuic.TooManyOpenStreams) {
conn, err = t.newClient(false, opts...).DialContext(ctx, metadata, dialFn)
}
if err != nil {
return nil, err return nil, err
} }
return NewConn(conn, t), err return NewConn(conn, t), err
@ -81,7 +82,7 @@ func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...di
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
opts = t.Base.DialOptions(opts...) opts = t.Base.DialOptions(opts...)
pc, err := t.getClient(true, opts...).ListenPacketContext(ctx, metadata, func(ctx context.Context) (net.PacketConn, net.Addr, error) { dialFn := func(ctx context.Context) (net.PacketConn, net.Addr, error) {
pc, err := dialer.ListenPacket(ctx, "udp", "", opts...) pc, err := dialer.ListenPacket(ctx, "udp", "", opts...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -91,12 +92,12 @@ func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata, op
return nil, nil, err return nil, nil, err
} }
return pc, addr, err return pc, addr, err
})
if err != nil {
if errors.Is(err, tuic.TooManyOpenStreams) {
t.removeClient(true, opts...)
return t.ListenPacketContext(ctx, metadata, opts...)
} }
pc, err := t.getClient(true, opts...).ListenPacketContext(ctx, metadata, dialFn)
if errors.Is(err, tuic.TooManyOpenStreams) {
pc, err = t.newClient(false, opts...).ListenPacketContext(ctx, metadata, dialFn)
}
if err != nil {
return nil, err return nil, err
} }
return newPacketConn(pc, t), nil return newPacketConn(pc, t), nil
@ -194,37 +195,23 @@ func NewTuic(option TuicOption) (*Tuic, error) {
tlsConfig.ServerName = "" tlsConfig.ServerName = ""
} }
tkn := tuic.GenTKN(option.Token) tkn := tuic.GenTKN(option.Token)
tcpClientMap := make(map[any]*tuic.Client) tcpClients := list.New[*tuic.Client]()
tcpClientMapMutex := &sync.Mutex{} tcpClientsMutex := &sync.Mutex{}
udpClientMap := make(map[any]*tuic.Client) udpClients := list.New[*tuic.Client]()
udpClientMapMutex := &sync.Mutex{} udpClientsMutex := &sync.Mutex{}
getClient := func(udp bool, opts ...dialer.Option) *tuic.Client { newClient := func(udp bool, opts ...dialer.Option) *tuic.Client {
clientMap := tcpClientMap clients := tcpClients
clientMapMutex := tcpClientMapMutex clientsMutex := tcpClientsMutex
if udp { if udp {
clientMap = udpClientMap clients = udpClients
clientMapMutex = udpClientMapMutex clientsMutex = udpClientsMutex
} }
o := *dialer.ApplyOptions(opts...) var o any = *dialer.ApplyOptions(opts...)
clientsMutex.Lock()
defer clientsMutex.Unlock()
clientMapMutex.Lock()
defer clientMapMutex.Unlock()
for key := range clientMap {
client := clientMap[key]
if client == nil {
delete(clientMap, key) // It is safe in Golang
continue
}
if key == o {
client.LastVisited = time.Now()
return client
}
if time.Now().Sub(client.LastVisited) > 30*time.Minute {
delete(clientMap, key)
continue
}
}
client := &tuic.Client{ client := &tuic.Client{
TlsConfig: tlsConfig, TlsConfig: tlsConfig,
QuicConfig: quicConfig, QuicConfig: quicConfig,
@ -235,26 +222,60 @@ func NewTuic(option TuicOption) (*Tuic, error) {
ReduceRtt: option.ReduceRtt, ReduceRtt: option.ReduceRtt,
RequestTimeout: option.RequestTimeout, RequestTimeout: option.RequestTimeout,
MaxUdpRelayPacketSize: option.MaxUdpRelayPacketSize, MaxUdpRelayPacketSize: option.MaxUdpRelayPacketSize,
Key: o,
LastVisited: time.Now(), LastVisited: time.Now(),
UDP: udp, UDP: udp,
} }
clientMap[o] = client clients.PushFront(client)
runtime.SetFinalizer(client, closeTuicClient) runtime.SetFinalizer(client, closeTuicClient)
return client return client
} }
removeClient := func(udp bool, opts ...dialer.Option) { getClient := func(udp bool, opts ...dialer.Option) *tuic.Client {
clientMap := tcpClientMap clients := tcpClients
clientMapMutex := tcpClientMapMutex clientsMutex := tcpClientsMutex
if udp { if udp {
clientMap = udpClientMap clients = udpClients
clientMapMutex = udpClientMapMutex clientsMutex = udpClientsMutex
} }
o := *dialer.ApplyOptions(opts...) var o any = *dialer.ApplyOptions(opts...)
var bestClient *tuic.Client
clientMapMutex.Lock() func() {
defer clientMapMutex.Unlock() clientsMutex.Lock()
delete(clientMap, o) defer clientsMutex.Unlock()
for it := clients.Front(); it != nil; {
client := it.Value
if client == nil {
next := it.Next()
clients.Remove(it)
it = next
continue
}
if client.Key == o {
if bestClient == nil {
bestClient = client
} else {
if client.OpenStreams.Load() < bestClient.OpenStreams.Load() {
bestClient = client
}
}
}
if time.Now().Sub(client.LastVisited) > 30*time.Minute {
next := it.Next()
clients.Remove(it)
it = next
continue
}
it = it.Next()
}
}()
if bestClient == nil {
return newClient(udp, opts...)
} else {
return bestClient
}
} }
return &Tuic{ return &Tuic{
@ -266,8 +287,8 @@ func NewTuic(option TuicOption) (*Tuic, error) {
iface: option.Interface, iface: option.Interface,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: C.NewDNSPrefer(option.IPVersion),
}, },
newClient: newClient,
getClient: getClient, getClient: getClient,
removeClient: removeClient,
}, nil }, nil
} }

View file

@ -39,13 +39,14 @@ type Client struct {
RequestTimeout int RequestTimeout int
MaxUdpRelayPacketSize int MaxUdpRelayPacketSize int
Key any
LastVisited time.Time LastVisited time.Time
UDP bool UDP bool
quicConn quic.Connection quicConn quic.Connection
connMutex sync.Mutex connMutex sync.Mutex
openStreams atomic.Int32 OpenStreams atomic.Int32
udpInputMap sync.Map udpInputMap sync.Map
} }
@ -242,9 +243,9 @@ func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn f
if err != nil { if err != nil {
return nil, err return nil, err
} }
openStreams := t.openStreams.Add(1) openStreams := t.OpenStreams.Add(1)
if openStreams >= MaxOpenStreams { if openStreams >= MaxOpenStreams {
t.openStreams.Add(-1) t.OpenStreams.Add(-1)
return nil, TooManyOpenStreams return nil, TooManyOpenStreams
} }
stream, err := func() (stream *quicStreamConn, err error) { stream, err := func() (stream *quicStreamConn, err error) {
@ -300,6 +301,9 @@ type quicStreamConn struct {
lAddr net.Addr lAddr net.Addr
rAddr net.Addr rAddr net.Addr
client *Client client *Client
closeOnce sync.Once
closeErr error
} }
func (q *quicStreamConn) Write(p []byte) (n int, err error) { func (q *quicStreamConn) Write(p []byte) (n int, err error) {
@ -309,8 +313,15 @@ func (q *quicStreamConn) Write(p []byte) (n int, err error) {
} }
func (q *quicStreamConn) Close() error { func (q *quicStreamConn) Close() error {
q.closeOnce.Do(func() {
q.closeErr = q.close()
})
return q.closeErr
}
func (q *quicStreamConn) close() error {
defer time.AfterFunc(C.DefaultTCPTimeout, func() { defer time.AfterFunc(C.DefaultTCPTimeout, func() {
q.client.openStreams.Add(-1) q.client.OpenStreams.Add(-1)
}) })
// https://github.com/cloudflare/cloudflared/commit/ed2bac026db46b239699ac5ce4fcf122d7cab2cd // https://github.com/cloudflare/cloudflared/commit/ed2bac026db46b239699ac5ce4fcf122d7cab2cd
@ -342,6 +353,11 @@ func (t *Client) ListenPacketContext(ctx context.Context, metadata *C.Metadata,
if err != nil { if err != nil {
return nil, err return nil, err
} }
openStreams := t.OpenStreams.Add(1)
if openStreams >= MaxOpenStreams {
t.OpenStreams.Add(-1)
return nil, TooManyOpenStreams
}
pipe1, pipe2 := net.Pipe() pipe1, pipe2 := net.Pipe()
inputCh := make(chan udpData) inputCh := make(chan udpData)
@ -380,16 +396,21 @@ type quicStreamPacketConn struct {
closeOnce sync.Once closeOnce sync.Once
closeErr error closeErr error
closed bool
} }
func (q *quicStreamPacketConn) Close() error { func (q *quicStreamPacketConn) Close() error {
q.closeOnce.Do(func() { q.closeOnce.Do(func() {
q.closed = true
q.closeErr = q.close() q.closeErr = q.close()
}) })
return q.closeErr return q.closeErr
} }
func (q *quicStreamPacketConn) close() (err error) { func (q *quicStreamPacketConn) close() (err error) {
defer time.AfterFunc(C.DefaultTCPTimeout, func() {
q.client.OpenStreams.Add(-1)
})
defer func() { defer func() {
q.client.deferQuicConn(q.quicConn, err) q.client.deferQuicConn(q.quicConn, err)
}() }()
@ -441,6 +462,9 @@ func (q *quicStreamPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err erro
if len(p) > q.client.MaxUdpRelayPacketSize { if len(p) > q.client.MaxUdpRelayPacketSize {
return 0, fmt.Errorf("udp packet too large(%d > %d)", len(p), q.client.MaxUdpRelayPacketSize) return 0, fmt.Errorf("udp packet too large(%d > %d)", len(p), q.client.MaxUdpRelayPacketSize)
} }
if q.closed {
return 0, net.ErrClosed
}
defer func() { defer func() {
q.client.deferQuicConn(q.quicConn, err) q.client.deferQuicConn(q.quicConn, err)
}() }()