chore: decrease memory copy in quic sniffer
This commit is contained in:
parent
8e637a2ec7
commit
ea7e15b447
5 changed files with 41 additions and 35 deletions
|
@ -10,6 +10,7 @@ const BufferSize = buf.BufferSize
|
||||||
type Buffer = buf.Buffer
|
type Buffer = buf.Buffer
|
||||||
|
|
||||||
var New = buf.New
|
var New = buf.New
|
||||||
|
var NewPacket = buf.NewPacket
|
||||||
var NewSize = buf.NewSize
|
var NewSize = buf.NewSize
|
||||||
var With = buf.With
|
var With = buf.With
|
||||||
var As = buf.As
|
var As = buf.As
|
||||||
|
|
|
@ -51,10 +51,7 @@ func (sd *SnifferDispatcher) UDPSniff(packet C.PacketAdapter) bool {
|
||||||
overrideDest := config.OverrideDest
|
overrideDest := config.OverrideDest
|
||||||
|
|
||||||
if inWhitelist {
|
if inWhitelist {
|
||||||
var copyBuf = make([]byte, len(packet.Data()))
|
host, err := sniffer.SniffData(packet.Data())
|
||||||
copy(copyBuf, packet.Data())
|
|
||||||
|
|
||||||
host, err := sniffer.SniffData(copyBuf)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,10 +107,7 @@ func (quic QuicSniffer) SniffData(b []byte) (string, error) {
|
||||||
return "", errNotQuic
|
return "", errNotQuic
|
||||||
}
|
}
|
||||||
|
|
||||||
hdrLen := len(b) - int(buffer.Len())
|
hdrLen := len(b) - buffer.Len()
|
||||||
|
|
||||||
origPNBytes := make([]byte, 4)
|
|
||||||
copy(origPNBytes, b[hdrLen:hdrLen+4])
|
|
||||||
|
|
||||||
var salt []byte
|
var salt []byte
|
||||||
if versionNumber == version1 {
|
if versionNumber == version1 {
|
||||||
|
@ -126,31 +123,40 @@ func (quic QuicSniffer) SniffData(b []byte) (string, error) {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
cache := buf.New()
|
cache := buf.NewPacket()
|
||||||
defer cache.Release()
|
defer cache.Release()
|
||||||
|
|
||||||
mask := cache.Extend(int(block.BlockSize()))
|
mask := cache.Extend(block.BlockSize())
|
||||||
block.Encrypt(mask, b[hdrLen+4:hdrLen+4+16])
|
block.Encrypt(mask, b[hdrLen+4:hdrLen+4+16])
|
||||||
b[0] ^= mask[0] & 0xf
|
firstByte := b[0]
|
||||||
for i := range b[hdrLen : hdrLen+4] {
|
// Encrypt/decrypt first byte.
|
||||||
b[hdrLen+i] ^= mask[i+1]
|
if isLongHeader {
|
||||||
|
// Long header: 4 bits masked
|
||||||
|
// High 4 bits are not protected.
|
||||||
|
firstByte ^= mask[0] & 0x0f
|
||||||
|
} else {
|
||||||
|
// Short header: 5 bits masked
|
||||||
|
// High 3 bits are not protected.
|
||||||
|
firstByte ^= mask[0] & 0x1f
|
||||||
}
|
}
|
||||||
packetNumberLength := b[0]&0x3 + 1
|
packetNumberLength := int(firstByte&0x3 + 1) // max = 4 (64-bit sequence number)
|
||||||
var packetNumber uint32
|
extHdrLen := hdrLen + packetNumberLength
|
||||||
{
|
|
||||||
n, err := buffer.ReadByte()
|
// copy to avoid modify origin data
|
||||||
if err != nil {
|
extHdr := cache.Extend(extHdrLen)
|
||||||
return "", err
|
copy(extHdr, b)
|
||||||
}
|
extHdr[0] = firstByte
|
||||||
packetNumber = uint32(n)
|
|
||||||
|
packetNumber := extHdr[hdrLen:extHdrLen]
|
||||||
|
// Encrypt/decrypt packet number.
|
||||||
|
for i := range packetNumber {
|
||||||
|
packetNumber[i] ^= mask[1+i]
|
||||||
}
|
}
|
||||||
|
|
||||||
if packetNumber != 0 && packetNumber != 1 {
|
if packetNumber[0] != 0 && packetNumber[0] != 1 {
|
||||||
return "", errNotQuicInitial
|
return "", errNotQuicInitial
|
||||||
}
|
}
|
||||||
|
|
||||||
extHdrLen := hdrLen + int(packetNumberLength)
|
|
||||||
copy(b[extHdrLen:hdrLen+4], origPNBytes[packetNumberLength:])
|
|
||||||
data := b[extHdrLen : int(packetLen)+hdrLen]
|
data := b[extHdrLen : int(packetLen)+hdrLen]
|
||||||
|
|
||||||
key := hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16)
|
key := hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16)
|
||||||
|
@ -163,24 +169,20 @@ func (quic QuicSniffer) SniffData(b []byte) (string, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
nonce := cache.Extend(8) // 64-bit sequence number
|
|
||||||
binary.BigEndian.PutUint64(nonce[len(nonce)-8:], uint64(packetNumber))
|
|
||||||
// copy from crypto/tls.aeadAESGCMTLS13
|
|
||||||
for i, b := range nonce {
|
|
||||||
iv[4+i] ^= b
|
|
||||||
}
|
|
||||||
decrypted, err := aead.Open(b[extHdrLen:extHdrLen], iv, data, b[:extHdrLen])
|
|
||||||
// We only decrypt once, so we do not need to XOR it back.
|
// We only decrypt once, so we do not need to XOR it back.
|
||||||
//for i, b := range nonce {
|
// https://github.com/quic-go/qtls-go1-20/blob/e132a0e6cb45e20ac0b705454849a11d09ba5a54/cipher_suites.go#L496
|
||||||
// iv[4+i] ^= b
|
for i, b := range packetNumber {
|
||||||
//}
|
iv[len(iv)-len(packetNumber)+i] ^= b
|
||||||
|
}
|
||||||
|
dst := cache.Extend(len(data))
|
||||||
|
decrypted, err := aead.Open(dst[:0], iv, data, extHdr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
buffer = buf.As(decrypted)
|
buffer = buf.As(decrypted)
|
||||||
|
|
||||||
cryptoLen := uint(0)
|
cryptoLen := uint(0)
|
||||||
cryptoData := make([]byte, buffer.Len())
|
cryptoData := cache.Extend(buffer.Len())
|
||||||
for i := 0; !buffer.IsEmpty(); i++ {
|
for i := 0; !buffer.IsEmpty(); i++ {
|
||||||
frameType := byte(0x0) // Default to PADDING frame
|
frameType := byte(0x0) // Default to PADDING frame
|
||||||
for frameType == 0x0 && !buffer.IsEmpty() {
|
for frameType == 0x0 && !buffer.IsEmpty() {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package sniffer
|
package sniffer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -26,9 +27,11 @@ func TestQuicHeaders(t *testing.T) {
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
pkt, err := hex.DecodeString(test.input)
|
pkt, err := hex.DecodeString(test.input)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
oriPkt := bytes.Clone(pkt)
|
||||||
domain, err := q.SniffData(pkt)
|
domain, err := q.SniffData(pkt)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, test.domain, domain)
|
assert.Equal(t, test.domain, domain)
|
||||||
|
assert.Equal(t, oriPkt, pkt) // ensure input data not changed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -170,6 +173,7 @@ func TestTLSHeaders(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
|
input := bytes.Clone(test.input)
|
||||||
domain, err := SniffTLS(test.input)
|
domain, err := SniffTLS(test.input)
|
||||||
if test.err {
|
if test.err {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -183,5 +187,6 @@ func TestTLSHeaders(t *testing.T) {
|
||||||
t.Error("expect domain ", test.domain, " but got ", domain)
|
t.Error("expect domain ", test.domain, " but got ", domain)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
assert.Equal(t, input, test.input)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import "github.com/Dreamacro/clash/constant"
|
||||||
|
|
||||||
type Sniffer interface {
|
type Sniffer interface {
|
||||||
SupportNetwork() constant.NetWork
|
SupportNetwork() constant.NetWork
|
||||||
|
// SniffData must not change input bytes
|
||||||
SniffData(bytes []byte) (string, error)
|
SniffData(bytes []byte) (string, error)
|
||||||
Protocol() string
|
Protocol() string
|
||||||
SupportPort(port uint16) bool
|
SupportPort(port uint16) bool
|
||||||
|
|
Loading…
Reference in a new issue