From 947f029a4a0b1f364bdc0519695bfc3848a55360 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Fri, 25 Nov 2022 17:15:45 +0800 Subject: [PATCH] chore: split tuic's tcp and udp client --- adapter/outbound/tuic.go | 22 ++++++++++++++++------ transport/tuic/client.go | 20 +++++++++++++------- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/adapter/outbound/tuic.go b/adapter/outbound/tuic.go index 502a751d..57b761c3 100644 --- a/adapter/outbound/tuic.go +++ b/adapter/outbound/tuic.go @@ -24,7 +24,7 @@ import ( type Tuic struct { *Base - getClient func(opts ...dialer.Option) *tuic.Client + getClient func(udp bool, opts ...dialer.Option) *tuic.Client } type TuicOption struct { @@ -55,7 +55,7 @@ type TuicOption struct { // DialContext implements C.ProxyAdapter func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { opts = t.Base.DialOptions(opts...) - conn, err := t.getClient(opts...).DialContext(ctx, metadata, func(ctx context.Context) (net.PacketConn, net.Addr, error) { + conn, err := t.getClient(false, opts...).DialContext(ctx, metadata, func(ctx context.Context) (net.PacketConn, net.Addr, error) { pc, err := dialer.ListenPacket(ctx, "udp", "", opts...) if err != nil { return nil, nil, err @@ -75,7 +75,7 @@ func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...di // ListenPacketContext implements C.ProxyAdapter func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { opts = t.Base.DialOptions(opts...) - pc, err := t.getClient(opts...).ListenPacketContext(ctx, metadata, func(ctx context.Context) (net.PacketConn, net.Addr, error) { + pc, err := t.getClient(true, opts...).ListenPacketContext(ctx, metadata, func(ctx context.Context) (net.PacketConn, net.Addr, error) { pc, err := dialer.ListenPacket(ctx, "udp", "", opts...) if err != nil { return nil, nil, err @@ -184,9 +184,18 @@ func NewTuic(option TuicOption) (*Tuic, error) { tlsConfig.ServerName = "" } tkn := tuic.GenTKN(option.Token) - clientMap := make(map[any]*tuic.Client) - clientMapMutex := sync.Mutex{} - getClient := func(opts ...dialer.Option) *tuic.Client { + tcpClientMap := make(map[any]*tuic.Client) + tcpClientMapMutex := &sync.Mutex{} + udpClientMap := make(map[any]*tuic.Client) + udpClientMapMutex := &sync.Mutex{} + getClient := func(udp bool, opts ...dialer.Option) *tuic.Client { + clientMap := tcpClientMap + clientMapMutex := tcpClientMapMutex + if udp { + clientMap = udpClientMap + clientMapMutex = udpClientMapMutex + } + o := *dialer.ApplyOptions(opts...) clientMapMutex.Lock() @@ -217,6 +226,7 @@ func NewTuic(option TuicOption) (*Tuic, error) { RequestTimeout: option.RequestTimeout, MaxUdpRelayPacketSize: option.MaxUdpRelayPacketSize, LastVisited: time.Now(), + UDP: udp, } clientMap[o] = client runtime.SetFinalizer(client, closeTuicClient) diff --git a/transport/tuic/client.go b/transport/tuic/client.go index 18e623ab..4293c9c4 100644 --- a/transport/tuic/client.go +++ b/transport/tuic/client.go @@ -32,6 +32,7 @@ type Client struct { MaxUdpRelayPacketSize int LastVisited time.Time + UDP bool quicConn quic.Connection connMutex sync.Mutex @@ -113,9 +114,7 @@ func (t *Client) getQuicConn(ctx context.Context, dialFn func(ctx context.Contex return nil } - go sendAuthentication(quicConn) - - go func(quicConn quic.Connection) (err error) { + parseUDP := func(quicConn quic.Connection) (err error) { defer func() { t.deferQuicConn(quicConn, err) }() @@ -189,7 +188,13 @@ func (t *Client) getQuicConn(ctx context.Context, dialFn func(ctx context.Contex }() } } - }(quicConn) + } + + go sendAuthentication(quicConn) + + if t.UDP { + go parseUDP(quicConn) + } t.quicConn = quicConn return quicConn, nil @@ -226,7 +231,7 @@ func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn f if err != nil { return nil, err } - stream, err := func() (stream quic.Stream, err error) { + stream, err := func() (stream *quicStreamConn, err error) { defer func() { t.deferQuicConn(quicConn, err) }() @@ -235,10 +240,11 @@ func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn f if err != nil { return nil, err } - stream, err = quicConn.OpenStream() + quicStream, err := quicConn.OpenStream() if err != nil { return nil, err } + stream = &quicStreamConn{quicStream, quicConn.LocalAddr(), quicConn.RemoteAddr(), t} _, err = buf.WriteTo(stream) if err != nil { _ = stream.Close() @@ -253,7 +259,7 @@ func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn f if t.RequestTimeout > 0 { _ = stream.SetReadDeadline(time.Now().Add(time.Duration(t.RequestTimeout) * time.Millisecond)) } - conn := N.NewBufferedConn(&quicStreamConn{stream, quicConn.LocalAddr(), quicConn.RemoteAddr(), t}) + conn := N.NewBufferedConn(stream) response, err := ReadResponse(conn) if err != nil { _ = conn.Close()