fix: sing-based listener panic

This commit is contained in:
gVisor bot 2023-05-12 09:14:27 +08:00
parent 5070a21626
commit 4c38b2f0bf
8 changed files with 57 additions and 16 deletions

View file

@ -193,7 +193,7 @@ func (ss *ShadowSocks) ListenPacketWithDialer(ctx context.Context, dialer C.Dial
if err != nil { if err != nil {
return nil, err return nil, err
} }
pc = ss.method.DialPacketConn(N.NewBindPacketConn(N.NewEnhancePacketConn(pc), addr)) pc = ss.method.DialPacketConn(N.NewBindPacketConn(pc, addr))
return newPacketConn(pc, ss), nil return newPacketConn(pc, ss), nil
} }

View file

@ -37,9 +37,9 @@ func (c *bindPacketConn) Upstream() any {
return c.EnhancePacketConn return c.EnhancePacketConn
} }
func NewBindPacketConn(pc EnhancePacketConn, rAddr net.Addr) net.Conn { func NewBindPacketConn(pc net.PacketConn, rAddr net.Addr) net.Conn {
return &bindPacketConn{ return &bindPacketConn{
EnhancePacketConn: pc, EnhancePacketConn: NewEnhancePacketConn(pc),
rAddr: rAddr, rAddr: rAddr,
} }
} }

View file

@ -15,20 +15,24 @@ var NewDeadlinePacketConn = deadline.NewPacketConn
var NewDeadlineEnhancePacketConn = deadline.NewEnhancePacketConn var NewDeadlineEnhancePacketConn = deadline.NewEnhancePacketConn
type threadSafePacketConn struct { type threadSafePacketConn struct {
net.PacketConn EnhancePacketConn
access sync.Mutex access sync.Mutex
} }
func (c *threadSafePacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { func (c *threadSafePacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
c.access.Lock() c.access.Lock()
defer c.access.Unlock() defer c.access.Unlock()
return c.PacketConn.WriteTo(b, addr) return c.EnhancePacketConn.WriteTo(b, addr)
} }
func (c *threadSafePacketConn) Upstream() any { func (c *threadSafePacketConn) Upstream() any {
return c.PacketConn return c.EnhancePacketConn
}
func (c *threadSafePacketConn) ReaderReplaceable() bool {
return true
} }
func NewThreadSafePacketConn(pc net.PacketConn) net.PacketConn { func NewThreadSafePacketConn(pc net.PacketConn) net.PacketConn {
return &threadSafePacketConn{PacketConn: pc} return &threadSafePacketConn{EnhancePacketConn: NewEnhancePacketConn(pc)}
} }

View file

@ -82,10 +82,15 @@ func NewRefConn(conn net.Conn, ref any) net.Conn {
} }
type refPacketConn struct { type refPacketConn struct {
pc net.PacketConn pc EnhancePacketConn
ref any ref any
} }
func (pc *refPacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) {
defer runtime.KeepAlive(pc.ref)
return pc.pc.WaitReadFrom()
}
func (pc *refPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { func (pc *refPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
defer runtime.KeepAlive(pc.ref) defer runtime.KeepAlive(pc.ref)
return pc.pc.ReadFrom(p) return pc.pc.ReadFrom(p)
@ -121,6 +126,18 @@ func (pc *refPacketConn) SetWriteDeadline(t time.Time) error {
return pc.pc.SetWriteDeadline(t) return pc.pc.SetWriteDeadline(t)
} }
func NewRefPacketConn(pc net.PacketConn, ref any) net.PacketConn { func (pc *refPacketConn) Upstream() any {
return &refPacketConn{pc: pc, ref: ref} return pc.pc
}
func (pc *refPacketConn) ReaderReplaceable() bool { // Relay() will handle reference
return true
}
func (pc *refPacketConn) WriterReplaceable() bool { // Relay() will handle reference
return true
}
func NewRefPacketConn(pc net.PacketConn, ref any) net.PacketConn {
return &refPacketConn{pc: NewEnhancePacketConn(pc), ref: ref}
} }

View file

@ -58,6 +58,14 @@ func (c *waitCloseConn) Upstream() any {
return c.ExtendedConn return c.ExtendedConn
} }
func (c *waitCloseConn) ReaderReplaceable() bool {
return true
}
func (c *waitCloseConn) WriterReplaceable() bool {
return true
}
func UpstreamMetadata(metadata M.Metadata) M.Metadata { func UpstreamMetadata(metadata M.Metadata) M.Metadata {
return M.Metadata{ return M.Metadata{
Source: metadata.Source, Source: metadata.Source,
@ -116,7 +124,7 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network.
defer mutex.Unlock() defer mutex.Unlock()
conn2 = nil conn2 = nil
}() }()
readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(conn) connReader := network.UnwrapPacketReader(conn) // decrease runtime cost for bufio.CreatePacketReadWaiter
for { for {
var ( var (
buff *buf.Buffer buff *buf.Buffer
@ -127,7 +135,9 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network.
buff = buf.NewPacket() // do not use stack buffer buff = buf.NewPacket() // do not use stack buffer
return buff return buff
} }
if isReadWaiter { // syscallPacketReadWaiter.WaitReadPacket will cache newBuffer function
// so create new PacketReadWaiter in each loop
if readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(connReader); isReadWaiter {
dest, err = readWaiter.WaitReadPacket(newBuffer) dest, err = readWaiter.WaitReadPacket(newBuffer)
} else { } else {
dest, err = conn.ReadPacket(newBuffer()) dest, err = conn.ReadPacket(newBuffer())

View file

@ -22,6 +22,7 @@ import (
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/network"
) )
type Listener struct { type Listener struct {
@ -92,7 +93,7 @@ func New(config LC.ShadowsocksServer, tcpIn chan<- C.ConnContext, udpIn chan<- C
go func() { go func() {
conn := bufio.NewPacketConn(ul) conn := bufio.NewPacketConn(ul)
readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(conn) connReader := network.UnwrapPacketReader(conn) // decrease runtime cost for bufio.CreatePacketReadWaiter
for { for {
var ( var (
buff *buf.Buffer buff *buf.Buffer
@ -103,7 +104,9 @@ func New(config LC.ShadowsocksServer, tcpIn chan<- C.ConnContext, udpIn chan<- C
buff = buf.NewPacket() // do not use stack buffer buff = buf.NewPacket() // do not use stack buffer
return buff return buff
} }
if isReadWaiter { // syscallPacketReadWaiter.WaitReadPacket will cache newBuffer function
// so create new PacketReadWaiter in each loop
if readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(connReader); isReadWaiter {
dest, err = readWaiter.WaitReadPacket(newBuffer) dest, err = readWaiter.WaitReadPacket(newBuffer)
} else { } else {
dest, err = conn.ReadPacket(newBuffer()) dest, err = conn.ReadPacket(newBuffer())

View file

@ -109,7 +109,8 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network.
defer mutex.Unlock() defer mutex.Unlock()
conn2 = nil conn2 = nil
}() }()
readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(conn)
connReader := network.UnwrapPacketReader(conn) // decrease runtime cost for bufio.CreatePacketReadWaiter
for { for {
var ( var (
buff *buf.Buffer buff *buf.Buffer
@ -123,7 +124,9 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network.
return buff return buff
} }
_ = conn.SetReadDeadline(time.Now().Add(DefaultDnsReadTimeout)) _ = conn.SetReadDeadline(time.Now().Add(DefaultDnsReadTimeout))
if isReadWaiter { // syscallPacketReadWaiter.WaitReadPacket will cache newBuffer function
// so create new PacketReadWaiter in each loop
if readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(connReader); isReadWaiter {
dest, err = readWaiter.WaitReadPacket(newBuffer) dest, err = readWaiter.WaitReadPacket(newBuffer)
} else { } else {
dest, err = conn.ReadPacket(newBuffer()) dest, err = conn.ReadPacket(newBuffer())

View file

@ -211,6 +211,10 @@ func (ut *udpTracker) Close() error {
return ut.PacketConn.Close() return ut.PacketConn.Close()
} }
func (ut *udpTracker) Upstream() any {
return ut.PacketConn
}
func NewUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, rule C.Rule, uploadTotal int64, downloadTotal int64, pushToManager bool) *udpTracker { func NewUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, rule C.Rule, uploadTotal int64, downloadTotal int64, pushToManager bool) *udpTracker {
metadata.RemoteDst = parseRemoteDestination(nil, conn) metadata.RemoteDst = parseRemoteDestination(nil, conn)