fix: sing-based listener panic
This commit is contained in:
parent
5070a21626
commit
4c38b2f0bf
8 changed files with 57 additions and 16 deletions
|
@ -193,7 +193,7 @@ func (ss *ShadowSocks) ListenPacketWithDialer(ctx context.Context, dialer C.Dial
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
pc = ss.method.DialPacketConn(N.NewBindPacketConn(N.NewEnhancePacketConn(pc), addr))
|
pc = ss.method.DialPacketConn(N.NewBindPacketConn(pc, addr))
|
||||||
return newPacketConn(pc, ss), nil
|
return newPacketConn(pc, ss), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,9 +37,9 @@ func (c *bindPacketConn) Upstream() any {
|
||||||
return c.EnhancePacketConn
|
return c.EnhancePacketConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBindPacketConn(pc EnhancePacketConn, rAddr net.Addr) net.Conn {
|
func NewBindPacketConn(pc net.PacketConn, rAddr net.Addr) net.Conn {
|
||||||
return &bindPacketConn{
|
return &bindPacketConn{
|
||||||
EnhancePacketConn: pc,
|
EnhancePacketConn: NewEnhancePacketConn(pc),
|
||||||
rAddr: rAddr,
|
rAddr: rAddr,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,20 +15,24 @@ var NewDeadlinePacketConn = deadline.NewPacketConn
|
||||||
var NewDeadlineEnhancePacketConn = deadline.NewEnhancePacketConn
|
var NewDeadlineEnhancePacketConn = deadline.NewEnhancePacketConn
|
||||||
|
|
||||||
type threadSafePacketConn struct {
|
type threadSafePacketConn struct {
|
||||||
net.PacketConn
|
EnhancePacketConn
|
||||||
access sync.Mutex
|
access sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *threadSafePacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
func (c *threadSafePacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||||
c.access.Lock()
|
c.access.Lock()
|
||||||
defer c.access.Unlock()
|
defer c.access.Unlock()
|
||||||
return c.PacketConn.WriteTo(b, addr)
|
return c.EnhancePacketConn.WriteTo(b, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *threadSafePacketConn) Upstream() any {
|
func (c *threadSafePacketConn) Upstream() any {
|
||||||
return c.PacketConn
|
return c.EnhancePacketConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *threadSafePacketConn) ReaderReplaceable() bool {
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewThreadSafePacketConn(pc net.PacketConn) net.PacketConn {
|
func NewThreadSafePacketConn(pc net.PacketConn) net.PacketConn {
|
||||||
return &threadSafePacketConn{PacketConn: pc}
|
return &threadSafePacketConn{EnhancePacketConn: NewEnhancePacketConn(pc)}
|
||||||
}
|
}
|
||||||
|
|
|
@ -82,10 +82,15 @@ func NewRefConn(conn net.Conn, ref any) net.Conn {
|
||||||
}
|
}
|
||||||
|
|
||||||
type refPacketConn struct {
|
type refPacketConn struct {
|
||||||
pc net.PacketConn
|
pc EnhancePacketConn
|
||||||
ref any
|
ref any
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pc *refPacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) {
|
||||||
|
defer runtime.KeepAlive(pc.ref)
|
||||||
|
return pc.pc.WaitReadFrom()
|
||||||
|
}
|
||||||
|
|
||||||
func (pc *refPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
func (pc *refPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||||
defer runtime.KeepAlive(pc.ref)
|
defer runtime.KeepAlive(pc.ref)
|
||||||
return pc.pc.ReadFrom(p)
|
return pc.pc.ReadFrom(p)
|
||||||
|
@ -121,6 +126,18 @@ func (pc *refPacketConn) SetWriteDeadline(t time.Time) error {
|
||||||
return pc.pc.SetWriteDeadline(t)
|
return pc.pc.SetWriteDeadline(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRefPacketConn(pc net.PacketConn, ref any) net.PacketConn {
|
func (pc *refPacketConn) Upstream() any {
|
||||||
return &refPacketConn{pc: pc, ref: ref}
|
return pc.pc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pc *refPacketConn) ReaderReplaceable() bool { // Relay() will handle reference
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pc *refPacketConn) WriterReplaceable() bool { // Relay() will handle reference
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRefPacketConn(pc net.PacketConn, ref any) net.PacketConn {
|
||||||
|
return &refPacketConn{pc: NewEnhancePacketConn(pc), ref: ref}
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,6 +58,14 @@ func (c *waitCloseConn) Upstream() any {
|
||||||
return c.ExtendedConn
|
return c.ExtendedConn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *waitCloseConn) ReaderReplaceable() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *waitCloseConn) WriterReplaceable() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func UpstreamMetadata(metadata M.Metadata) M.Metadata {
|
func UpstreamMetadata(metadata M.Metadata) M.Metadata {
|
||||||
return M.Metadata{
|
return M.Metadata{
|
||||||
Source: metadata.Source,
|
Source: metadata.Source,
|
||||||
|
@ -116,7 +124,7 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network.
|
||||||
defer mutex.Unlock()
|
defer mutex.Unlock()
|
||||||
conn2 = nil
|
conn2 = nil
|
||||||
}()
|
}()
|
||||||
readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(conn)
|
connReader := network.UnwrapPacketReader(conn) // decrease runtime cost for bufio.CreatePacketReadWaiter
|
||||||
for {
|
for {
|
||||||
var (
|
var (
|
||||||
buff *buf.Buffer
|
buff *buf.Buffer
|
||||||
|
@ -127,7 +135,9 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network.
|
||||||
buff = buf.NewPacket() // do not use stack buffer
|
buff = buf.NewPacket() // do not use stack buffer
|
||||||
return buff
|
return buff
|
||||||
}
|
}
|
||||||
if isReadWaiter {
|
// syscallPacketReadWaiter.WaitReadPacket will cache newBuffer function
|
||||||
|
// so create new PacketReadWaiter in each loop
|
||||||
|
if readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(connReader); isReadWaiter {
|
||||||
dest, err = readWaiter.WaitReadPacket(newBuffer)
|
dest, err = readWaiter.WaitReadPacket(newBuffer)
|
||||||
} else {
|
} else {
|
||||||
dest, err = conn.ReadPacket(newBuffer())
|
dest, err = conn.ReadPacket(newBuffer())
|
||||||
|
|
|
@ -22,6 +22,7 @@ import (
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
"github.com/sagernet/sing/common/bufio"
|
"github.com/sagernet/sing/common/bufio"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
"github.com/sagernet/sing/common/network"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Listener struct {
|
type Listener struct {
|
||||||
|
@ -92,7 +93,7 @@ func New(config LC.ShadowsocksServer, tcpIn chan<- C.ConnContext, udpIn chan<- C
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
conn := bufio.NewPacketConn(ul)
|
conn := bufio.NewPacketConn(ul)
|
||||||
readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(conn)
|
connReader := network.UnwrapPacketReader(conn) // decrease runtime cost for bufio.CreatePacketReadWaiter
|
||||||
for {
|
for {
|
||||||
var (
|
var (
|
||||||
buff *buf.Buffer
|
buff *buf.Buffer
|
||||||
|
@ -103,7 +104,9 @@ func New(config LC.ShadowsocksServer, tcpIn chan<- C.ConnContext, udpIn chan<- C
|
||||||
buff = buf.NewPacket() // do not use stack buffer
|
buff = buf.NewPacket() // do not use stack buffer
|
||||||
return buff
|
return buff
|
||||||
}
|
}
|
||||||
if isReadWaiter {
|
// syscallPacketReadWaiter.WaitReadPacket will cache newBuffer function
|
||||||
|
// so create new PacketReadWaiter in each loop
|
||||||
|
if readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(connReader); isReadWaiter {
|
||||||
dest, err = readWaiter.WaitReadPacket(newBuffer)
|
dest, err = readWaiter.WaitReadPacket(newBuffer)
|
||||||
} else {
|
} else {
|
||||||
dest, err = conn.ReadPacket(newBuffer())
|
dest, err = conn.ReadPacket(newBuffer())
|
||||||
|
|
|
@ -109,7 +109,8 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network.
|
||||||
defer mutex.Unlock()
|
defer mutex.Unlock()
|
||||||
conn2 = nil
|
conn2 = nil
|
||||||
}()
|
}()
|
||||||
readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(conn)
|
|
||||||
|
connReader := network.UnwrapPacketReader(conn) // decrease runtime cost for bufio.CreatePacketReadWaiter
|
||||||
for {
|
for {
|
||||||
var (
|
var (
|
||||||
buff *buf.Buffer
|
buff *buf.Buffer
|
||||||
|
@ -123,7 +124,9 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network.
|
||||||
return buff
|
return buff
|
||||||
}
|
}
|
||||||
_ = conn.SetReadDeadline(time.Now().Add(DefaultDnsReadTimeout))
|
_ = conn.SetReadDeadline(time.Now().Add(DefaultDnsReadTimeout))
|
||||||
if isReadWaiter {
|
// syscallPacketReadWaiter.WaitReadPacket will cache newBuffer function
|
||||||
|
// so create new PacketReadWaiter in each loop
|
||||||
|
if readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(connReader); isReadWaiter {
|
||||||
dest, err = readWaiter.WaitReadPacket(newBuffer)
|
dest, err = readWaiter.WaitReadPacket(newBuffer)
|
||||||
} else {
|
} else {
|
||||||
dest, err = conn.ReadPacket(newBuffer())
|
dest, err = conn.ReadPacket(newBuffer())
|
||||||
|
|
|
@ -211,6 +211,10 @@ func (ut *udpTracker) Close() error {
|
||||||
return ut.PacketConn.Close()
|
return ut.PacketConn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ut *udpTracker) Upstream() any {
|
||||||
|
return ut.PacketConn
|
||||||
|
}
|
||||||
|
|
||||||
func NewUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, rule C.Rule, uploadTotal int64, downloadTotal int64, pushToManager bool) *udpTracker {
|
func NewUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, rule C.Rule, uploadTotal int64, downloadTotal int64, pushToManager bool) *udpTracker {
|
||||||
metadata.RemoteDst = parseRemoteDestination(nil, conn)
|
metadata.RemoteDst = parseRemoteDestination(nil, conn)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue