From 68da19351ebf2ed66935296c3a3cd245eeb3e00c Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Thu, 19 Oct 2023 23:51:37 +0800 Subject: [PATCH] chore: decrease memory copy in quic sniffer --- common/buf/sing.go | 1 + component/sniffer/dispatcher.go | 5 +-- component/sniffer/quic_sniffer.go | 64 ++++++++++++++++--------------- component/sniffer/sniff_test.go | 5 +++ constant/sniffer/sniffer.go | 1 + 5 files changed, 41 insertions(+), 35 deletions(-) diff --git a/common/buf/sing.go b/common/buf/sing.go index d204ba11..0907a95c 100644 --- a/common/buf/sing.go +++ b/common/buf/sing.go @@ -10,6 +10,7 @@ const BufferSize = buf.BufferSize type Buffer = buf.Buffer var New = buf.New +var NewPacket = buf.NewPacket var NewSize = buf.NewSize var With = buf.With var As = buf.As diff --git a/component/sniffer/dispatcher.go b/component/sniffer/dispatcher.go index 8df6313c..271be8bb 100644 --- a/component/sniffer/dispatcher.go +++ b/component/sniffer/dispatcher.go @@ -51,10 +51,7 @@ func (sd *SnifferDispatcher) UDPSniff(packet C.PacketAdapter) bool { overrideDest := config.OverrideDest if inWhitelist { - var copyBuf = make([]byte, len(packet.Data())) - copy(copyBuf, packet.Data()) - - host, err := sniffer.SniffData(copyBuf) + host, err := sniffer.SniffData(packet.Data()) if err != nil { continue } diff --git a/component/sniffer/quic_sniffer.go b/component/sniffer/quic_sniffer.go index 24e1bcc4..ef49e5ad 100644 --- a/component/sniffer/quic_sniffer.go +++ b/component/sniffer/quic_sniffer.go @@ -107,10 +107,7 @@ func (quic QuicSniffer) SniffData(b []byte) (string, error) { return "", errNotQuic } - hdrLen := len(b) - int(buffer.Len()) - - origPNBytes := make([]byte, 4) - copy(origPNBytes, b[hdrLen:hdrLen+4]) + hdrLen := len(b) - buffer.Len() var salt []byte if versionNumber == version1 { @@ -126,31 +123,40 @@ func (quic QuicSniffer) SniffData(b []byte) (string, error) { return "", err } - cache := buf.New() + cache := buf.NewPacket() defer cache.Release() - mask := cache.Extend(int(block.BlockSize())) + mask := cache.Extend(block.BlockSize()) block.Encrypt(mask, b[hdrLen+4:hdrLen+4+16]) - b[0] ^= mask[0] & 0xf - for i := range b[hdrLen : hdrLen+4] { - b[hdrLen+i] ^= mask[i+1] + firstByte := b[0] + // Encrypt/decrypt first byte. + if isLongHeader { + // Long header: 4 bits masked + // High 4 bits are not protected. + firstByte ^= mask[0] & 0x0f + } else { + // Short header: 5 bits masked + // High 3 bits are not protected. + firstByte ^= mask[0] & 0x1f } - packetNumberLength := b[0]&0x3 + 1 - var packetNumber uint32 - { - n, err := buffer.ReadByte() - if err != nil { - return "", err - } - packetNumber = uint32(n) + packetNumberLength := int(firstByte&0x3 + 1) // max = 4 (64-bit sequence number) + extHdrLen := hdrLen + packetNumberLength + + // copy to avoid modify origin data + extHdr := cache.Extend(extHdrLen) + copy(extHdr, b) + extHdr[0] = firstByte + + packetNumber := extHdr[hdrLen:extHdrLen] + // Encrypt/decrypt packet number. + for i := range packetNumber { + packetNumber[i] ^= mask[1+i] } - if packetNumber != 0 && packetNumber != 1 { + if packetNumber[0] != 0 && packetNumber[0] != 1 { return "", errNotQuicInitial } - extHdrLen := hdrLen + int(packetNumberLength) - copy(b[extHdrLen:hdrLen+4], origPNBytes[packetNumberLength:]) data := b[extHdrLen : int(packetLen)+hdrLen] key := hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16) @@ -163,24 +169,20 @@ func (quic QuicSniffer) SniffData(b []byte) (string, error) { if err != nil { return "", err } - nonce := cache.Extend(8) // 64-bit sequence number - binary.BigEndian.PutUint64(nonce[len(nonce)-8:], uint64(packetNumber)) - // copy from crypto/tls.aeadAESGCMTLS13 - for i, b := range nonce { - iv[4+i] ^= b - } - decrypted, err := aead.Open(b[extHdrLen:extHdrLen], iv, data, b[:extHdrLen]) // We only decrypt once, so we do not need to XOR it back. - //for i, b := range nonce { - // iv[4+i] ^= b - //} + // https://github.com/quic-go/qtls-go1-20/blob/e132a0e6cb45e20ac0b705454849a11d09ba5a54/cipher_suites.go#L496 + for i, b := range packetNumber { + iv[len(iv)-len(packetNumber)+i] ^= b + } + dst := cache.Extend(len(data)) + decrypted, err := aead.Open(dst[:0], iv, data, extHdr) if err != nil { return "", err } buffer = buf.As(decrypted) cryptoLen := uint(0) - cryptoData := make([]byte, buffer.Len()) + cryptoData := cache.Extend(buffer.Len()) for i := 0; !buffer.IsEmpty(); i++ { frameType := byte(0x0) // Default to PADDING frame for frameType == 0x0 && !buffer.IsEmpty() { diff --git a/component/sniffer/sniff_test.go b/component/sniffer/sniff_test.go index 4c59d432..18cc9152 100644 --- a/component/sniffer/sniff_test.go +++ b/component/sniffer/sniff_test.go @@ -1,6 +1,7 @@ package sniffer import ( + "bytes" "encoding/hex" "github.com/stretchr/testify/assert" "testing" @@ -26,9 +27,11 @@ func TestQuicHeaders(t *testing.T) { for _, test := range cases { pkt, err := hex.DecodeString(test.input) assert.NoError(t, err) + oriPkt := bytes.Clone(pkt) domain, err := q.SniffData(pkt) assert.NoError(t, err) assert.Equal(t, test.domain, domain) + assert.Equal(t, oriPkt, pkt) // ensure input data not changed } } @@ -170,6 +173,7 @@ func TestTLSHeaders(t *testing.T) { } for _, test := range cases { + input := bytes.Clone(test.input) domain, err := SniffTLS(test.input) if test.err { if err == nil { @@ -183,5 +187,6 @@ func TestTLSHeaders(t *testing.T) { t.Error("expect domain ", test.domain, " but got ", domain) } } + assert.Equal(t, input, test.input) } } diff --git a/constant/sniffer/sniffer.go b/constant/sniffer/sniffer.go index d5414b14..47dbd069 100644 --- a/constant/sniffer/sniffer.go +++ b/constant/sniffer/sniffer.go @@ -4,6 +4,7 @@ import "github.com/Dreamacro/clash/constant" type Sniffer interface { SupportNetwork() constant.NetWork + // SniffData must not change input bytes SniffData(bytes []byte) (string, error) Protocol() string SupportPort(port uint16) bool