fix: tuicV5's heartbeat should be a datagram packet

This commit is contained in:
wwqgtxx 2023-06-13 17:50:10 +08:00 committed by Larvan2
parent c2cdf43239
commit f4b734c74c
10 changed files with 233 additions and 216 deletions

View file

@ -175,8 +175,9 @@ func NewTuic(option TuicOption) (*Tuic, error) {
option.HeartbeatInterval = 10000 option.HeartbeatInterval = 10000
} }
udpRelayMode := tuic.QUIC
if option.UdpRelayMode != "quic" { if option.UdpRelayMode != "quic" {
option.UdpRelayMode = "native" udpRelayMode = tuic.NATIVE
} }
if option.MaxUdpRelayPacketSize == 0 { if option.MaxUdpRelayPacketSize == 0 {
@ -264,7 +265,7 @@ func NewTuic(option TuicOption) (*Tuic, error) {
TlsConfig: tlsConfig, TlsConfig: tlsConfig,
QuicConfig: quicConfig, QuicConfig: quicConfig,
Token: tkn, Token: tkn,
UdpRelayMode: option.UdpRelayMode, UdpRelayMode: udpRelayMode,
CongestionController: option.CongestionController, CongestionController: option.CongestionController,
ReduceRtt: option.ReduceRtt, ReduceRtt: option.ReduceRtt,
RequestTimeout: time.Duration(option.RequestTimeout) * time.Millisecond, RequestTimeout: time.Duration(option.RequestTimeout) * time.Millisecond,
@ -280,7 +281,7 @@ func NewTuic(option TuicOption) (*Tuic, error) {
QuicConfig: quicConfig, QuicConfig: quicConfig,
Uuid: uuid.FromStringOrNil(option.UUID), Uuid: uuid.FromStringOrNil(option.UUID),
Password: option.Password, Password: option.Password,
UdpRelayMode: option.UdpRelayMode, UdpRelayMode: udpRelayMode,
CongestionController: option.CongestionController, CongestionController: option.CongestionController,
ReduceRtt: option.ReduceRtt, ReduceRtt: option.ReduceRtt,
MaxUdpRelayPacketSize: option.MaxUdpRelayPacketSize, MaxUdpRelayPacketSize: option.MaxUdpRelayPacketSize,

View file

@ -32,3 +32,10 @@ type Server interface {
Serve() error Serve() error
Close() error Close() error
} }
type UdpRelayMode uint8
const (
QUIC UdpRelayMode = iota
NATIVE
)

View file

@ -45,3 +45,10 @@ const DefaultConnectionReceiveWindow = common.DefaultConnectionReceiveWindow
var GenTKN = v4.GenTKN var GenTKN = v4.GenTKN
var PacketOverHeadV4 = v4.PacketOverHead var PacketOverHeadV4 = v4.PacketOverHead
var PacketOverHeadV5 = v5.PacketOverHead var PacketOverHeadV5 = v5.PacketOverHead
type UdpRelayMode = common.UdpRelayMode
const (
QUIC = common.QUIC
NATIVE = common.NATIVE
)

View file

@ -29,7 +29,7 @@ type ClientOption struct {
TlsConfig *tls.Config TlsConfig *tls.Config
QuicConfig *quic.Config QuicConfig *quic.Config
Token [32]byte Token [32]byte
UdpRelayMode string UdpRelayMode common.UdpRelayMode
CongestionController string CongestionController string
ReduceRtt bool ReduceRtt bool
RequestTimeout time.Duration RequestTimeout time.Duration
@ -99,7 +99,12 @@ func (t *clientImpl) getQuicConn(ctx context.Context, dialer C.Dialer, dialFn co
if t.udp { if t.udp {
go func() { go func() {
_ = t.parseUDP(quicConn) switch t.UdpRelayMode {
case common.QUIC:
_ = t.handleUniStream(quicConn)
default: // native
_ = t.handleMessage(quicConn)
}
}() }()
} }
@ -133,12 +138,10 @@ func (t *clientImpl) sendAuthentication(quicConn quic.Connection) (err error) {
return nil return nil
} }
func (t *clientImpl) parseUDP(quicConn quic.Connection) (err error) { func (t *clientImpl) handleUniStream(quicConn quic.Connection) (err error) {
defer func() { defer func() {
t.deferQuicConn(quicConn, err) t.deferQuicConn(quicConn, err)
}() }()
switch t.UdpRelayMode {
case "quic":
for { for {
var stream quic.ReceiveStream var stream quic.ReceiveStream
stream, err = quicConn.AcceptUniStream(context.Background()) stream, err = quicConn.AcceptUniStream(context.Background())
@ -159,10 +162,18 @@ func (t *clientImpl) parseUDP(quicConn quic.Connection) (err error) {
stream.CancelRead(0) stream.CancelRead(0)
}() }()
reader := bufio.NewReader(stream) reader := bufio.NewReader(stream)
packet, err := ReadPacket(reader) commandHead, err := ReadCommandHead(reader)
if err != nil { if err != nil {
return return
} }
switch commandHead.TYPE {
case PacketType:
var packet Packet
packet, err = ReadPacketWithHead(commandHead, reader)
if err != nil {
return
}
if t.udp && t.UdpRelayMode == common.QUIC {
assocId = packet.ASSOC_ID assocId = packet.ASSOC_ID
if val, ok := t.udpInputMap.Load(assocId); ok { if val, ok := t.udpInputMap.Load(assocId); ok {
if conn, ok := val.(net.Conn); ok { if conn, ok := val.(net.Conn); ok {
@ -171,10 +182,17 @@ func (t *clientImpl) parseUDP(quicConn quic.Connection) (err error) {
_ = writer.Flush() _ = writer.Flush()
} }
} }
}
}
return return
}() }()
} }
default: // native }
func (t *clientImpl) handleMessage(quicConn quic.Connection) (err error) {
defer func() {
t.deferQuicConn(quicConn, err)
}()
for { for {
var message []byte var message []byte
message, err = quicConn.ReceiveMessage() message, err = quicConn.ReceiveMessage()
@ -193,21 +211,30 @@ func (t *clientImpl) parseUDP(quicConn quic.Connection) (err error) {
} }
} }
}() }()
buffer := bytes.NewBuffer(message) reader := bytes.NewBuffer(message)
packet, err := ReadPacket(buffer) commandHead, err := ReadCommandHead(reader)
if err != nil { if err != nil {
return return
} }
switch commandHead.TYPE {
case PacketType:
var packet Packet
packet, err = ReadPacketWithHead(commandHead, reader)
if err != nil {
return
}
if t.udp && t.UdpRelayMode == common.NATIVE {
assocId = packet.ASSOC_ID assocId = packet.ASSOC_ID
if val, ok := t.udpInputMap.Load(assocId); ok { if val, ok := t.udpInputMap.Load(assocId); ok {
if conn, ok := val.(net.Conn); ok { if conn, ok := val.(net.Conn); ok {
_, _ = conn.Write(message) _, _ = conn.Write(message)
} }
} }
}
}
return return
}() }()
} }
}
} }
func (t *clientImpl) deferQuicConn(quicConn quic.Connection, err error) { func (t *clientImpl) deferQuicConn(quicConn quic.Connection, err error) {

View file

@ -6,10 +6,11 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/metacubex/quic-go"
N "github.com/Dreamacro/clash/common/net" N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/transport/tuic/common"
"github.com/metacubex/quic-go"
) )
type quicStreamPacketConn struct { type quicStreamPacketConn struct {
@ -17,7 +18,7 @@ type quicStreamPacketConn struct {
quicConn quic.Connection quicConn quic.Connection
inputConn *N.BufferedConn inputConn *N.BufferedConn
udpRelayMode string udpRelayMode common.UdpRelayMode
maxUdpRelayPacketSize int maxUdpRelayPacketSize int
deferQuicConnFn func(quicConn quic.Connection, err error) deferQuicConnFn func(quicConn quic.Connection, err error)
@ -121,7 +122,7 @@ func (q *quicStreamPacketConn) WaitReadFrom() (data []byte, put func(), addr net
} }
func (q *quicStreamPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (q *quicStreamPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if q.udpRelayMode != "quic" && len(p) > q.maxUdpRelayPacketSize { if q.udpRelayMode != common.QUIC && len(p) > q.maxUdpRelayPacketSize {
return 0, quic.ErrMessageTooLarge(q.maxUdpRelayPacketSize) return 0, quic.ErrMessageTooLarge(q.maxUdpRelayPacketSize)
} }
if q.closed { if q.closed {
@ -147,7 +148,7 @@ func (q *quicStreamPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err erro
return return
} }
switch q.udpRelayMode { switch q.udpRelayMode {
case "quic": case common.QUIC:
var stream quic.SendStream var stream quic.SendStream
stream, err = q.quicConn.OpenUniStream() stream, err = q.quicConn.OpenUniStream()
if err != nil { if err != nil {

View file

@ -118,12 +118,12 @@ func (s *serverHandler) handleMessage() (err error) {
if err != nil { if err != nil {
return return
} }
return s.parsePacket(packet, "native") return s.parsePacket(packet, common.NATIVE)
}() }()
} }
} }
func (s *serverHandler) parsePacket(packet Packet, udpRelayMode string) (err error) { func (s *serverHandler) parsePacket(packet Packet, udpRelayMode common.UdpRelayMode) (err error) {
<-s.authCh <-s.authCh
if !s.authOk { if !s.authOk {
return return
@ -247,7 +247,7 @@ func (s *serverHandler) handleUniStream() (err error) {
if err != nil { if err != nil {
return return
} }
return s.parsePacket(packet, "quic") return s.parsePacket(packet, common.QUIC)
case DissociateType: case DissociateType:
var disassociate Dissociate var disassociate Dissociate
disassociate, err = ReadDissociateWithHead(commandHead, reader) disassociate, err = ReadDissociateWithHead(commandHead, reader)

View file

@ -28,7 +28,7 @@ type ClientOption struct {
QuicConfig *quic.Config QuicConfig *quic.Config
Uuid [16]byte Uuid [16]byte
Password string Password string
UdpRelayMode string UdpRelayMode common.UdpRelayMode
CongestionController string CongestionController string
ReduceRtt bool ReduceRtt bool
MaxUdpRelayPacketSize int MaxUdpRelayPacketSize int
@ -94,11 +94,14 @@ func (t *clientImpl) getQuicConn(ctx context.Context, dialer C.Dialer, dialFn co
_ = t.sendAuthentication(quicConn) _ = t.sendAuthentication(quicConn)
}() }()
if t.udp { if t.udp && t.UdpRelayMode == common.QUIC {
go func() { go func() {
_ = t.parseUDP(quicConn) _ = t.handleUniStream(quicConn)
}() }()
} }
go func() {
_ = t.handleMessage(quicConn) // always handleMessage because tuicV5 using datagram to send the Heartbeat
}()
t.quicConn = quicConn t.quicConn = quicConn
t.openStreams.Store(0) t.openStreams.Store(0)
@ -134,12 +137,10 @@ func (t *clientImpl) sendAuthentication(quicConn quic.Connection) (err error) {
return nil return nil
} }
func (t *clientImpl) parseUDP(quicConn quic.Connection) (err error) { func (t *clientImpl) handleUniStream(quicConn quic.Connection) (err error) {
defer func() { defer func() {
t.deferQuicConn(quicConn, err) t.deferQuicConn(quicConn, err)
}() }()
switch t.UdpRelayMode {
case "quic":
for { for {
var stream quic.ReceiveStream var stream quic.ReceiveStream
stream, err = quicConn.AcceptUniStream(context.Background()) stream, err = quicConn.AcceptUniStream(context.Background())
@ -160,10 +161,18 @@ func (t *clientImpl) parseUDP(quicConn quic.Connection) (err error) {
stream.CancelRead(0) stream.CancelRead(0)
}() }()
reader := bufio.NewReader(stream) reader := bufio.NewReader(stream)
packet, err := ReadPacket(reader) commandHead, err := ReadCommandHead(reader)
if err != nil { if err != nil {
return return
} }
switch commandHead.TYPE {
case PacketType:
var packet Packet
packet, err = ReadPacketWithHead(commandHead, reader)
if err != nil {
return
}
if t.udp && t.UdpRelayMode == common.QUIC {
assocId = packet.ASSOC_ID assocId = packet.ASSOC_ID
if val, ok := t.udpInputMap.Load(assocId); ok { if val, ok := t.udpInputMap.Load(assocId); ok {
if conn, ok := val.(net.Conn); ok { if conn, ok := val.(net.Conn); ok {
@ -172,10 +181,17 @@ func (t *clientImpl) parseUDP(quicConn quic.Connection) (err error) {
_ = writer.Flush() _ = writer.Flush()
} }
} }
}
}
return return
}() }()
} }
default: // native }
func (t *clientImpl) handleMessage(quicConn quic.Connection) (err error) {
defer func() {
t.deferQuicConn(quicConn, err)
}()
for { for {
var message []byte var message []byte
message, err = quicConn.ReceiveMessage() message, err = quicConn.ReceiveMessage()
@ -194,21 +210,37 @@ func (t *clientImpl) parseUDP(quicConn quic.Connection) (err error) {
} }
} }
}() }()
buffer := bytes.NewBuffer(message) reader := bytes.NewBuffer(message)
packet, err := ReadPacket(buffer) commandHead, err := ReadCommandHead(reader)
if err != nil { if err != nil {
return return
} }
switch commandHead.TYPE {
case PacketType:
var packet Packet
packet, err = ReadPacketWithHead(commandHead, reader)
if err != nil {
return
}
if t.udp && t.UdpRelayMode == common.NATIVE {
assocId = packet.ASSOC_ID assocId = packet.ASSOC_ID
if val, ok := t.udpInputMap.Load(assocId); ok { if val, ok := t.udpInputMap.Load(assocId); ok {
if conn, ok := val.(net.Conn); ok { if conn, ok := val.(net.Conn); ok {
_, _ = conn.Write(message) _, _ = conn.Write(message)
} }
} }
}
case HeartbeatType:
var heartbeat Heartbeat
heartbeat, err = ReadHeartbeatWithHead(commandHead, reader)
if err != nil {
return
}
heartbeat.BytesLen()
}
return return
}() }()
} }
}
} }
func (t *clientImpl) deferQuicConn(quicConn quic.Connection, err error) { func (t *clientImpl) deferQuicConn(quicConn quic.Connection, err error) {

View file

@ -9,6 +9,7 @@ import (
N "github.com/Dreamacro/clash/common/net" N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/transport/tuic/common"
"github.com/metacubex/quic-go" "github.com/metacubex/quic-go"
"github.com/zhangyunhao116/fastrand" "github.com/zhangyunhao116/fastrand"
@ -19,7 +20,7 @@ type quicStreamPacketConn struct {
quicConn quic.Connection quicConn quic.Connection
inputConn *N.BufferedConn inputConn *N.BufferedConn
udpRelayMode string udpRelayMode common.UdpRelayMode
maxUdpRelayPacketSize int maxUdpRelayPacketSize int
deferQuicConnFn func(quicConn quic.Connection, err error) deferQuicConnFn func(quicConn quic.Connection, err error)
@ -159,7 +160,7 @@ func (q *quicStreamPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err erro
pktId := uint16(fastrand.Uint32()) pktId := uint16(fastrand.Uint32())
packet := NewPacket(q.connId, pktId, 1, 0, uint16(len(p)), address, p) packet := NewPacket(q.connId, pktId, 1, 0, uint16(len(p)), address, p)
switch q.udpRelayMode { switch q.udpRelayMode {
case "quic": case common.QUIC:
err = packet.WriteTo(buf) err = packet.WriteTo(buf)
if err != nil { if err != nil {
return return

View file

@ -33,7 +33,6 @@ const (
PacketType = CommandType(0x02) PacketType = CommandType(0x02)
DissociateType = CommandType(0x03) DissociateType = CommandType(0x03)
HeartbeatType = CommandType(0x04) HeartbeatType = CommandType(0x04)
ResponseType = CommandType(0xff)
) )
func (c CommandType) String() string { func (c CommandType) String() string {
@ -48,8 +47,6 @@ func (c CommandType) String() string {
return "Dissociate" return "Dissociate"
case HeartbeatType: case HeartbeatType:
return "Heartbeat" return "Heartbeat"
case ResponseType:
return "Response"
default: default:
return fmt.Sprintf("UnknowCommand: %#x", byte(c)) return fmt.Sprintf("UnknowCommand: %#x", byte(c))
} }
@ -406,71 +403,6 @@ func ReadHeartbeat(reader BufferedReader) (c Heartbeat, err error) {
return ReadHeartbeatWithHead(head, reader) return ReadHeartbeatWithHead(head, reader)
} }
type Response struct {
CommandHead
REP byte
}
func NewResponse(REP byte) Response {
return Response{
CommandHead: NewCommandHead(ResponseType),
REP: REP,
}
}
func NewResponseSucceed() Response {
return NewResponse(0x00)
}
func NewResponseFailed() Response {
return NewResponse(0xff)
}
func ReadResponseWithHead(head CommandHead, reader BufferedReader) (c Response, err error) {
c.CommandHead = head
if c.CommandHead.TYPE != ResponseType {
err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE)
return
}
c.REP, err = reader.ReadByte()
if err != nil {
return
}
return
}
func ReadResponse(reader BufferedReader) (c Response, err error) {
head, err := ReadCommandHead(reader)
if err != nil {
return
}
return ReadResponseWithHead(head, reader)
}
func (c Response) WriteTo(writer BufferedWriter) (err error) {
err = c.CommandHead.WriteTo(writer)
if err != nil {
return
}
err = writer.WriteByte(c.REP)
if err != nil {
return
}
return
}
func (c Response) IsSucceed() bool {
return c.REP == 0x00
}
func (c Response) IsFailed() bool {
return c.REP == 0xff
}
func (c Response) BytesLen() int {
return c.CommandHead.BytesLen() + 1
}
// Addr types // Addr types
const ( const (
AtypDomainName byte = 0 AtypDomainName byte = 0

View file

@ -113,17 +113,33 @@ func (s *serverHandler) handleMessage() (err error) {
return err return err
} }
go func() (err error) { go func() (err error) {
buffer := bytes.NewBuffer(message) reader := bytes.NewBuffer(message)
packet, err := ReadPacket(buffer) commandHead, err := ReadCommandHead(reader)
if err != nil { if err != nil {
return return
} }
return s.parsePacket(packet, "native") switch commandHead.TYPE {
case PacketType:
var packet Packet
packet, err = ReadPacketWithHead(commandHead, reader)
if err != nil {
return
}
return s.parsePacket(packet, common.NATIVE)
case HeartbeatType:
var heartbeat Heartbeat
heartbeat, err = ReadHeartbeatWithHead(commandHead, reader)
if err != nil {
return
}
heartbeat.BytesLen()
}
return
}() }()
} }
} }
func (s *serverHandler) parsePacket(packet Packet, udpRelayMode string) (err error) { func (s *serverHandler) parsePacket(packet Packet, udpRelayMode common.UdpRelayMode) (err error) {
<-s.authCh <-s.authCh
if !s.authOk { if !s.authOk {
return return
@ -244,7 +260,7 @@ func (s *serverHandler) handleUniStream() (err error) {
if err != nil { if err != nil {
return return
} }
return s.parsePacket(packet, "quic") return s.parsePacket(packet, common.QUIC)
case DissociateType: case DissociateType:
var disassociate Dissociate var disassociate Dissociate
disassociate, err = ReadDissociateWithHead(commandHead, reader) disassociate, err = ReadDissociateWithHead(commandHead, reader)
@ -255,13 +271,6 @@ func (s *serverHandler) handleUniStream() (err error) {
input := v.(*serverUDPInput) input := v.(*serverUDPInput)
input.writeClosed.Store(true) input.writeClosed.Store(true)
} }
case HeartbeatType:
var heartbeat Heartbeat
heartbeat, err = ReadHeartbeatWithHead(commandHead, reader)
if err != nil {
return
}
heartbeat.BytesLen()
} }
return return
}() }()