From 9ea09b2b9441f942d8126c1d7d2fc2cfd9225a54 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Tue, 29 Nov 2022 00:42:26 +0800 Subject: [PATCH] fix: tuic protocol error --- transport/tuic/protocol.go | 10 ++++++++-- transport/tuic/server.go | 4 +++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/transport/tuic/protocol.go b/transport/tuic/protocol.go index f7b85a23..ab696e79 100644 --- a/transport/tuic/protocol.go +++ b/transport/tuic/protocol.go @@ -33,7 +33,7 @@ const ( PacketType = CommandType(0x02) DissociateType = CommandType(0x03) HeartbeatType = CommandType(0x04) - ResponseType = CommandType(0x05) + ResponseType = CommandType(0xff) ) func (c CommandType) String() string { @@ -119,6 +119,7 @@ func ReadAuthenticateWithHead(head CommandHead, reader BufferedReader) (c Authen } if c.CommandHead.TYPE != AuthenticateType { err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE) + return } _, err = io.ReadFull(reader, c.TKN[:]) if err != nil { @@ -174,6 +175,7 @@ func ReadConnectWithHead(head CommandHead, reader BufferedReader) (c Connect, er } if c.CommandHead.TYPE != ConnectType { err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE) + return } c.ADDR, err = ReadAddress(reader) if err != nil { @@ -231,6 +233,7 @@ func ReadPacketWithHead(head CommandHead, reader BufferedReader) (c Packet, err } if c.CommandHead.TYPE != PacketType { err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE) + return } err = binary.Read(reader, binary.BigEndian, &c.ASSOC_ID) if err != nil { @@ -305,8 +308,9 @@ func ReadDissociateWithHead(head CommandHead, reader BufferedReader) (c Dissocia if err != nil { return } - if c.CommandHead.TYPE != PacketType { + if c.CommandHead.TYPE != DissociateType { err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE) + return } err = binary.Read(reader, binary.BigEndian, &c.ASSOC_ID) if err != nil { @@ -353,6 +357,7 @@ func ReadHeartbeatWithHead(head CommandHead, reader BufferedReader) (c Heartbeat c.CommandHead = head if c.CommandHead.TYPE != HeartbeatType { err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE) + return } return } @@ -389,6 +394,7 @@ func ReadResponseWithHead(head CommandHead, reader BufferedReader) (c Response, c.CommandHead = head if c.CommandHead.TYPE != ResponseType { err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE) + return } c.REP, err = reader.ReadByte() if err != nil { diff --git a/transport/tuic/server.go b/transport/tuic/server.go index 78344589..c1213d68 100644 --- a/transport/tuic/server.go +++ b/transport/tuic/server.go @@ -190,8 +190,10 @@ func (s *serverHandler) handleStream() (err error) { err = s.HandleTcpFn(conn, connect.ADDR.SocksAddr()) if err != nil { err = NewResponseFailed().WriteTo(buf) + defer conn.Close() + } else { + err = NewResponseSucceed().WriteTo(buf) } - err = NewResponseSucceed().WriteTo(buf) if err != nil { _ = conn.Close() return err