chore: rebuild tuic client's code

This commit is contained in:
gVisor bot 2022-11-26 23:53:59 +08:00
parent 3afe8226e3
commit d3bfbe06dc
4 changed files with 276 additions and 183 deletions

View file

@ -6,18 +6,14 @@ import (
"crypto/tls"
"encoding/hex"
"encoding/pem"
"errors"
"fmt"
"net"
"os"
"runtime"
"strconv"
"sync"
"time"
"github.com/metacubex/quic-go"
"github.com/Dreamacro/clash/common/generics/list"
"github.com/Dreamacro/clash/component/dialer"
tlsC "github.com/Dreamacro/clash/component/tls"
C "github.com/Dreamacro/clash/constant"
@ -26,9 +22,7 @@ import (
type Tuic struct {
*Base
dialFn func(ctx context.Context, t *Tuic, opts ...dialer.Option) (net.PacketConn, net.Addr, error)
newClient func(udp bool, opts ...dialer.Option) *tuic.Client
getClient func(udp bool, opts ...dialer.Option) *tuic.Client
client *tuic.PoolClient
}
type TuicOption struct {
@ -60,13 +54,7 @@ type TuicOption struct {
// DialContext implements C.ProxyAdapter
func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
opts = t.Base.DialOptions(opts...)
dialFn := func(ctx context.Context) (net.PacketConn, net.Addr, error) {
return t.dialFn(ctx, t, opts...)
}
conn, err := t.getClient(false, opts...).DialContext(ctx, metadata, dialFn)
if errors.Is(err, tuic.TooManyOpenStreams) {
conn, err = t.newClient(false, opts...).DialContext(ctx, metadata, dialFn)
}
conn, err := t.client.DialContext(ctx, metadata, opts...)
if err != nil {
return nil, err
}
@ -76,19 +64,25 @@ func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...di
// ListenPacketContext implements C.ProxyAdapter
func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
opts = t.Base.DialOptions(opts...)
dialFn := func(ctx context.Context) (net.PacketConn, net.Addr, error) {
return t.dialFn(ctx, t, opts...)
}
pc, err := t.getClient(true, opts...).ListenPacketContext(ctx, metadata, dialFn)
if errors.Is(err, tuic.TooManyOpenStreams) {
pc, err = t.newClient(false, opts...).ListenPacketContext(ctx, metadata, dialFn)
}
pc, err := t.client.ListenPacketContext(ctx, metadata, opts...)
if err != nil {
return nil, err
}
return newPacketConn(pc, t), nil
}
func (t *Tuic) dial(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) {
pc, err = dialer.ListenPacket(ctx, "udp", "", opts...)
if err != nil {
return nil, nil, err
}
addr, err = resolveUDPAddrWithPrefer(ctx, "udp", t.addr, t.prefer)
if err != nil {
return nil, nil, err
}
return
}
func NewTuic(option TuicOption) (*Tuic, error) {
addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
serverName := option.Server
@ -192,139 +186,21 @@ func NewTuic(option TuicOption) (*Tuic, error) {
prefer: C.NewDNSPrefer(option.IPVersion),
},
}
type dialResult struct {
pc net.PacketConn
addr net.Addr
err error
clientOption := &tuic.ClientOption{
DialFn: t.dial,
TlsConfig: tlsConfig,
QuicConfig: quicConfig,
Host: host,
Token: tkn,
UdpRelayMode: option.UdpRelayMode,
CongestionController: option.CongestionController,
ReduceRtt: option.ReduceRtt,
RequestTimeout: option.RequestTimeout,
MaxUdpRelayPacketSize: option.MaxUdpRelayPacketSize,
FastOpen: option.FastOpen,
}
dialResultMap := make(map[any]dialResult)
dialResultMutex := &sync.Mutex{}
tcpClients := list.New[*tuic.Client]()
tcpClientsMutex := &sync.Mutex{}
udpClients := list.New[*tuic.Client]()
udpClientsMutex := &sync.Mutex{}
t.dialFn = func(ctx context.Context, t *Tuic, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) {
var o any = *dialer.ApplyOptions(opts...)
dialResultMutex.Lock()
dr, ok := dialResultMap[o]
dialResultMutex.Unlock()
if ok {
return dr.pc, dr.addr, dr.err
}
t.client = tuic.NewClientPool(clientOption)
pc, err = dialer.ListenPacket(ctx, "udp", "", opts...)
if err != nil {
return nil, nil, err
}
addr, err = resolveUDPAddrWithPrefer(ctx, "udp", t.addr, t.prefer)
if err != nil {
return nil, nil, err
}
dr.pc, dr.addr, dr.err = pc, addr, err
dialResultMutex.Lock()
dialResultMap[o] = dr
dialResultMutex.Unlock()
return pc, addr, err
}
closeFn := func(t *Tuic) {
dialResultMutex.Lock()
defer dialResultMutex.Unlock()
for key := range dialResultMap {
pc := dialResultMap[key].pc
if pc != nil {
_ = pc.Close()
}
delete(dialResultMap, key)
}
}
t.newClient = func(udp bool, opts ...dialer.Option) *tuic.Client {
clients := tcpClients
clientsMutex := tcpClientsMutex
if udp {
clients = udpClients
clientsMutex = udpClientsMutex
}
var o any = *dialer.ApplyOptions(opts...)
clientsMutex.Lock()
defer clientsMutex.Unlock()
client := &tuic.Client{
TlsConfig: tlsConfig,
QuicConfig: quicConfig,
Host: host,
Token: tkn,
UdpRelayMode: option.UdpRelayMode,
CongestionController: option.CongestionController,
ReduceRtt: option.ReduceRtt,
RequestTimeout: option.RequestTimeout,
MaxUdpRelayPacketSize: option.MaxUdpRelayPacketSize,
FastOpen: option.FastOpen,
Inference: t,
Key: o,
LastVisited: time.Now(),
UDP: udp,
}
clients.PushFront(client)
runtime.SetFinalizer(client, closeTuicClient)
return client
}
t.getClient = func(udp bool, opts ...dialer.Option) *tuic.Client {
clients := tcpClients
clientsMutex := tcpClientsMutex
if udp {
clients = udpClients
clientsMutex = udpClientsMutex
}
var o any = *dialer.ApplyOptions(opts...)
var bestClient *tuic.Client
func() {
clientsMutex.Lock()
defer clientsMutex.Unlock()
for it := clients.Front(); it != nil; {
client := it.Value
if client == nil {
next := it.Next()
clients.Remove(it)
it = next
continue
}
if client.Key == o {
if bestClient == nil {
bestClient = client
} else {
if client.OpenStreams.Load() < bestClient.OpenStreams.Load() {
bestClient = client
}
}
}
if client.OpenStreams.Load() == 0 && time.Now().Sub(client.LastVisited) > 30*time.Minute {
next := it.Next()
clients.Remove(it)
it = next
continue
}
it = it.Next()
}
}()
if bestClient == nil {
return t.newClient(udp, opts...)
} else {
return bestClient
}
}
runtime.SetFinalizer(t, closeFn)
return t, nil
}
func closeTuicClient(client *tuic.Client) {
client.Close(tuic.ClientClosed)
}

View file

@ -10,6 +10,7 @@ import (
"math/rand"
"net"
"net/netip"
"runtime"
"sync"
"sync/atomic"
"time"
@ -17,6 +18,7 @@ import (
"github.com/metacubex/quic-go"
N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/component/dialer"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/transport/tuic/congestion"
)
@ -28,7 +30,9 @@ var (
const MaxOpenStreams = 100 - 90
type Client struct {
type ClientOption struct {
DialFn func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error)
TlsConfig *tls.Config
QuicConfig *quic.Config
Host string
@ -39,27 +43,32 @@ type Client struct {
RequestTimeout int
MaxUdpRelayPacketSize int
FastOpen bool
}
Inference any
Key any
LastVisited time.Time
UDP bool
type Client struct {
*ClientOption
udp bool
quicConn quic.Connection
connMutex sync.Mutex
OpenStreams atomic.Int32
openStreams atomic.Int32
udpInputMap sync.Map
// only ready for PoolClient
poolRef *PoolClient
optionRef any
lastVisited time.Time
}
func (t *Client) getQuicConn(ctx context.Context, dialFn func(ctx context.Context) (net.PacketConn, net.Addr, error)) (quic.Connection, error) {
func (t *Client) getQuicConn(ctx context.Context) (quic.Connection, error) {
t.connMutex.Lock()
defer t.connMutex.Unlock()
if t.quicConn != nil {
return t.quicConn, nil
}
pc, addr, err := dialFn(ctx)
pc, addr, err := t.DialFn(ctx)
if err != nil {
return nil, err
}
@ -206,7 +215,7 @@ func (t *Client) getQuicConn(ctx context.Context, dialFn func(ctx context.Contex
go sendAuthentication(quicConn)
if t.UDP {
if t.udp {
go parseUDP(quicConn)
}
@ -240,14 +249,14 @@ func (t *Client) Close(err error) {
}
}
func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn func(ctx context.Context) (net.PacketConn, net.Addr, error)) (net.Conn, error) {
quicConn, err := t.getQuicConn(ctx, dialFn)
func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata) (net.Conn, error) {
quicConn, err := t.getQuicConn(ctx)
if err != nil {
return nil, err
}
openStreams := t.OpenStreams.Add(1)
openStreams := t.openStreams.Add(1)
if openStreams >= MaxOpenStreams {
t.OpenStreams.Add(-1)
t.openStreams.Add(-1)
return nil, TooManyOpenStreams
}
stream, err := func() (stream *quicStreamConn, err error) {
@ -354,7 +363,7 @@ func (q *quicStreamConn) Close() error {
func (q *quicStreamConn) close() error {
defer time.AfterFunc(C.DefaultTCPTimeout, func() {
q.client.OpenStreams.Add(-1)
q.client.openStreams.Add(-1)
})
// https://github.com/cloudflare/cloudflared/commit/ed2bac026db46b239699ac5ce4fcf122d7cab2cd
@ -381,14 +390,14 @@ func (q *quicStreamConn) RemoteAddr() net.Addr {
var _ net.Conn = &quicStreamConn{}
func (t *Client) ListenPacketContext(ctx context.Context, metadata *C.Metadata, dialFn func(ctx context.Context) (net.PacketConn, net.Addr, error)) (net.PacketConn, error) {
quicConn, err := t.getQuicConn(ctx, dialFn)
func (t *Client) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (net.PacketConn, error) {
quicConn, err := t.getQuicConn(ctx)
if err != nil {
return nil, err
}
openStreams := t.OpenStreams.Add(1)
openStreams := t.openStreams.Add(1)
if openStreams >= MaxOpenStreams {
t.OpenStreams.Add(-1)
t.openStreams.Add(-1)
return nil, TooManyOpenStreams
}
@ -442,7 +451,7 @@ func (q *quicStreamPacketConn) Close() error {
func (q *quicStreamPacketConn) close() (err error) {
defer time.AfterFunc(C.DefaultTCPTimeout, func() {
q.client.OpenStreams.Add(-1)
q.client.openStreams.Add(-1)
})
defer func() {
q.client.deferQuicConn(q.quicConn, err)
@ -539,3 +548,16 @@ func (q *quicStreamPacketConn) LocalAddr() net.Addr {
}
var _ net.PacketConn = &quicStreamPacketConn{}
func NewClient(clientOption *ClientOption, udp bool) *Client {
c := &Client{
ClientOption: clientOption,
udp: udp,
}
runtime.SetFinalizer(c, closeClient)
return c
}
func closeClient(client *Client) {
client.Close(ClientClosed)
}

View file

@ -0,0 +1,177 @@
package tuic
import (
"context"
"errors"
"net"
"runtime"
"sync"
"time"
"github.com/Dreamacro/clash/common/generics/list"
"github.com/Dreamacro/clash/component/dialer"
C "github.com/Dreamacro/clash/constant"
)
type dialResult struct {
pc net.PacketConn
addr net.Addr
err error
}
type PoolClient struct {
*ClientOption
dialResultMap map[any]dialResult
dialResultMutex *sync.Mutex
tcpClients *list.List[*Client]
tcpClientsMutex *sync.Mutex
udpClients *list.List[*Client]
udpClientsMutex *sync.Mutex
}
func (t *PoolClient) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (net.Conn, error) {
conn, err := t.getClient(false, opts...).DialContext(ctx, metadata)
if errors.Is(err, TooManyOpenStreams) {
conn, err = t.newClient(false, opts...).DialContext(ctx, metadata)
}
if err != nil {
return nil, err
}
return conn, err
}
func (t *PoolClient) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (net.PacketConn, error) {
pc, err := t.getClient(true, opts...).ListenPacketContext(ctx, metadata)
if errors.Is(err, TooManyOpenStreams) {
pc, err = t.newClient(false, opts...).ListenPacketContext(ctx, metadata)
}
if err != nil {
return nil, err
}
return pc, nil
}
func (t *PoolClient) dial(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) {
var o any = *dialer.ApplyOptions(opts...)
t.dialResultMutex.Lock()
dr, ok := t.dialResultMap[o]
t.dialResultMutex.Unlock()
if ok {
return dr.pc, dr.addr, dr.err
}
pc, addr, err = t.DialFn(ctx, opts...)
if err != nil {
return nil, nil, err
}
dr.pc, dr.addr, dr.err = pc, addr, err
t.dialResultMutex.Lock()
t.dialResultMap[o] = dr
t.dialResultMutex.Unlock()
return pc, addr, err
}
func (t *PoolClient) Close() {
t.dialResultMutex.Lock()
defer t.dialResultMutex.Unlock()
for key := range t.dialResultMap {
pc := t.dialResultMap[key].pc
if pc != nil {
_ = pc.Close()
}
delete(t.dialResultMap, key)
}
}
func (t *PoolClient) newClient(udp bool, opts ...dialer.Option) *Client {
clients := t.tcpClients
clientsMutex := t.tcpClientsMutex
if udp {
clients = t.udpClients
clientsMutex = t.udpClientsMutex
}
var o any = *dialer.ApplyOptions(opts...)
clientsMutex.Lock()
defer clientsMutex.Unlock()
client := NewClient(t.ClientOption, udp)
client.poolRef = t // make sure pool has a reference
client.optionRef = o
client.lastVisited = time.Now()
clients.PushFront(client)
return client
}
func (t *PoolClient) getClient(udp bool, opts ...dialer.Option) *Client {
clients := t.tcpClients
clientsMutex := t.tcpClientsMutex
if udp {
clients = t.udpClients
clientsMutex = t.udpClientsMutex
}
var o any = *dialer.ApplyOptions(opts...)
var bestClient *Client
func() {
clientsMutex.Lock()
defer clientsMutex.Unlock()
for it := clients.Front(); it != nil; {
client := it.Value
if client == nil {
next := it.Next()
clients.Remove(it)
it = next
continue
}
if client.optionRef == o {
if bestClient == nil {
bestClient = client
} else {
if client.openStreams.Load() < bestClient.openStreams.Load() {
bestClient = client
}
}
}
if client.openStreams.Load() == 0 && time.Now().Sub(client.lastVisited) > 30*time.Minute {
next := it.Next()
clients.Remove(it)
it = next
continue
}
it = it.Next()
}
}()
if bestClient == nil {
return t.newClient(udp, opts...)
} else {
bestClient.lastVisited = time.Now()
return bestClient
}
}
func NewClientPool(clientOption *ClientOption) *PoolClient {
p := &PoolClient{
ClientOption: clientOption,
dialResultMap: make(map[any]dialResult),
dialResultMutex: &sync.Mutex{},
tcpClients: list.New[*Client](),
tcpClientsMutex: &sync.Mutex{},
udpClients: list.New[*Client](),
udpClientsMutex: &sync.Mutex{},
}
runtime.SetFinalizer(p, closeClientPool)
return p
}
func closeClientPool(client *PoolClient) {
client.Close()
}

View file

@ -178,8 +178,8 @@ func NewPacket(ASSOC_ID uint32, LEN uint16, ADDR Address, DATA []byte) Packet {
}
}
func ReadPacket(reader BufferedReader) (c Packet, err error) {
c.CommandHead, err = ReadCommandHead(reader)
func ReadPacketWithHead(head CommandHead, reader BufferedReader) (c Packet, err error) {
c.CommandHead = head
if err != nil {
return
}
@ -206,6 +206,14 @@ func ReadPacket(reader BufferedReader) (c Packet, err error) {
return
}
func ReadPacket(reader BufferedReader) (c Packet, err error) {
head, err := ReadCommandHead(reader)
if err != nil {
return
}
return ReadPacketWithHead(head, reader)
}
func (c Packet) WriteTo(writer BufferedWriter) (err error) {
err = c.CommandHead.WriteTo(writer)
if err != nil {
@ -272,17 +280,22 @@ func NewHeartbeat() Heartbeat {
}
}
func ReadHeartbeat(reader BufferedReader) (c Response, err error) {
c.CommandHead, err = ReadCommandHead(reader)
if err != nil {
return
}
func ReadHeartbeatWithHead(head CommandHead, reader BufferedReader) (c Response, err error) {
c.CommandHead = head
if c.CommandHead.TYPE != HeartbeatType {
err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE)
}
return
}
func ReadHeartbeat(reader BufferedReader) (c Response, err error) {
head, err := ReadCommandHead(reader)
if err != nil {
return
}
return ReadHeartbeatWithHead(head, reader)
}
type Response struct {
CommandHead
REP byte
@ -295,11 +308,8 @@ func NewResponse(REP byte) Response {
}
}
func ReadResponse(reader BufferedReader) (c Response, err error) {
c.CommandHead, err = ReadCommandHead(reader)
if err != nil {
return
}
func ReadResponseWithHead(head CommandHead, reader BufferedReader) (c Response, err error) {
c.CommandHead = head
if c.CommandHead.TYPE != ResponseType {
err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE)
}
@ -310,6 +320,14 @@ func ReadResponse(reader BufferedReader) (c Response, err error) {
return
}
func ReadResponse(reader BufferedReader) (c Response, err error) {
head, err := ReadCommandHead(reader)
if err != nil {
return
}
return ReadResponseWithHead(head, reader)
}
func (c Response) WriteTo(writer BufferedWriter) (err error) {
err = c.CommandHead.WriteTo(writer)
if err != nil {