From 644abcf0717f1a382866da9044bc70e68ee35bb8 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Tue, 13 Jun 2023 17:50:10 +0800 Subject: [PATCH] fix: tuicV5's heartbeat should be a datagram packet --- adapter/outbound/tuic.go | 7 +- transport/tuic/common/type.go | 7 ++ transport/tuic/tuic.go | 7 ++ transport/tuic/v4/client.go | 149 +++++++++++++++++++------------- transport/tuic/v4/packet.go | 11 +-- transport/tuic/v4/server.go | 6 +- transport/tuic/v5/client.go | 156 ++++++++++++++++++++-------------- transport/tuic/v5/packet.go | 5 +- transport/tuic/v5/protocol.go | 68 --------------- transport/tuic/v5/server.go | 33 ++++--- 10 files changed, 233 insertions(+), 216 deletions(-) diff --git a/adapter/outbound/tuic.go b/adapter/outbound/tuic.go index e2aafca5..af0d3b30 100644 --- a/adapter/outbound/tuic.go +++ b/adapter/outbound/tuic.go @@ -175,8 +175,9 @@ func NewTuic(option TuicOption) (*Tuic, error) { option.HeartbeatInterval = 10000 } + udpRelayMode := tuic.QUIC if option.UdpRelayMode != "quic" { - option.UdpRelayMode = "native" + udpRelayMode = tuic.NATIVE } if option.MaxUdpRelayPacketSize == 0 { @@ -264,7 +265,7 @@ func NewTuic(option TuicOption) (*Tuic, error) { TlsConfig: tlsConfig, QuicConfig: quicConfig, Token: tkn, - UdpRelayMode: option.UdpRelayMode, + UdpRelayMode: udpRelayMode, CongestionController: option.CongestionController, ReduceRtt: option.ReduceRtt, RequestTimeout: time.Duration(option.RequestTimeout) * time.Millisecond, @@ -280,7 +281,7 @@ func NewTuic(option TuicOption) (*Tuic, error) { QuicConfig: quicConfig, Uuid: uuid.FromStringOrNil(option.UUID), Password: option.Password, - UdpRelayMode: option.UdpRelayMode, + UdpRelayMode: udpRelayMode, CongestionController: option.CongestionController, ReduceRtt: option.ReduceRtt, MaxUdpRelayPacketSize: option.MaxUdpRelayPacketSize, diff --git a/transport/tuic/common/type.go b/transport/tuic/common/type.go index 16c6f49e..a5a60986 100644 --- a/transport/tuic/common/type.go +++ b/transport/tuic/common/type.go @@ -32,3 +32,10 @@ type Server interface { Serve() error Close() error } + +type UdpRelayMode uint8 + +const ( + QUIC UdpRelayMode = iota + NATIVE +) diff --git a/transport/tuic/tuic.go b/transport/tuic/tuic.go index 279cec95..7be6f450 100644 --- a/transport/tuic/tuic.go +++ b/transport/tuic/tuic.go @@ -45,3 +45,10 @@ const DefaultConnectionReceiveWindow = common.DefaultConnectionReceiveWindow var GenTKN = v4.GenTKN var PacketOverHeadV4 = v4.PacketOverHead var PacketOverHeadV5 = v5.PacketOverHead + +type UdpRelayMode = common.UdpRelayMode + +const ( + QUIC = common.QUIC + NATIVE = common.NATIVE +) diff --git a/transport/tuic/v4/client.go b/transport/tuic/v4/client.go index ae0cf473..7e5ed7e0 100644 --- a/transport/tuic/v4/client.go +++ b/transport/tuic/v4/client.go @@ -29,7 +29,7 @@ type ClientOption struct { TlsConfig *tls.Config QuicConfig *quic.Config Token [32]byte - UdpRelayMode string + UdpRelayMode common.UdpRelayMode CongestionController string ReduceRtt bool RequestTimeout time.Duration @@ -99,7 +99,12 @@ func (t *clientImpl) getQuicConn(ctx context.Context, dialer C.Dialer, dialFn co if t.udp { go func() { - _ = t.parseUDP(quicConn) + switch t.UdpRelayMode { + case common.QUIC: + _ = t.handleUniStream(quicConn) + default: // native + _ = t.handleMessage(quicConn) + } }() } @@ -133,80 +138,102 @@ func (t *clientImpl) sendAuthentication(quicConn quic.Connection) (err error) { return nil } -func (t *clientImpl) parseUDP(quicConn quic.Connection) (err error) { +func (t *clientImpl) handleUniStream(quicConn quic.Connection) (err error) { defer func() { t.deferQuicConn(quicConn, err) }() - switch t.UdpRelayMode { - case "quic": - for { - var stream quic.ReceiveStream - stream, err = quicConn.AcceptUniStream(context.Background()) - if err != nil { - return err - } - go func() (err error) { - var assocId uint32 - defer func() { - t.deferQuicConn(quicConn, err) - if err != nil && assocId != 0 { - if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok { - if conn, ok := val.(net.Conn); ok { - _ = conn.Close() - } + for { + var stream quic.ReceiveStream + stream, err = quicConn.AcceptUniStream(context.Background()) + if err != nil { + return err + } + go func() (err error) { + var assocId uint32 + defer func() { + t.deferQuicConn(quicConn, err) + if err != nil && assocId != 0 { + if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok { + if conn, ok := val.(net.Conn); ok { + _ = conn.Close() } } - stream.CancelRead(0) - }() - reader := bufio.NewReader(stream) - packet, err := ReadPacket(reader) + } + stream.CancelRead(0) + }() + reader := bufio.NewReader(stream) + commandHead, err := ReadCommandHead(reader) + if err != nil { + return + } + switch commandHead.TYPE { + case PacketType: + var packet Packet + packet, err = ReadPacketWithHead(commandHead, reader) if err != nil { return } - assocId = packet.ASSOC_ID - if val, ok := t.udpInputMap.Load(assocId); ok { - if conn, ok := val.(net.Conn); ok { - writer := bufio.NewWriterSize(conn, packet.BytesLen()) - _ = packet.WriteTo(writer) - _ = writer.Flush() - } - } - return - }() - } - default: // native - for { - var message []byte - message, err = quicConn.ReceiveMessage() - if err != nil { - return err - } - go func() (err error) { - var assocId uint32 - defer func() { - t.deferQuicConn(quicConn, err) - if err != nil && assocId != 0 { - if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok { - if conn, ok := val.(net.Conn); ok { - _ = conn.Close() - } + if t.udp && t.UdpRelayMode == common.QUIC { + assocId = packet.ASSOC_ID + if val, ok := t.udpInputMap.Load(assocId); ok { + if conn, ok := val.(net.Conn); ok { + writer := bufio.NewWriterSize(conn, packet.BytesLen()) + _ = packet.WriteTo(writer) + _ = writer.Flush() } } - }() - buffer := bytes.NewBuffer(message) - packet, err := ReadPacket(buffer) + } + } + return + }() + } +} + +func (t *clientImpl) handleMessage(quicConn quic.Connection) (err error) { + defer func() { + t.deferQuicConn(quicConn, err) + }() + for { + var message []byte + message, err = quicConn.ReceiveMessage() + if err != nil { + return err + } + go func() (err error) { + var assocId uint32 + defer func() { + t.deferQuicConn(quicConn, err) + if err != nil && assocId != 0 { + if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok { + if conn, ok := val.(net.Conn); ok { + _ = conn.Close() + } + } + } + }() + reader := bytes.NewBuffer(message) + commandHead, err := ReadCommandHead(reader) + if err != nil { + return + } + switch commandHead.TYPE { + case PacketType: + var packet Packet + packet, err = ReadPacketWithHead(commandHead, reader) if err != nil { return } - assocId = packet.ASSOC_ID - if val, ok := t.udpInputMap.Load(assocId); ok { - if conn, ok := val.(net.Conn); ok { - _, _ = conn.Write(message) + if t.udp && t.UdpRelayMode == common.NATIVE { + assocId = packet.ASSOC_ID + if val, ok := t.udpInputMap.Load(assocId); ok { + if conn, ok := val.(net.Conn); ok { + _, _ = conn.Write(message) + } } } - return - }() - } + } + return + }() } } diff --git a/transport/tuic/v4/packet.go b/transport/tuic/v4/packet.go index edd872cc..2f808bef 100644 --- a/transport/tuic/v4/packet.go +++ b/transport/tuic/v4/packet.go @@ -6,10 +6,11 @@ import ( "sync/atomic" "time" - "github.com/metacubex/quic-go" - N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/pool" + "github.com/Dreamacro/clash/transport/tuic/common" + + "github.com/metacubex/quic-go" ) type quicStreamPacketConn struct { @@ -17,7 +18,7 @@ type quicStreamPacketConn struct { quicConn quic.Connection inputConn *N.BufferedConn - udpRelayMode string + udpRelayMode common.UdpRelayMode maxUdpRelayPacketSize int deferQuicConnFn func(quicConn quic.Connection, err error) @@ -121,7 +122,7 @@ func (q *quicStreamPacketConn) WaitReadFrom() (data []byte, put func(), addr net } func (q *quicStreamPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if q.udpRelayMode != "quic" && len(p) > q.maxUdpRelayPacketSize { + if q.udpRelayMode != common.QUIC && len(p) > q.maxUdpRelayPacketSize { return 0, quic.ErrMessageTooLarge(q.maxUdpRelayPacketSize) } if q.closed { @@ -147,7 +148,7 @@ func (q *quicStreamPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err erro return } switch q.udpRelayMode { - case "quic": + case common.QUIC: var stream quic.SendStream stream, err = q.quicConn.OpenUniStream() if err != nil { diff --git a/transport/tuic/v4/server.go b/transport/tuic/v4/server.go index 525ead17..017494ea 100644 --- a/transport/tuic/v4/server.go +++ b/transport/tuic/v4/server.go @@ -118,12 +118,12 @@ func (s *serverHandler) handleMessage() (err error) { if err != nil { return } - return s.parsePacket(packet, "native") + return s.parsePacket(packet, common.NATIVE) }() } } -func (s *serverHandler) parsePacket(packet Packet, udpRelayMode string) (err error) { +func (s *serverHandler) parsePacket(packet Packet, udpRelayMode common.UdpRelayMode) (err error) { <-s.authCh if !s.authOk { return @@ -247,7 +247,7 @@ func (s *serverHandler) handleUniStream() (err error) { if err != nil { return } - return s.parsePacket(packet, "quic") + return s.parsePacket(packet, common.QUIC) case DissociateType: var disassociate Dissociate disassociate, err = ReadDissociateWithHead(commandHead, reader) diff --git a/transport/tuic/v5/client.go b/transport/tuic/v5/client.go index 9b878177..7bc1c360 100644 --- a/transport/tuic/v5/client.go +++ b/transport/tuic/v5/client.go @@ -28,7 +28,7 @@ type ClientOption struct { QuicConfig *quic.Config Uuid [16]byte Password string - UdpRelayMode string + UdpRelayMode common.UdpRelayMode CongestionController string ReduceRtt bool MaxUdpRelayPacketSize int @@ -94,11 +94,14 @@ func (t *clientImpl) getQuicConn(ctx context.Context, dialer C.Dialer, dialFn co _ = t.sendAuthentication(quicConn) }() - if t.udp { + if t.udp && t.UdpRelayMode == common.QUIC { go func() { - _ = t.parseUDP(quicConn) + _ = t.handleUniStream(quicConn) }() } + go func() { + _ = t.handleMessage(quicConn) // always handleMessage because tuicV5 using datagram to send the Heartbeat + }() t.quicConn = quicConn t.openStreams.Store(0) @@ -134,80 +137,109 @@ func (t *clientImpl) sendAuthentication(quicConn quic.Connection) (err error) { return nil } -func (t *clientImpl) parseUDP(quicConn quic.Connection) (err error) { +func (t *clientImpl) handleUniStream(quicConn quic.Connection) (err error) { defer func() { t.deferQuicConn(quicConn, err) }() - switch t.UdpRelayMode { - case "quic": - for { - var stream quic.ReceiveStream - stream, err = quicConn.AcceptUniStream(context.Background()) - if err != nil { - return err - } - go func() (err error) { - var assocId uint16 - defer func() { - t.deferQuicConn(quicConn, err) - if err != nil && assocId != 0 { - if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok { - if conn, ok := val.(net.Conn); ok { - _ = conn.Close() - } + for { + var stream quic.ReceiveStream + stream, err = quicConn.AcceptUniStream(context.Background()) + if err != nil { + return err + } + go func() (err error) { + var assocId uint16 + defer func() { + t.deferQuicConn(quicConn, err) + if err != nil && assocId != 0 { + if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok { + if conn, ok := val.(net.Conn); ok { + _ = conn.Close() } } - stream.CancelRead(0) - }() - reader := bufio.NewReader(stream) - packet, err := ReadPacket(reader) + } + stream.CancelRead(0) + }() + reader := bufio.NewReader(stream) + commandHead, err := ReadCommandHead(reader) + if err != nil { + return + } + switch commandHead.TYPE { + case PacketType: + var packet Packet + packet, err = ReadPacketWithHead(commandHead, reader) if err != nil { return } - assocId = packet.ASSOC_ID - if val, ok := t.udpInputMap.Load(assocId); ok { - if conn, ok := val.(net.Conn); ok { - writer := bufio.NewWriterSize(conn, packet.BytesLen()) - _ = packet.WriteTo(writer) - _ = writer.Flush() - } - } - return - }() - } - default: // native - for { - var message []byte - message, err = quicConn.ReceiveMessage() - if err != nil { - return err - } - go func() (err error) { - var assocId uint16 - defer func() { - t.deferQuicConn(quicConn, err) - if err != nil && assocId != 0 { - if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok { - if conn, ok := val.(net.Conn); ok { - _ = conn.Close() - } + if t.udp && t.UdpRelayMode == common.QUIC { + assocId = packet.ASSOC_ID + if val, ok := t.udpInputMap.Load(assocId); ok { + if conn, ok := val.(net.Conn); ok { + writer := bufio.NewWriterSize(conn, packet.BytesLen()) + _ = packet.WriteTo(writer) + _ = writer.Flush() } } - }() - buffer := bytes.NewBuffer(message) - packet, err := ReadPacket(buffer) + } + } + return + }() + } +} + +func (t *clientImpl) handleMessage(quicConn quic.Connection) (err error) { + defer func() { + t.deferQuicConn(quicConn, err) + }() + for { + var message []byte + message, err = quicConn.ReceiveMessage() + if err != nil { + return err + } + go func() (err error) { + var assocId uint16 + defer func() { + t.deferQuicConn(quicConn, err) + if err != nil && assocId != 0 { + if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok { + if conn, ok := val.(net.Conn); ok { + _ = conn.Close() + } + } + } + }() + reader := bytes.NewBuffer(message) + commandHead, err := ReadCommandHead(reader) + if err != nil { + return + } + switch commandHead.TYPE { + case PacketType: + var packet Packet + packet, err = ReadPacketWithHead(commandHead, reader) if err != nil { return } - assocId = packet.ASSOC_ID - if val, ok := t.udpInputMap.Load(assocId); ok { - if conn, ok := val.(net.Conn); ok { - _, _ = conn.Write(message) + if t.udp && t.UdpRelayMode == common.NATIVE { + assocId = packet.ASSOC_ID + if val, ok := t.udpInputMap.Load(assocId); ok { + if conn, ok := val.(net.Conn); ok { + _, _ = conn.Write(message) + } } } - return - }() - } + case HeartbeatType: + var heartbeat Heartbeat + heartbeat, err = ReadHeartbeatWithHead(commandHead, reader) + if err != nil { + return + } + heartbeat.BytesLen() + } + return + }() } } diff --git a/transport/tuic/v5/packet.go b/transport/tuic/v5/packet.go index 50c602eb..9f546400 100644 --- a/transport/tuic/v5/packet.go +++ b/transport/tuic/v5/packet.go @@ -9,6 +9,7 @@ import ( N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/pool" + "github.com/Dreamacro/clash/transport/tuic/common" "github.com/metacubex/quic-go" "github.com/zhangyunhao116/fastrand" @@ -19,7 +20,7 @@ type quicStreamPacketConn struct { quicConn quic.Connection inputConn *N.BufferedConn - udpRelayMode string + udpRelayMode common.UdpRelayMode maxUdpRelayPacketSize int deferQuicConnFn func(quicConn quic.Connection, err error) @@ -159,7 +160,7 @@ func (q *quicStreamPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err erro pktId := uint16(fastrand.Uint32()) packet := NewPacket(q.connId, pktId, 1, 0, uint16(len(p)), address, p) switch q.udpRelayMode { - case "quic": + case common.QUIC: err = packet.WriteTo(buf) if err != nil { return diff --git a/transport/tuic/v5/protocol.go b/transport/tuic/v5/protocol.go index f2849746..dc7062ea 100644 --- a/transport/tuic/v5/protocol.go +++ b/transport/tuic/v5/protocol.go @@ -33,7 +33,6 @@ const ( PacketType = CommandType(0x02) DissociateType = CommandType(0x03) HeartbeatType = CommandType(0x04) - ResponseType = CommandType(0xff) ) func (c CommandType) String() string { @@ -48,8 +47,6 @@ func (c CommandType) String() string { return "Dissociate" case HeartbeatType: return "Heartbeat" - case ResponseType: - return "Response" default: return fmt.Sprintf("UnknowCommand: %#x", byte(c)) } @@ -406,71 +403,6 @@ func ReadHeartbeat(reader BufferedReader) (c Heartbeat, err error) { return ReadHeartbeatWithHead(head, reader) } -type Response struct { - CommandHead - REP byte -} - -func NewResponse(REP byte) Response { - return Response{ - CommandHead: NewCommandHead(ResponseType), - REP: REP, - } -} - -func NewResponseSucceed() Response { - return NewResponse(0x00) -} - -func NewResponseFailed() Response { - return NewResponse(0xff) -} - -func ReadResponseWithHead(head CommandHead, reader BufferedReader) (c Response, err error) { - c.CommandHead = head - if c.CommandHead.TYPE != ResponseType { - err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE) - return - } - c.REP, err = reader.ReadByte() - if err != nil { - return - } - return -} - -func ReadResponse(reader BufferedReader) (c Response, err error) { - head, err := ReadCommandHead(reader) - if err != nil { - return - } - return ReadResponseWithHead(head, reader) -} - -func (c Response) WriteTo(writer BufferedWriter) (err error) { - err = c.CommandHead.WriteTo(writer) - if err != nil { - return - } - err = writer.WriteByte(c.REP) - if err != nil { - return - } - return -} - -func (c Response) IsSucceed() bool { - return c.REP == 0x00 -} - -func (c Response) IsFailed() bool { - return c.REP == 0xff -} - -func (c Response) BytesLen() int { - return c.CommandHead.BytesLen() + 1 -} - // Addr types const ( AtypDomainName byte = 0 diff --git a/transport/tuic/v5/server.go b/transport/tuic/v5/server.go index 3e3dc52f..26965436 100644 --- a/transport/tuic/v5/server.go +++ b/transport/tuic/v5/server.go @@ -113,17 +113,33 @@ func (s *serverHandler) handleMessage() (err error) { return err } go func() (err error) { - buffer := bytes.NewBuffer(message) - packet, err := ReadPacket(buffer) + reader := bytes.NewBuffer(message) + commandHead, err := ReadCommandHead(reader) if err != nil { return } - return s.parsePacket(packet, "native") + switch commandHead.TYPE { + case PacketType: + var packet Packet + packet, err = ReadPacketWithHead(commandHead, reader) + if err != nil { + return + } + return s.parsePacket(packet, common.NATIVE) + case HeartbeatType: + var heartbeat Heartbeat + heartbeat, err = ReadHeartbeatWithHead(commandHead, reader) + if err != nil { + return + } + heartbeat.BytesLen() + } + return }() } } -func (s *serverHandler) parsePacket(packet Packet, udpRelayMode string) (err error) { +func (s *serverHandler) parsePacket(packet Packet, udpRelayMode common.UdpRelayMode) (err error) { <-s.authCh if !s.authOk { return @@ -244,7 +260,7 @@ func (s *serverHandler) handleUniStream() (err error) { if err != nil { return } - return s.parsePacket(packet, "quic") + return s.parsePacket(packet, common.QUIC) case DissociateType: var disassociate Dissociate disassociate, err = ReadDissociateWithHead(commandHead, reader) @@ -255,13 +271,6 @@ func (s *serverHandler) handleUniStream() (err error) { input := v.(*serverUDPInput) input.writeClosed.Store(true) } - case HeartbeatType: - var heartbeat Heartbeat - heartbeat, err = ReadHeartbeatWithHead(commandHead, reader) - if err != nil { - return - } - heartbeat.BytesLen() } return }()