chore: split tuic's tcp and udp client

This commit is contained in:
gVisor bot 2022-11-25 17:15:45 +08:00
parent 29dd58edaa
commit 947f029a4a
2 changed files with 29 additions and 13 deletions

View file

@ -24,7 +24,7 @@ import (
type Tuic struct { type Tuic struct {
*Base *Base
getClient func(opts ...dialer.Option) *tuic.Client getClient func(udp bool, opts ...dialer.Option) *tuic.Client
} }
type TuicOption struct { type TuicOption struct {
@ -55,7 +55,7 @@ type TuicOption struct {
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
opts = t.Base.DialOptions(opts...) opts = t.Base.DialOptions(opts...)
conn, err := t.getClient(opts...).DialContext(ctx, metadata, func(ctx context.Context) (net.PacketConn, net.Addr, error) { conn, err := t.getClient(false, opts...).DialContext(ctx, metadata, func(ctx context.Context) (net.PacketConn, net.Addr, error) {
pc, err := dialer.ListenPacket(ctx, "udp", "", opts...) pc, err := dialer.ListenPacket(ctx, "udp", "", opts...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -75,7 +75,7 @@ func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...di
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
opts = t.Base.DialOptions(opts...) opts = t.Base.DialOptions(opts...)
pc, err := t.getClient(opts...).ListenPacketContext(ctx, metadata, func(ctx context.Context) (net.PacketConn, net.Addr, error) { pc, err := t.getClient(true, opts...).ListenPacketContext(ctx, metadata, func(ctx context.Context) (net.PacketConn, net.Addr, error) {
pc, err := dialer.ListenPacket(ctx, "udp", "", opts...) pc, err := dialer.ListenPacket(ctx, "udp", "", opts...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -184,9 +184,18 @@ func NewTuic(option TuicOption) (*Tuic, error) {
tlsConfig.ServerName = "" tlsConfig.ServerName = ""
} }
tkn := tuic.GenTKN(option.Token) tkn := tuic.GenTKN(option.Token)
clientMap := make(map[any]*tuic.Client) tcpClientMap := make(map[any]*tuic.Client)
clientMapMutex := sync.Mutex{} tcpClientMapMutex := &sync.Mutex{}
getClient := func(opts ...dialer.Option) *tuic.Client { udpClientMap := make(map[any]*tuic.Client)
udpClientMapMutex := &sync.Mutex{}
getClient := func(udp bool, opts ...dialer.Option) *tuic.Client {
clientMap := tcpClientMap
clientMapMutex := tcpClientMapMutex
if udp {
clientMap = udpClientMap
clientMapMutex = udpClientMapMutex
}
o := *dialer.ApplyOptions(opts...) o := *dialer.ApplyOptions(opts...)
clientMapMutex.Lock() clientMapMutex.Lock()
@ -217,6 +226,7 @@ func NewTuic(option TuicOption) (*Tuic, error) {
RequestTimeout: option.RequestTimeout, RequestTimeout: option.RequestTimeout,
MaxUdpRelayPacketSize: option.MaxUdpRelayPacketSize, MaxUdpRelayPacketSize: option.MaxUdpRelayPacketSize,
LastVisited: time.Now(), LastVisited: time.Now(),
UDP: udp,
} }
clientMap[o] = client clientMap[o] = client
runtime.SetFinalizer(client, closeTuicClient) runtime.SetFinalizer(client, closeTuicClient)

View file

@ -32,6 +32,7 @@ type Client struct {
MaxUdpRelayPacketSize int MaxUdpRelayPacketSize int
LastVisited time.Time LastVisited time.Time
UDP bool
quicConn quic.Connection quicConn quic.Connection
connMutex sync.Mutex connMutex sync.Mutex
@ -113,9 +114,7 @@ func (t *Client) getQuicConn(ctx context.Context, dialFn func(ctx context.Contex
return nil return nil
} }
go sendAuthentication(quicConn) parseUDP := func(quicConn quic.Connection) (err error) {
go func(quicConn quic.Connection) (err error) {
defer func() { defer func() {
t.deferQuicConn(quicConn, err) t.deferQuicConn(quicConn, err)
}() }()
@ -189,7 +188,13 @@ func (t *Client) getQuicConn(ctx context.Context, dialFn func(ctx context.Contex
}() }()
} }
} }
}(quicConn) }
go sendAuthentication(quicConn)
if t.UDP {
go parseUDP(quicConn)
}
t.quicConn = quicConn t.quicConn = quicConn
return quicConn, nil return quicConn, nil
@ -226,7 +231,7 @@ func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn f
if err != nil { if err != nil {
return nil, err return nil, err
} }
stream, err := func() (stream quic.Stream, err error) { stream, err := func() (stream *quicStreamConn, err error) {
defer func() { defer func() {
t.deferQuicConn(quicConn, err) t.deferQuicConn(quicConn, err)
}() }()
@ -235,10 +240,11 @@ func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn f
if err != nil { if err != nil {
return nil, err return nil, err
} }
stream, err = quicConn.OpenStream() quicStream, err := quicConn.OpenStream()
if err != nil { if err != nil {
return nil, err return nil, err
} }
stream = &quicStreamConn{quicStream, quicConn.LocalAddr(), quicConn.RemoteAddr(), t}
_, err = buf.WriteTo(stream) _, err = buf.WriteTo(stream)
if err != nil { if err != nil {
_ = stream.Close() _ = stream.Close()
@ -253,7 +259,7 @@ func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn f
if t.RequestTimeout > 0 { if t.RequestTimeout > 0 {
_ = stream.SetReadDeadline(time.Now().Add(time.Duration(t.RequestTimeout) * time.Millisecond)) _ = stream.SetReadDeadline(time.Now().Add(time.Duration(t.RequestTimeout) * time.Millisecond))
} }
conn := N.NewBufferedConn(&quicStreamConn{stream, quicConn.LocalAddr(), quicConn.RemoteAddr(), t}) conn := N.NewBufferedConn(stream)
response, err := ReadResponse(conn) response, err := ReadResponse(conn)
if err != nil { if err != nil {
_ = conn.Close() _ = conn.Close()