From a22b1cd69e37b6739512079ca517eb748bf59e63 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Fri, 12 May 2023 09:14:27 +0800 Subject: [PATCH] fix: sing-based listener panic --- adapter/outbound/shadowsocks.go | 2 +- common/net/bind.go | 4 ++-- common/net/packet.go | 12 ++++++++---- common/net/refconn.go | 23 ++++++++++++++++++++--- listener/sing/sing.go | 14 ++++++++++++-- listener/sing_shadowsocks/server.go | 7 +++++-- listener/sing_tun/dns.go | 7 +++++-- tunnel/statistic/tracker.go | 4 ++++ 8 files changed, 57 insertions(+), 16 deletions(-) diff --git a/adapter/outbound/shadowsocks.go b/adapter/outbound/shadowsocks.go index 44bfe64f..02e975ef 100644 --- a/adapter/outbound/shadowsocks.go +++ b/adapter/outbound/shadowsocks.go @@ -193,7 +193,7 @@ func (ss *ShadowSocks) ListenPacketWithDialer(ctx context.Context, dialer C.Dial if err != nil { return nil, err } - pc = ss.method.DialPacketConn(N.NewBindPacketConn(N.NewEnhancePacketConn(pc), addr)) + pc = ss.method.DialPacketConn(N.NewBindPacketConn(pc, addr)) return newPacketConn(pc, ss), nil } diff --git a/common/net/bind.go b/common/net/bind.go index edf51ccb..231c24c2 100644 --- a/common/net/bind.go +++ b/common/net/bind.go @@ -37,9 +37,9 @@ func (c *bindPacketConn) Upstream() any { return c.EnhancePacketConn } -func NewBindPacketConn(pc EnhancePacketConn, rAddr net.Addr) net.Conn { +func NewBindPacketConn(pc net.PacketConn, rAddr net.Addr) net.Conn { return &bindPacketConn{ - EnhancePacketConn: pc, + EnhancePacketConn: NewEnhancePacketConn(pc), rAddr: rAddr, } } diff --git a/common/net/packet.go b/common/net/packet.go index 261c721c..d01c9efe 100644 --- a/common/net/packet.go +++ b/common/net/packet.go @@ -15,20 +15,24 @@ var NewDeadlinePacketConn = deadline.NewPacketConn var NewDeadlineEnhancePacketConn = deadline.NewEnhancePacketConn type threadSafePacketConn struct { - net.PacketConn + EnhancePacketConn access sync.Mutex } func (c *threadSafePacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { c.access.Lock() defer c.access.Unlock() - return c.PacketConn.WriteTo(b, addr) + return c.EnhancePacketConn.WriteTo(b, addr) } func (c *threadSafePacketConn) Upstream() any { - return c.PacketConn + return c.EnhancePacketConn +} + +func (c *threadSafePacketConn) ReaderReplaceable() bool { + return true } func NewThreadSafePacketConn(pc net.PacketConn) net.PacketConn { - return &threadSafePacketConn{PacketConn: pc} + return &threadSafePacketConn{EnhancePacketConn: NewEnhancePacketConn(pc)} } diff --git a/common/net/refconn.go b/common/net/refconn.go index 537cb839..0f32ebc1 100644 --- a/common/net/refconn.go +++ b/common/net/refconn.go @@ -82,10 +82,15 @@ func NewRefConn(conn net.Conn, ref any) net.Conn { } type refPacketConn struct { - pc net.PacketConn + pc EnhancePacketConn ref any } +func (pc *refPacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { + defer runtime.KeepAlive(pc.ref) + return pc.pc.WaitReadFrom() +} + func (pc *refPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { defer runtime.KeepAlive(pc.ref) return pc.pc.ReadFrom(p) @@ -121,6 +126,18 @@ func (pc *refPacketConn) SetWriteDeadline(t time.Time) error { return pc.pc.SetWriteDeadline(t) } -func NewRefPacketConn(pc net.PacketConn, ref any) net.PacketConn { - return &refPacketConn{pc: pc, ref: ref} +func (pc *refPacketConn) Upstream() any { + return pc.pc +} + +func (pc *refPacketConn) ReaderReplaceable() bool { // Relay() will handle reference + return true +} + +func (pc *refPacketConn) WriterReplaceable() bool { // Relay() will handle reference + return true +} + +func NewRefPacketConn(pc net.PacketConn, ref any) net.PacketConn { + return &refPacketConn{pc: NewEnhancePacketConn(pc), ref: ref} } diff --git a/listener/sing/sing.go b/listener/sing/sing.go index fe806e0f..2a2d8474 100644 --- a/listener/sing/sing.go +++ b/listener/sing/sing.go @@ -58,6 +58,14 @@ func (c *waitCloseConn) Upstream() any { return c.ExtendedConn } +func (c *waitCloseConn) ReaderReplaceable() bool { + return true +} + +func (c *waitCloseConn) WriterReplaceable() bool { + return true +} + func UpstreamMetadata(metadata M.Metadata) M.Metadata { return M.Metadata{ Source: metadata.Source, @@ -116,7 +124,7 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network. defer mutex.Unlock() conn2 = nil }() - readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(conn) + connReader := network.UnwrapPacketReader(conn) // decrease runtime cost for bufio.CreatePacketReadWaiter for { var ( buff *buf.Buffer @@ -127,7 +135,9 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network. buff = buf.NewPacket() // do not use stack buffer return buff } - if isReadWaiter { + // syscallPacketReadWaiter.WaitReadPacket will cache newBuffer function + // so create new PacketReadWaiter in each loop + if readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(connReader); isReadWaiter { dest, err = readWaiter.WaitReadPacket(newBuffer) } else { dest, err = conn.ReadPacket(newBuffer()) diff --git a/listener/sing_shadowsocks/server.go b/listener/sing_shadowsocks/server.go index 31b342e8..b35e1238 100644 --- a/listener/sing_shadowsocks/server.go +++ b/listener/sing_shadowsocks/server.go @@ -22,6 +22,7 @@ import ( "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/network" ) type Listener struct { @@ -92,7 +93,7 @@ func New(config LC.ShadowsocksServer, tcpIn chan<- C.ConnContext, udpIn chan<- C go func() { conn := bufio.NewPacketConn(ul) - readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(conn) + connReader := network.UnwrapPacketReader(conn) // decrease runtime cost for bufio.CreatePacketReadWaiter for { var ( buff *buf.Buffer @@ -103,7 +104,9 @@ func New(config LC.ShadowsocksServer, tcpIn chan<- C.ConnContext, udpIn chan<- C buff = buf.NewPacket() // do not use stack buffer return buff } - if isReadWaiter { + // syscallPacketReadWaiter.WaitReadPacket will cache newBuffer function + // so create new PacketReadWaiter in each loop + if readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(connReader); isReadWaiter { dest, err = readWaiter.WaitReadPacket(newBuffer) } else { dest, err = conn.ReadPacket(newBuffer()) diff --git a/listener/sing_tun/dns.go b/listener/sing_tun/dns.go index fcf0cc9c..5ec6d96b 100644 --- a/listener/sing_tun/dns.go +++ b/listener/sing_tun/dns.go @@ -109,7 +109,8 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network. defer mutex.Unlock() conn2 = nil }() - readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(conn) + + connReader := network.UnwrapPacketReader(conn) // decrease runtime cost for bufio.CreatePacketReadWaiter for { var ( buff *buf.Buffer @@ -123,7 +124,9 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network. return buff } _ = conn.SetReadDeadline(time.Now().Add(DefaultDnsReadTimeout)) - if isReadWaiter { + // syscallPacketReadWaiter.WaitReadPacket will cache newBuffer function + // so create new PacketReadWaiter in each loop + if readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(connReader); isReadWaiter { dest, err = readWaiter.WaitReadPacket(newBuffer) } else { dest, err = conn.ReadPacket(newBuffer()) diff --git a/tunnel/statistic/tracker.go b/tunnel/statistic/tracker.go index 685b5e90..a2a921ac 100644 --- a/tunnel/statistic/tracker.go +++ b/tunnel/statistic/tracker.go @@ -211,6 +211,10 @@ func (ut *udpTracker) Close() error { return ut.PacketConn.Close() } +func (ut *udpTracker) Upstream() any { + return ut.PacketConn +} + func NewUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, rule C.Rule, uploadTotal int64, downloadTotal int64, pushToManager bool) *udpTracker { metadata.RemoteDst = parseRemoteDestination(nil, conn)