diff --git a/component/gun/gun.go b/component/gun/gun.go index d1eef3b8..4ff5efe0 100644 --- a/component/gun/gun.go +++ b/component/gun/gun.go @@ -4,6 +4,7 @@ package gun import ( + "bufio" "crypto/tls" "encoding/binary" "errors" @@ -15,14 +16,13 @@ import ( "sync" "time" - "github.com/Dreamacro/clash/common/pool" - "go.uber.org/atomic" "golang.org/x/net/http2" ) var ( ErrInvalidLength = errors.New("invalid length") + ErrSmallBuffer = errors.New("buffer too small") ) var ( @@ -42,9 +42,8 @@ type Conn struct { once sync.Once close *atomic.Bool err error - - buf []byte - offset int + remain int + br *bufio.Reader } type Config struct { @@ -73,53 +72,76 @@ func (g *Conn) Read(b []byte) (n int, err error) { return 0, g.err } - if g.buf != nil { - n = copy(b, g.buf[g.offset:]) - g.offset += n - if g.offset == len(g.buf) { - g.offset = 0 - g.buf = nil + if g.br != nil { + remain := g.br.Buffered() + if len(b) < remain { + remain = len(b) } + + n, err = g.br.Read(b[:remain]) + if g.br.Buffered() == 0 { + g.br = nil + } + return + } else if g.remain != 0 { + size := g.remain + if len(b) < size { + size = len(b) + } + + n, err = g.response.Body.Read(b[:size]) + g.remain -= n return } else if g.response == nil { return 0, net.ErrClosed } - buf := make([]byte, 5) + // 0x00 grpclength(uint32) 0x0A uleb128 payload + buf := make([]byte, 6) _, err = io.ReadFull(g.response.Body, buf) if err != nil { return 0, err } - grpcPayloadLen := binary.BigEndian.Uint32(buf[1:]) - if grpcPayloadLen > pool.RelayBufferSize { - return 0, ErrInvalidLength - } - buf = pool.Get(int(grpcPayloadLen)) - _, err = io.ReadFull(g.response.Body, buf) + br := bufio.NewReaderSize(g.response.Body, 16) + protobufPayloadLen, err := binary.ReadUvarint(br) if err != nil { - pool.Put(buf) - return 0, io.ErrUnexpectedEOF - } - protobufPayloadLen, protobufLengthLen := decodeUleb128(buf[1:]) - if protobufLengthLen == 0 { - pool.Put(buf) - return 0, ErrInvalidLength - } - if grpcPayloadLen != uint32(protobufPayloadLen)+uint32(protobufLengthLen)+1 { - pool.Put(buf) return 0, ErrInvalidLength } - if len(b) >= int(grpcPayloadLen)-1-int(protobufLengthLen) { - n = copy(b, buf[1+protobufLengthLen:]) - pool.Put(buf) + bufferedSize := br.Buffered() + if len(b) < bufferedSize { + n, err = br.Read(b) + g.br = br + g.remain = int(protobufPayloadLen) - n - g.br.Buffered() return } - n = copy(b, buf[1+protobufLengthLen:]) - g.offset = n + 1 + int(protobufLengthLen) - g.buf = buf - return + + _, err = br.Read(b[:bufferedSize]) + if err != nil { + return + } + + offset := int(protobufPayloadLen) + if len(b) < int(protobufPayloadLen) { + offset = len(b) + } + + if offset == bufferedSize { + return bufferedSize, nil + } + + n, err = io.ReadFull(g.response.Body, b[bufferedSize:offset]) + if err != nil { + return + } + + remain := int(protobufPayloadLen) - offset + if remain > 0 { + g.remain = remain + } + + return offset, nil } func (g *Conn) Write(b []byte) (n int, err error) {