diff --git a/listener/tun/device/iobased/endpoint.go b/listener/tun/device/iobased/endpoint.go index 90871ee0..35eb74b3 100644 --- a/listener/tun/device/iobased/endpoint.go +++ b/listener/tun/device/iobased/endpoint.go @@ -7,11 +7,12 @@ package iobased import ( "context" "errors" + "gvisor.dev/gvisor/pkg/bufferv2" "io" "sync" + "github.com/Dreamacro/clash/common/pool" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -95,31 +96,44 @@ func (e *Endpoint) dispatchLoop(cancel context.CancelFunc) { mtu := int(e.mtu) for { - data := make([]byte, mtu) + data := pool.Get(mtu) n, err := e.rw.Read(data) if err != nil { + _ = pool.Put(data) break } if n == 0 || n > mtu { + _ = pool.Put(data) continue } if !e.IsAttached() { + _ = pool.Put(data) continue /* unattached, drop packet */ } - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.View(data[:n]).ToVectorisedView(), - }) - + var p tcpip.NetworkProtocolNumber switch header.IPVersion(data) { case header.IPv4Version: - e.InjectInbound(header.IPv4ProtocolNumber, pkt) + p = header.IPv4ProtocolNumber case header.IPv6Version: - e.InjectInbound(header.IPv6ProtocolNumber, pkt) + p = header.IPv6ProtocolNumber + default: + _ = pool.Put(data) + continue } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: bufferv2.MakeWithData(data), + OnRelease: func() { + _ = pool.Put(data) + }, + }) + + e.InjectInbound(p, pkt) + pkt.DecRef() } } @@ -138,13 +152,14 @@ func (e *Endpoint) outboundLoop(ctx context.Context) { // writePacket writes outbound packets to the io.Writer. func (e *Endpoint) writePacket(pkt *stack.PacketBuffer) tcpip.Error { - defer pkt.DecRef() + pktView := pkt.ToView() - size := pkt.Size() - views := pkt.Views() + defer func() { + pktView.Release() + pkt.DecRef() + }() - vView := buffer.NewVectorisedView(size, views) - if _, err := e.rw.Write(vView.ToView()); err != nil { + if _, err := e.rw.Write(pktView.AsSlice()); err != nil { return &tcpip.ErrInvalidEndpointState{} } return nil