Improve: pool buffer alloc
This commit is contained in:
parent
d97ee63d0e
commit
0b60be9438
12 changed files with 158 additions and 48 deletions
65
common/pool/alloc.go
Normal file
65
common/pool/alloc.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package pool
|
||||
|
||||
// Inspired by https://github.com/xtaci/smux/blob/master/alloc.go
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/bits"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var defaultAllocator *Allocator
|
||||
|
||||
func init() {
|
||||
defaultAllocator = NewAllocator()
|
||||
}
|
||||
|
||||
// Allocator for incoming frames, optimized to prevent overwriting after zeroing
|
||||
type Allocator struct {
|
||||
buffers []sync.Pool
|
||||
}
|
||||
|
||||
// NewAllocator initiates a []byte allocator for frames less than 65536 bytes,
|
||||
// the waste(memory fragmentation) of space allocation is guaranteed to be
|
||||
// no more than 50%.
|
||||
func NewAllocator() *Allocator {
|
||||
alloc := new(Allocator)
|
||||
alloc.buffers = make([]sync.Pool, 17) // 1B -> 64K
|
||||
for k := range alloc.buffers {
|
||||
i := k
|
||||
alloc.buffers[k].New = func() interface{} {
|
||||
return make([]byte, 1<<uint32(i))
|
||||
}
|
||||
}
|
||||
return alloc
|
||||
}
|
||||
|
||||
// Get a []byte from pool with most appropriate cap
|
||||
func (alloc *Allocator) Get(size int) []byte {
|
||||
if size <= 0 || size > 65536 {
|
||||
return nil
|
||||
}
|
||||
|
||||
bits := msb(size)
|
||||
if size == 1<<bits {
|
||||
return alloc.buffers[bits].Get().([]byte)[:size]
|
||||
}
|
||||
|
||||
return alloc.buffers[bits+1].Get().([]byte)[:size]
|
||||
}
|
||||
|
||||
// Put returns a []byte to pool for future use,
|
||||
// which the cap must be exactly 2^n
|
||||
func (alloc *Allocator) Put(buf []byte) error {
|
||||
bits := msb(cap(buf))
|
||||
if cap(buf) == 0 || cap(buf) > 65536 || cap(buf) != 1<<bits {
|
||||
return errors.New("allocator Put() incorrect buffer size")
|
||||
}
|
||||
alloc.buffers[bits].Put(buf)
|
||||
return nil
|
||||
}
|
||||
|
||||
// msb return the pos of most significiant bit
|
||||
func msb(size int) uint16 {
|
||||
return uint16(bits.Len32(uint32(size)) - 1)
|
||||
}
|
48
common/pool/alloc_test.go
Normal file
48
common/pool/alloc_test.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package pool
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAllocGet(t *testing.T) {
|
||||
alloc := NewAllocator()
|
||||
assert.Nil(t, alloc.Get(0))
|
||||
assert.Equal(t, 1, len(alloc.Get(1)))
|
||||
assert.Equal(t, 2, len(alloc.Get(2)))
|
||||
assert.Equal(t, 3, len(alloc.Get(3)))
|
||||
assert.Equal(t, 4, cap(alloc.Get(3)))
|
||||
assert.Equal(t, 4, cap(alloc.Get(4)))
|
||||
assert.Equal(t, 1023, len(alloc.Get(1023)))
|
||||
assert.Equal(t, 1024, cap(alloc.Get(1023)))
|
||||
assert.Equal(t, 1024, len(alloc.Get(1024)))
|
||||
assert.Equal(t, 65536, len(alloc.Get(65536)))
|
||||
assert.Nil(t, alloc.Get(65537))
|
||||
}
|
||||
|
||||
func TestAllocPut(t *testing.T) {
|
||||
alloc := NewAllocator()
|
||||
assert.NotNil(t, alloc.Put(nil), "put nil misbehavior")
|
||||
assert.NotNil(t, alloc.Put(make([]byte, 3, 3)), "put elem:3 []bytes misbehavior")
|
||||
assert.Nil(t, alloc.Put(make([]byte, 4, 4)), "put elem:4 []bytes misbehavior")
|
||||
assert.Nil(t, alloc.Put(make([]byte, 1023, 1024)), "put elem:1024 []bytes misbehavior")
|
||||
assert.Nil(t, alloc.Put(make([]byte, 65536, 65536)), "put elem:65536 []bytes misbehavior")
|
||||
assert.NotNil(t, alloc.Put(make([]byte, 65537, 65537)), "put elem:65537 []bytes misbehavior")
|
||||
}
|
||||
|
||||
func TestAllocPutThenGet(t *testing.T) {
|
||||
alloc := NewAllocator()
|
||||
data := alloc.Get(4)
|
||||
alloc.Put(data)
|
||||
newData := alloc.Get(4)
|
||||
|
||||
assert.Equal(t, cap(data), cap(newData), "different cap while alloc.Get()")
|
||||
}
|
||||
|
||||
func BenchmarkMSB(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
msb(rand.Int())
|
||||
}
|
||||
}
|
|
@ -1,15 +1,16 @@
|
|||
package pool
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// io.Copy default buffer size is 32 KiB
|
||||
// but the maximum packet size of vmess/shadowsocks is about 16 KiB
|
||||
// so define a buffer of 20 KiB to reduce the memory of each TCP relay
|
||||
bufferSize = 20 * 1024
|
||||
RelayBufferSize = 20 * 1024
|
||||
)
|
||||
|
||||
// BufPool provide buffer for relay
|
||||
var BufPool = sync.Pool{New: func() interface{} { return make([]byte, bufferSize) }}
|
||||
func Get(size int) []byte {
|
||||
return defaultAllocator.Get(size)
|
||||
}
|
||||
|
||||
func Put(buf []byte) error {
|
||||
return defaultAllocator.Put(buf)
|
||||
}
|
||||
|
|
|
@ -34,15 +34,15 @@ func (ho *HTTPObfs) Read(b []byte) (int, error) {
|
|||
}
|
||||
|
||||
if ho.firstResponse {
|
||||
buf := pool.BufPool.Get().([]byte)
|
||||
buf := pool.Get(pool.RelayBufferSize)
|
||||
n, err := ho.Conn.Read(buf)
|
||||
if err != nil {
|
||||
pool.BufPool.Put(buf[:cap(buf)])
|
||||
pool.Put(buf)
|
||||
return 0, err
|
||||
}
|
||||
idx := bytes.Index(buf[:n], []byte("\r\n\r\n"))
|
||||
if idx == -1 {
|
||||
pool.BufPool.Put(buf[:cap(buf)])
|
||||
pool.Put(buf)
|
||||
return 0, io.EOF
|
||||
}
|
||||
ho.firstResponse = false
|
||||
|
@ -52,7 +52,7 @@ func (ho *HTTPObfs) Read(b []byte) (int, error) {
|
|||
ho.buf = buf[:idx+4+length]
|
||||
ho.offset = idx + 4 + n
|
||||
} else {
|
||||
pool.BufPool.Put(buf[:cap(buf)])
|
||||
pool.Put(buf)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
|
|
@ -29,12 +29,12 @@ type TLSObfs struct {
|
|||
}
|
||||
|
||||
func (to *TLSObfs) read(b []byte, discardN int) (int, error) {
|
||||
buf := pool.BufPool.Get().([]byte)
|
||||
_, err := io.ReadFull(to.Conn, buf[:discardN])
|
||||
buf := pool.Get(discardN)
|
||||
_, err := io.ReadFull(to.Conn, buf)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
pool.BufPool.Put(buf[:cap(buf)])
|
||||
pool.Put(buf)
|
||||
|
||||
sizeBuf := make([]byte, 2)
|
||||
_, err = io.ReadFull(to.Conn, sizeBuf)
|
||||
|
@ -102,15 +102,11 @@ func (to *TLSObfs) write(b []byte) (int, error) {
|
|||
return len(b), err
|
||||
}
|
||||
|
||||
size := pool.BufPool.Get().([]byte)
|
||||
binary.BigEndian.PutUint16(size[:2], uint16(len(b)))
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
buf.Write([]byte{0x17, 0x03, 0x03})
|
||||
buf.Write(size[:2])
|
||||
binary.Write(buf, binary.BigEndian, uint16(len(b)))
|
||||
buf.Write(b)
|
||||
_, err := to.Conn.Write(buf.Bytes())
|
||||
pool.BufPool.Put(size[:cap(size)])
|
||||
return len(b), err
|
||||
}
|
||||
|
||||
|
|
|
@ -22,8 +22,8 @@ func newAEADWriter(w io.Writer, aead cipher.AEAD, iv []byte) *aeadWriter {
|
|||
}
|
||||
|
||||
func (w *aeadWriter) Write(b []byte) (n int, err error) {
|
||||
buf := pool.BufPool.Get().([]byte)
|
||||
defer pool.BufPool.Put(buf[:cap(buf)])
|
||||
buf := pool.Get(pool.RelayBufferSize)
|
||||
defer pool.Put(buf)
|
||||
length := len(b)
|
||||
for {
|
||||
if length == 0 {
|
||||
|
@ -73,7 +73,7 @@ func (r *aeadReader) Read(b []byte) (int, error) {
|
|||
n := copy(b, r.buf[r.offset:])
|
||||
r.offset += n
|
||||
if r.offset == len(r.buf) {
|
||||
pool.BufPool.Put(r.buf[:cap(r.buf)])
|
||||
pool.Put(r.buf)
|
||||
r.buf = nil
|
||||
}
|
||||
return n, nil
|
||||
|
@ -89,10 +89,10 @@ func (r *aeadReader) Read(b []byte) (int, error) {
|
|||
return 0, errors.New("Buffer is larger than standard")
|
||||
}
|
||||
|
||||
buf := pool.BufPool.Get().([]byte)
|
||||
buf := pool.Get(size)
|
||||
_, err = io.ReadFull(r.Reader, buf[:size])
|
||||
if err != nil {
|
||||
pool.BufPool.Put(buf[:cap(buf)])
|
||||
pool.Put(buf)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
|
@ -107,7 +107,7 @@ func (r *aeadReader) Read(b []byte) (int, error) {
|
|||
realLen := size - r.Overhead()
|
||||
n := copy(b, buf[:realLen])
|
||||
if len(b) >= realLen {
|
||||
pool.BufPool.Put(buf[:cap(buf)])
|
||||
pool.Put(buf)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ func (cr *chunkReader) Read(b []byte) (int, error) {
|
|||
n := copy(b, cr.buf[cr.offset:])
|
||||
cr.offset += n
|
||||
if cr.offset == len(cr.buf) {
|
||||
pool.BufPool.Put(cr.buf[:cap(cr.buf)])
|
||||
pool.Put(cr.buf)
|
||||
cr.buf = nil
|
||||
}
|
||||
return n, nil
|
||||
|
@ -59,15 +59,15 @@ func (cr *chunkReader) Read(b []byte) (int, error) {
|
|||
return size, nil
|
||||
}
|
||||
|
||||
buf := pool.BufPool.Get().([]byte)
|
||||
_, err = io.ReadFull(cr.Reader, buf[:size])
|
||||
buf := pool.Get(size)
|
||||
_, err = io.ReadFull(cr.Reader, buf)
|
||||
if err != nil {
|
||||
pool.BufPool.Put(buf[:cap(buf)])
|
||||
pool.Put(buf)
|
||||
return 0, err
|
||||
}
|
||||
n := copy(b, cr.buf[:])
|
||||
n := copy(b, buf)
|
||||
cr.offset = n
|
||||
cr.buf = buf[:size]
|
||||
cr.buf = buf
|
||||
return n, nil
|
||||
}
|
||||
|
||||
|
@ -76,8 +76,8 @@ type chunkWriter struct {
|
|||
}
|
||||
|
||||
func (cw *chunkWriter) Write(b []byte) (n int, err error) {
|
||||
buf := pool.BufPool.Get().([]byte)
|
||||
defer pool.BufPool.Put(buf[:cap(buf)])
|
||||
buf := pool.Get(pool.RelayBufferSize)
|
||||
defer pool.Put(buf)
|
||||
length := len(b)
|
||||
for {
|
||||
if length == 0 {
|
||||
|
|
|
@ -34,10 +34,10 @@ func NewRedirUDPProxy(addr string) (*RedirUDPListener, error) {
|
|||
go func() {
|
||||
oob := make([]byte, 1024)
|
||||
for {
|
||||
buf := pool.BufPool.Get().([]byte)
|
||||
buf := pool.Get(pool.RelayBufferSize)
|
||||
n, oobn, _, lAddr, err := c.ReadMsgUDP(buf, oob)
|
||||
if err != nil {
|
||||
pool.BufPool.Put(buf[:cap(buf)])
|
||||
pool.Put(buf)
|
||||
if rl.closed {
|
||||
break
|
||||
}
|
||||
|
|
|
@ -33,6 +33,6 @@ func (c *packet) LocalAddr() net.Addr {
|
|||
}
|
||||
|
||||
func (c *packet) Drop() {
|
||||
pool.BufPool.Put(c.buf[:cap(c.buf)])
|
||||
pool.Put(c.buf)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -31,10 +31,10 @@ func NewSocksUDPProxy(addr string) (*SockUDPListener, error) {
|
|||
sl := &SockUDPListener{l, addr, false}
|
||||
go func() {
|
||||
for {
|
||||
buf := pool.BufPool.Get().([]byte)
|
||||
buf := pool.Get(pool.RelayBufferSize)
|
||||
n, remoteAddr, err := l.ReadFrom(buf)
|
||||
if err != nil {
|
||||
pool.BufPool.Put(buf[:cap(buf)])
|
||||
pool.Put(buf)
|
||||
if sl.closed {
|
||||
break
|
||||
}
|
||||
|
@ -60,7 +60,7 @@ func handleSocksUDP(pc net.PacketConn, buf []byte, addr net.Addr) {
|
|||
target, payload, err := socks5.DecodeUDPPacket(buf)
|
||||
if err != nil {
|
||||
// Unresolved UDP packet, return buffer to the pool
|
||||
pool.BufPool.Put(buf[:cap(buf)])
|
||||
pool.Put(buf)
|
||||
return
|
||||
}
|
||||
packet := &packet{
|
||||
|
|
|
@ -33,6 +33,6 @@ func (c *packet) LocalAddr() net.Addr {
|
|||
}
|
||||
|
||||
func (c *packet) Drop() {
|
||||
pool.BufPool.Put(c.bufRef[:cap(c.bufRef)])
|
||||
pool.Put(c.bufRef)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -61,9 +61,9 @@ func handleHTTP(request *adapters.HTTPAdapter, outbound net.Conn) {
|
|||
}
|
||||
|
||||
// even if resp.Write write body to the connection, but some http request have to Copy to close it
|
||||
buf := pool.BufPool.Get().([]byte)
|
||||
buf := pool.Get(pool.RelayBufferSize)
|
||||
_, err = io.CopyBuffer(request, resp.Body, buf)
|
||||
pool.BufPool.Put(buf[:cap(buf)])
|
||||
pool.Put(buf)
|
||||
if err != nil && err != io.EOF {
|
||||
break
|
||||
}
|
||||
|
@ -90,8 +90,8 @@ func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata
|
|||
}
|
||||
|
||||
func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, fAddr net.Addr) {
|
||||
buf := pool.BufPool.Get().([]byte)
|
||||
defer pool.BufPool.Put(buf[:cap(buf)])
|
||||
buf := pool.Get(pool.RelayBufferSize)
|
||||
defer pool.Put(buf)
|
||||
defer natTable.Delete(key)
|
||||
defer pc.Close()
|
||||
|
||||
|
@ -122,16 +122,16 @@ func relay(leftConn, rightConn net.Conn) {
|
|||
ch := make(chan error)
|
||||
|
||||
go func() {
|
||||
buf := pool.BufPool.Get().([]byte)
|
||||
buf := pool.Get(pool.RelayBufferSize)
|
||||
_, err := io.CopyBuffer(leftConn, rightConn, buf)
|
||||
pool.BufPool.Put(buf[:cap(buf)])
|
||||
pool.Put(buf)
|
||||
leftConn.SetReadDeadline(time.Now())
|
||||
ch <- err
|
||||
}()
|
||||
|
||||
buf := pool.BufPool.Get().([]byte)
|
||||
buf := pool.Get(pool.RelayBufferSize)
|
||||
io.CopyBuffer(rightConn, leftConn, buf)
|
||||
pool.BufPool.Put(buf[:cap(buf)])
|
||||
pool.Put(buf)
|
||||
rightConn.SetReadDeadline(time.Now())
|
||||
<-ch
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue