diff --git a/adapter/outbound/tuic.go b/adapter/outbound/tuic.go index c480bede..58c79fa1 100644 --- a/adapter/outbound/tuic.go +++ b/adapter/outbound/tuic.go @@ -6,18 +6,14 @@ import ( "crypto/tls" "encoding/hex" "encoding/pem" - "errors" "fmt" "net" "os" - "runtime" "strconv" - "sync" "time" "github.com/metacubex/quic-go" - "github.com/Dreamacro/clash/common/generics/list" "github.com/Dreamacro/clash/component/dialer" tlsC "github.com/Dreamacro/clash/component/tls" C "github.com/Dreamacro/clash/constant" @@ -26,9 +22,7 @@ import ( type Tuic struct { *Base - dialFn func(ctx context.Context, t *Tuic, opts ...dialer.Option) (net.PacketConn, net.Addr, error) - newClient func(udp bool, opts ...dialer.Option) *tuic.Client - getClient func(udp bool, opts ...dialer.Option) *tuic.Client + client *tuic.PoolClient } type TuicOption struct { @@ -60,13 +54,7 @@ type TuicOption struct { // DialContext implements C.ProxyAdapter func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { opts = t.Base.DialOptions(opts...) - dialFn := func(ctx context.Context) (net.PacketConn, net.Addr, error) { - return t.dialFn(ctx, t, opts...) - } - conn, err := t.getClient(false, opts...).DialContext(ctx, metadata, dialFn) - if errors.Is(err, tuic.TooManyOpenStreams) { - conn, err = t.newClient(false, opts...).DialContext(ctx, metadata, dialFn) - } + conn, err := t.client.DialContext(ctx, metadata, opts...) if err != nil { return nil, err } @@ -76,19 +64,25 @@ func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...di // ListenPacketContext implements C.ProxyAdapter func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { opts = t.Base.DialOptions(opts...) - dialFn := func(ctx context.Context) (net.PacketConn, net.Addr, error) { - return t.dialFn(ctx, t, opts...) - } - pc, err := t.getClient(true, opts...).ListenPacketContext(ctx, metadata, dialFn) - if errors.Is(err, tuic.TooManyOpenStreams) { - pc, err = t.newClient(false, opts...).ListenPacketContext(ctx, metadata, dialFn) - } + pc, err := t.client.ListenPacketContext(ctx, metadata, opts...) if err != nil { return nil, err } return newPacketConn(pc, t), nil } +func (t *Tuic) dial(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) { + pc, err = dialer.ListenPacket(ctx, "udp", "", opts...) + if err != nil { + return nil, nil, err + } + addr, err = resolveUDPAddrWithPrefer(ctx, "udp", t.addr, t.prefer) + if err != nil { + return nil, nil, err + } + return +} + func NewTuic(option TuicOption) (*Tuic, error) { addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) serverName := option.Server @@ -192,139 +186,21 @@ func NewTuic(option TuicOption) (*Tuic, error) { prefer: C.NewDNSPrefer(option.IPVersion), }, } - - type dialResult struct { - pc net.PacketConn - addr net.Addr - err error + clientOption := &tuic.ClientOption{ + DialFn: t.dial, + TlsConfig: tlsConfig, + QuicConfig: quicConfig, + Host: host, + Token: tkn, + UdpRelayMode: option.UdpRelayMode, + CongestionController: option.CongestionController, + ReduceRtt: option.ReduceRtt, + RequestTimeout: option.RequestTimeout, + MaxUdpRelayPacketSize: option.MaxUdpRelayPacketSize, + FastOpen: option.FastOpen, } - dialResultMap := make(map[any]dialResult) - dialResultMutex := &sync.Mutex{} - tcpClients := list.New[*tuic.Client]() - tcpClientsMutex := &sync.Mutex{} - udpClients := list.New[*tuic.Client]() - udpClientsMutex := &sync.Mutex{} - t.dialFn = func(ctx context.Context, t *Tuic, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) { - var o any = *dialer.ApplyOptions(opts...) - dialResultMutex.Lock() - dr, ok := dialResultMap[o] - dialResultMutex.Unlock() - if ok { - return dr.pc, dr.addr, dr.err - } + t.client = tuic.NewClientPool(clientOption) - pc, err = dialer.ListenPacket(ctx, "udp", "", opts...) - if err != nil { - return nil, nil, err - } - addr, err = resolveUDPAddrWithPrefer(ctx, "udp", t.addr, t.prefer) - if err != nil { - return nil, nil, err - } - - dr.pc, dr.addr, dr.err = pc, addr, err - - dialResultMutex.Lock() - dialResultMap[o] = dr - dialResultMutex.Unlock() - return pc, addr, err - } - closeFn := func(t *Tuic) { - dialResultMutex.Lock() - defer dialResultMutex.Unlock() - for key := range dialResultMap { - pc := dialResultMap[key].pc - if pc != nil { - _ = pc.Close() - } - delete(dialResultMap, key) - } - } - t.newClient = func(udp bool, opts ...dialer.Option) *tuic.Client { - clients := tcpClients - clientsMutex := tcpClientsMutex - if udp { - clients = udpClients - clientsMutex = udpClientsMutex - } - - var o any = *dialer.ApplyOptions(opts...) - - clientsMutex.Lock() - defer clientsMutex.Unlock() - - client := &tuic.Client{ - TlsConfig: tlsConfig, - QuicConfig: quicConfig, - Host: host, - Token: tkn, - UdpRelayMode: option.UdpRelayMode, - CongestionController: option.CongestionController, - ReduceRtt: option.ReduceRtt, - RequestTimeout: option.RequestTimeout, - MaxUdpRelayPacketSize: option.MaxUdpRelayPacketSize, - FastOpen: option.FastOpen, - Inference: t, - Key: o, - LastVisited: time.Now(), - UDP: udp, - } - clients.PushFront(client) - runtime.SetFinalizer(client, closeTuicClient) - return client - } - t.getClient = func(udp bool, opts ...dialer.Option) *tuic.Client { - clients := tcpClients - clientsMutex := tcpClientsMutex - if udp { - clients = udpClients - clientsMutex = udpClientsMutex - } - - var o any = *dialer.ApplyOptions(opts...) - var bestClient *tuic.Client - - func() { - clientsMutex.Lock() - defer clientsMutex.Unlock() - for it := clients.Front(); it != nil; { - client := it.Value - if client == nil { - next := it.Next() - clients.Remove(it) - it = next - continue - } - if client.Key == o { - if bestClient == nil { - bestClient = client - } else { - if client.OpenStreams.Load() < bestClient.OpenStreams.Load() { - bestClient = client - } - } - } - if client.OpenStreams.Load() == 0 && time.Now().Sub(client.LastVisited) > 30*time.Minute { - next := it.Next() - clients.Remove(it) - it = next - continue - } - it = it.Next() - } - }() - - if bestClient == nil { - return t.newClient(udp, opts...) - } else { - return bestClient - } - } - runtime.SetFinalizer(t, closeFn) return t, nil } - -func closeTuicClient(client *tuic.Client) { - client.Close(tuic.ClientClosed) -} diff --git a/transport/tuic/client.go b/transport/tuic/client.go index 95aeeac3..f4eaf913 100644 --- a/transport/tuic/client.go +++ b/transport/tuic/client.go @@ -10,6 +10,7 @@ import ( "math/rand" "net" "net/netip" + "runtime" "sync" "sync/atomic" "time" @@ -17,6 +18,7 @@ import ( "github.com/metacubex/quic-go" N "github.com/Dreamacro/clash/common/net" + "github.com/Dreamacro/clash/component/dialer" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/transport/tuic/congestion" ) @@ -28,7 +30,9 @@ var ( const MaxOpenStreams = 100 - 90 -type Client struct { +type ClientOption struct { + DialFn func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) + TlsConfig *tls.Config QuicConfig *quic.Config Host string @@ -39,27 +43,32 @@ type Client struct { RequestTimeout int MaxUdpRelayPacketSize int FastOpen bool +} - Inference any - Key any - LastVisited time.Time - UDP bool +type Client struct { + *ClientOption + udp bool quicConn quic.Connection connMutex sync.Mutex - OpenStreams atomic.Int32 + openStreams atomic.Int32 udpInputMap sync.Map + + // only ready for PoolClient + poolRef *PoolClient + optionRef any + lastVisited time.Time } -func (t *Client) getQuicConn(ctx context.Context, dialFn func(ctx context.Context) (net.PacketConn, net.Addr, error)) (quic.Connection, error) { +func (t *Client) getQuicConn(ctx context.Context) (quic.Connection, error) { t.connMutex.Lock() defer t.connMutex.Unlock() if t.quicConn != nil { return t.quicConn, nil } - pc, addr, err := dialFn(ctx) + pc, addr, err := t.DialFn(ctx) if err != nil { return nil, err } @@ -206,7 +215,7 @@ func (t *Client) getQuicConn(ctx context.Context, dialFn func(ctx context.Contex go sendAuthentication(quicConn) - if t.UDP { + if t.udp { go parseUDP(quicConn) } @@ -240,14 +249,14 @@ func (t *Client) Close(err error) { } } -func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn func(ctx context.Context) (net.PacketConn, net.Addr, error)) (net.Conn, error) { - quicConn, err := t.getQuicConn(ctx, dialFn) +func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata) (net.Conn, error) { + quicConn, err := t.getQuicConn(ctx) if err != nil { return nil, err } - openStreams := t.OpenStreams.Add(1) + openStreams := t.openStreams.Add(1) if openStreams >= MaxOpenStreams { - t.OpenStreams.Add(-1) + t.openStreams.Add(-1) return nil, TooManyOpenStreams } stream, err := func() (stream *quicStreamConn, err error) { @@ -354,7 +363,7 @@ func (q *quicStreamConn) Close() error { func (q *quicStreamConn) close() error { defer time.AfterFunc(C.DefaultTCPTimeout, func() { - q.client.OpenStreams.Add(-1) + q.client.openStreams.Add(-1) }) // https://github.com/cloudflare/cloudflared/commit/ed2bac026db46b239699ac5ce4fcf122d7cab2cd @@ -381,14 +390,14 @@ func (q *quicStreamConn) RemoteAddr() net.Addr { var _ net.Conn = &quicStreamConn{} -func (t *Client) ListenPacketContext(ctx context.Context, metadata *C.Metadata, dialFn func(ctx context.Context) (net.PacketConn, net.Addr, error)) (net.PacketConn, error) { - quicConn, err := t.getQuicConn(ctx, dialFn) +func (t *Client) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (net.PacketConn, error) { + quicConn, err := t.getQuicConn(ctx) if err != nil { return nil, err } - openStreams := t.OpenStreams.Add(1) + openStreams := t.openStreams.Add(1) if openStreams >= MaxOpenStreams { - t.OpenStreams.Add(-1) + t.openStreams.Add(-1) return nil, TooManyOpenStreams } @@ -442,7 +451,7 @@ func (q *quicStreamPacketConn) Close() error { func (q *quicStreamPacketConn) close() (err error) { defer time.AfterFunc(C.DefaultTCPTimeout, func() { - q.client.OpenStreams.Add(-1) + q.client.openStreams.Add(-1) }) defer func() { q.client.deferQuicConn(q.quicConn, err) @@ -539,3 +548,16 @@ func (q *quicStreamPacketConn) LocalAddr() net.Addr { } var _ net.PacketConn = &quicStreamPacketConn{} + +func NewClient(clientOption *ClientOption, udp bool) *Client { + c := &Client{ + ClientOption: clientOption, + udp: udp, + } + runtime.SetFinalizer(c, closeClient) + return c +} + +func closeClient(client *Client) { + client.Close(ClientClosed) +} diff --git a/transport/tuic/pool_client.go b/transport/tuic/pool_client.go new file mode 100644 index 00000000..19ebc4eb --- /dev/null +++ b/transport/tuic/pool_client.go @@ -0,0 +1,177 @@ +package tuic + +import ( + "context" + "errors" + "net" + "runtime" + "sync" + "time" + + "github.com/Dreamacro/clash/common/generics/list" + "github.com/Dreamacro/clash/component/dialer" + C "github.com/Dreamacro/clash/constant" +) + +type dialResult struct { + pc net.PacketConn + addr net.Addr + err error +} + +type PoolClient struct { + *ClientOption + + dialResultMap map[any]dialResult + dialResultMutex *sync.Mutex + tcpClients *list.List[*Client] + tcpClientsMutex *sync.Mutex + udpClients *list.List[*Client] + udpClientsMutex *sync.Mutex +} + +func (t *PoolClient) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (net.Conn, error) { + conn, err := t.getClient(false, opts...).DialContext(ctx, metadata) + if errors.Is(err, TooManyOpenStreams) { + conn, err = t.newClient(false, opts...).DialContext(ctx, metadata) + } + if err != nil { + return nil, err + } + return conn, err +} + +func (t *PoolClient) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (net.PacketConn, error) { + pc, err := t.getClient(true, opts...).ListenPacketContext(ctx, metadata) + if errors.Is(err, TooManyOpenStreams) { + pc, err = t.newClient(false, opts...).ListenPacketContext(ctx, metadata) + } + if err != nil { + return nil, err + } + return pc, nil +} + +func (t *PoolClient) dial(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) { + var o any = *dialer.ApplyOptions(opts...) + + t.dialResultMutex.Lock() + dr, ok := t.dialResultMap[o] + t.dialResultMutex.Unlock() + if ok { + return dr.pc, dr.addr, dr.err + } + + pc, addr, err = t.DialFn(ctx, opts...) + if err != nil { + return nil, nil, err + } + + dr.pc, dr.addr, dr.err = pc, addr, err + + t.dialResultMutex.Lock() + t.dialResultMap[o] = dr + t.dialResultMutex.Unlock() + return pc, addr, err +} + +func (t *PoolClient) Close() { + t.dialResultMutex.Lock() + defer t.dialResultMutex.Unlock() + for key := range t.dialResultMap { + pc := t.dialResultMap[key].pc + if pc != nil { + _ = pc.Close() + } + delete(t.dialResultMap, key) + } +} + +func (t *PoolClient) newClient(udp bool, opts ...dialer.Option) *Client { + clients := t.tcpClients + clientsMutex := t.tcpClientsMutex + if udp { + clients = t.udpClients + clientsMutex = t.udpClientsMutex + } + + var o any = *dialer.ApplyOptions(opts...) + + clientsMutex.Lock() + defer clientsMutex.Unlock() + + client := NewClient(t.ClientOption, udp) + client.poolRef = t // make sure pool has a reference + client.optionRef = o + client.lastVisited = time.Now() + + clients.PushFront(client) + return client +} + +func (t *PoolClient) getClient(udp bool, opts ...dialer.Option) *Client { + clients := t.tcpClients + clientsMutex := t.tcpClientsMutex + if udp { + clients = t.udpClients + clientsMutex = t.udpClientsMutex + } + + var o any = *dialer.ApplyOptions(opts...) + var bestClient *Client + + func() { + clientsMutex.Lock() + defer clientsMutex.Unlock() + for it := clients.Front(); it != nil; { + client := it.Value + if client == nil { + next := it.Next() + clients.Remove(it) + it = next + continue + } + if client.optionRef == o { + if bestClient == nil { + bestClient = client + } else { + if client.openStreams.Load() < bestClient.openStreams.Load() { + bestClient = client + } + } + } + if client.openStreams.Load() == 0 && time.Now().Sub(client.lastVisited) > 30*time.Minute { + next := it.Next() + clients.Remove(it) + it = next + continue + } + it = it.Next() + } + }() + + if bestClient == nil { + return t.newClient(udp, opts...) + } else { + bestClient.lastVisited = time.Now() + return bestClient + } +} + +func NewClientPool(clientOption *ClientOption) *PoolClient { + p := &PoolClient{ + ClientOption: clientOption, + dialResultMap: make(map[any]dialResult), + dialResultMutex: &sync.Mutex{}, + tcpClients: list.New[*Client](), + tcpClientsMutex: &sync.Mutex{}, + udpClients: list.New[*Client](), + udpClientsMutex: &sync.Mutex{}, + } + runtime.SetFinalizer(p, closeClientPool) + return p +} + +func closeClientPool(client *PoolClient) { + client.Close() +} diff --git a/transport/tuic/protocol.go b/transport/tuic/protocol.go index 913f0d51..98f7cd96 100644 --- a/transport/tuic/protocol.go +++ b/transport/tuic/protocol.go @@ -178,8 +178,8 @@ func NewPacket(ASSOC_ID uint32, LEN uint16, ADDR Address, DATA []byte) Packet { } } -func ReadPacket(reader BufferedReader) (c Packet, err error) { - c.CommandHead, err = ReadCommandHead(reader) +func ReadPacketWithHead(head CommandHead, reader BufferedReader) (c Packet, err error) { + c.CommandHead = head if err != nil { return } @@ -206,6 +206,14 @@ func ReadPacket(reader BufferedReader) (c Packet, err error) { return } +func ReadPacket(reader BufferedReader) (c Packet, err error) { + head, err := ReadCommandHead(reader) + if err != nil { + return + } + return ReadPacketWithHead(head, reader) +} + func (c Packet) WriteTo(writer BufferedWriter) (err error) { err = c.CommandHead.WriteTo(writer) if err != nil { @@ -272,17 +280,22 @@ func NewHeartbeat() Heartbeat { } } -func ReadHeartbeat(reader BufferedReader) (c Response, err error) { - c.CommandHead, err = ReadCommandHead(reader) - if err != nil { - return - } +func ReadHeartbeatWithHead(head CommandHead, reader BufferedReader) (c Response, err error) { + c.CommandHead = head if c.CommandHead.TYPE != HeartbeatType { err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE) } return } +func ReadHeartbeat(reader BufferedReader) (c Response, err error) { + head, err := ReadCommandHead(reader) + if err != nil { + return + } + return ReadHeartbeatWithHead(head, reader) +} + type Response struct { CommandHead REP byte @@ -295,11 +308,8 @@ func NewResponse(REP byte) Response { } } -func ReadResponse(reader BufferedReader) (c Response, err error) { - c.CommandHead, err = ReadCommandHead(reader) - if err != nil { - return - } +func ReadResponseWithHead(head CommandHead, reader BufferedReader) (c Response, err error) { + c.CommandHead = head if c.CommandHead.TYPE != ResponseType { err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE) } @@ -310,6 +320,14 @@ func ReadResponse(reader BufferedReader) (c Response, err error) { return } +func ReadResponse(reader BufferedReader) (c Response, err error) { + head, err := ReadCommandHead(reader) + if err != nil { + return + } + return ReadResponseWithHead(head, reader) +} + func (c Response) WriteTo(writer BufferedWriter) (err error) { err = c.CommandHead.WriteTo(writer) if err != nil {