From 8e88e0b9f5c0ba04564ee70d0cd2b0c4da75d8eb Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Sun, 28 May 2023 22:51:26 +0800 Subject: [PATCH] chore: add WaitReadFrom support in ssr --- adapter/outbound/hysteria.go | 2 - adapter/outbound/shadowsocks.go | 34 ---------- adapter/outbound/shadowsocksr.go | 68 +++++++++++++++++++- common/net/packet.go | 1 - common/net/packet/packet.go | 2 - go.mod | 2 +- go.sum | 4 +- listener/shadowsocks/udp.go | 21 +++--- listener/shadowsocks/utils.go | 7 +- transport/shadowsocks/core/cipher.go | 11 ++-- transport/shadowsocks/shadowaead/packet.go | 30 +++++++-- transport/shadowsocks/shadowstream/packet.go | 30 +++++++-- transport/ssr/protocol/auth_aes128_sha1.go | 5 +- transport/ssr/protocol/auth_chain_a.go | 5 +- transport/ssr/protocol/auth_sha1_v4.go | 3 +- transport/ssr/protocol/origin.go | 4 +- transport/ssr/protocol/packet.go | 24 ++++++- transport/ssr/protocol/protocol.go | 4 +- transport/trojan/trojan.go | 12 +++- transport/tuic/conn.go | 1 - 20 files changed, 182 insertions(+), 88 deletions(-) diff --git a/adapter/outbound/hysteria.go b/adapter/outbound/hysteria.go index 161a4546..7da4975d 100644 --- a/adapter/outbound/hysteria.go +++ b/adapter/outbound/hysteria.go @@ -19,7 +19,6 @@ import ( "github.com/metacubex/quic-go/congestion" M "github.com/sagernet/sing/common/metadata" - N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/component/dialer" "github.com/Dreamacro/clash/component/proxydialer" tlsC "github.com/Dreamacro/clash/component/tls" @@ -325,7 +324,6 @@ func (c *hyPacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, e return } data = b - put = N.NilPut addr = M.ParseSocksaddr(addrStr).UDPAddr() return } diff --git a/adapter/outbound/shadowsocks.go b/adapter/outbound/shadowsocks.go index 2edc7080..32558eac 100644 --- a/adapter/outbound/shadowsocks.go +++ b/adapter/outbound/shadowsocks.go @@ -16,7 +16,6 @@ import ( "github.com/Dreamacro/clash/transport/restls" obfs "github.com/Dreamacro/clash/transport/simple-obfs" shadowtls "github.com/Dreamacro/clash/transport/sing-shadowtls" - "github.com/Dreamacro/clash/transport/socks5" v2rayObfs "github.com/Dreamacro/clash/transport/v2ray-plugin" restlsC "github.com/3andne/restls-client-go" @@ -330,36 +329,3 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) { restlsConfig: restlsConfig, }, nil } - -type ssPacketConn struct { - net.PacketConn - rAddr net.Addr -} - -func (spc *ssPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - packet, err := socks5.EncodeUDPPacket(socks5.ParseAddrToSocksAddr(addr), b) - if err != nil { - return - } - return spc.PacketConn.WriteTo(packet[3:], spc.rAddr) -} - -func (spc *ssPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { - n, _, e := spc.PacketConn.ReadFrom(b) - if e != nil { - return 0, nil, e - } - - addr := socks5.SplitAddr(b[:n]) - if addr == nil { - return 0, nil, errors.New("parse addr error") - } - - udpAddr := addr.UDPAddr() - if udpAddr == nil { - return 0, nil, errors.New("parse addr error") - } - - copy(b, b[len(addr):]) - return n - len(addr), udpAddr, e -} diff --git a/adapter/outbound/shadowsocksr.go b/adapter/outbound/shadowsocksr.go index d33d6586..07778032 100644 --- a/adapter/outbound/shadowsocksr.go +++ b/adapter/outbound/shadowsocksr.go @@ -2,16 +2,19 @@ package outbound import ( "context" + "errors" "fmt" "net" "strconv" + N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/component/dialer" "github.com/Dreamacro/clash/component/proxydialer" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/transport/shadowsocks/core" "github.com/Dreamacro/clash/transport/shadowsocks/shadowaead" "github.com/Dreamacro/clash/transport/shadowsocks/shadowstream" + "github.com/Dreamacro/clash/transport/socks5" "github.com/Dreamacro/clash/transport/ssr/obfs" "github.com/Dreamacro/clash/transport/ssr/protocol" ) @@ -110,9 +113,9 @@ func (ssr *ShadowSocksR) ListenPacketWithDialer(ctx context.Context, dialer C.Di return nil, err } - pc = ssr.cipher.PacketConn(pc) - pc = ssr.protocol.PacketConn(pc) - return newPacketConn(&ssPacketConn{PacketConn: pc, rAddr: addr}, ssr), nil + epc := ssr.cipher.PacketConn(N.NewEnhancePacketConn(pc)) + epc = ssr.protocol.PacketConn(epc) + return newPacketConn(&ssrPacketConn{EnhancePacketConn: epc, rAddr: addr}, ssr), nil } // SupportWithDialer implements C.ProxyAdapter @@ -188,3 +191,62 @@ func NewShadowSocksR(option ShadowSocksROption) (*ShadowSocksR, error) { protocol: protocol, }, nil } + +type ssrPacketConn struct { + N.EnhancePacketConn + rAddr net.Addr +} + +func (spc *ssrPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + packet, err := socks5.EncodeUDPPacket(socks5.ParseAddrToSocksAddr(addr), b) + if err != nil { + return + } + return spc.EnhancePacketConn.WriteTo(packet[3:], spc.rAddr) +} + +func (spc *ssrPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { + n, _, e := spc.EnhancePacketConn.ReadFrom(b) + if e != nil { + return 0, nil, e + } + + addr := socks5.SplitAddr(b[:n]) + if addr == nil { + return 0, nil, errors.New("parse addr error") + } + + udpAddr := addr.UDPAddr() + if udpAddr == nil { + return 0, nil, errors.New("parse addr error") + } + + copy(b, b[len(addr):]) + return n - len(addr), udpAddr, e +} + +func (spc *ssrPacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { + data, put, _, err = spc.EnhancePacketConn.WaitReadFrom() + if err != nil { + return nil, nil, nil, err + } + + _addr := socks5.SplitAddr(data) + if _addr == nil { + if put != nil { + put() + } + return nil, nil, nil, errors.New("parse addr error") + } + + addr = _addr.UDPAddr() + if addr == nil { + if put != nil { + put() + } + return nil, nil, nil, errors.New("parse addr error") + } + + data = data[len(_addr):] + return +} diff --git a/common/net/packet.go b/common/net/packet.go index e949ecf2..fc562c42 100644 --- a/common/net/packet.go +++ b/common/net/packet.go @@ -8,7 +8,6 @@ import ( type EnhancePacketConn = packet.EnhancePacketConn type WaitReadFrom = packet.WaitReadFrom -var NilPut = packet.NilPut var NewEnhancePacketConn = packet.NewEnhancePacketConn var NewThreadSafePacketConn = packet.NewThreadSafePacketConn var NewRefPacketConn = packet.NewRefPacketConn diff --git a/common/net/packet/packet.go b/common/net/packet/packet.go index a3f1dd72..6c9542c1 100644 --- a/common/net/packet/packet.go +++ b/common/net/packet/packet.go @@ -10,8 +10,6 @@ type WaitReadFrom interface { WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) } -func NilPut() {} - type EnhancePacketConn interface { net.PacketConn WaitReadFrom diff --git a/go.mod b/go.mod index e5b4882f..b203dd40 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/mdlayher/netlink v1.7.2 github.com/metacubex/quic-go v0.33.3-0.20230510010206-687b537b6a58 github.com/metacubex/sing-shadowsocks v0.2.2-0.20230509230448-a5157cc00a1c - github.com/metacubex/sing-shadowsocks2 v0.0.0-20230519030442-556ef530768f + github.com/metacubex/sing-shadowsocks2 v0.0.0-20230528144023-05418c94ed2d github.com/metacubex/sing-tun v0.1.5-0.20230509224930-30065d4b6376 github.com/metacubex/sing-wireguard v0.0.0-20230426030325-41db09ae771a github.com/miekg/dns v1.1.54 diff --git a/go.sum b/go.sum index bee79cc9..25448530 100644 --- a/go.sum +++ b/go.sum @@ -100,8 +100,8 @@ github.com/metacubex/sing v0.0.0-20230526162852-6afe73474070 h1:AT/Qfe9MvCxyrI9u github.com/metacubex/sing v0.0.0-20230526162852-6afe73474070/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w= github.com/metacubex/sing-shadowsocks v0.2.2-0.20230509230448-a5157cc00a1c h1:LpVNvlW/xE+mR8z76xJeYZlYznZXEmU4TeWeuygYdJg= github.com/metacubex/sing-shadowsocks v0.2.2-0.20230509230448-a5157cc00a1c/go.mod h1:4uQQReKMTU7KTfOykVBe/oGJ00pl38d+BYJ99+mx26s= -github.com/metacubex/sing-shadowsocks2 v0.0.0-20230519030442-556ef530768f h1:aWgVMoAm5V2Ur9key6L//mUSBrVMl/zw/4GDG4ZjyZI= -github.com/metacubex/sing-shadowsocks2 v0.0.0-20230519030442-556ef530768f/go.mod h1:jVDD4N22bDPPKA73NvB7aqdlLWiAwv8D+jx7HwhcWak= +github.com/metacubex/sing-shadowsocks2 v0.0.0-20230528144023-05418c94ed2d h1:lWbWl3pZA1x8TgYDw07jo1u5RtbBRIlxuJDV4FW0WeQ= +github.com/metacubex/sing-shadowsocks2 v0.0.0-20230528144023-05418c94ed2d/go.mod h1:jVDD4N22bDPPKA73NvB7aqdlLWiAwv8D+jx7HwhcWak= github.com/metacubex/sing-tun v0.1.5-0.20230509224930-30065d4b6376 h1:zKNsbFQyleMFAP7NJYRew9sEMJuniuODH3V0FdWnEtk= github.com/metacubex/sing-tun v0.1.5-0.20230509224930-30065d4b6376/go.mod h1:BMfG00enVf90/CzcdX9PK3Dymgl7BZqHXJfexEyB7Cc= github.com/metacubex/sing-vmess v0.1.5-0.20230520082358-78b126617899 h1:iRfcuztp7REfmOyasSlCL/pqNWfUDMTJ2CwbGpxpeks= diff --git a/listener/shadowsocks/udp.go b/listener/shadowsocks/udp.go index ef67d4e8..4efafa60 100644 --- a/listener/shadowsocks/udp.go +++ b/listener/shadowsocks/udp.go @@ -4,7 +4,7 @@ import ( "net" "github.com/Dreamacro/clash/adapter/inbound" - "github.com/Dreamacro/clash/common/pool" + N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/sockopt" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/log" @@ -29,19 +29,20 @@ func NewUDP(addr string, pickCipher core.Cipher, in chan<- C.PacketAdapter) (*UD } sl := &UDPListener{l, false} - conn := pickCipher.PacketConn(l) + conn := pickCipher.PacketConn(N.NewEnhancePacketConn(l)) go func() { for { - buf := pool.Get(pool.UDPBufferSize) - n, remoteAddr, err := conn.ReadFrom(buf) + data, put, remoteAddr, err := conn.WaitReadFrom() if err != nil { - pool.Put(buf) + if put != nil { + put() + } if sl.closed { break } continue } - handleSocksUDP(conn, in, buf[:n], remoteAddr) + handleSocksUDP(conn, in, data, put, remoteAddr) } }() @@ -57,11 +58,13 @@ func (l *UDPListener) LocalAddr() net.Addr { return l.packetConn.LocalAddr() } -func handleSocksUDP(pc net.PacketConn, in chan<- C.PacketAdapter, buf []byte, addr net.Addr) { +func handleSocksUDP(pc net.PacketConn, in chan<- C.PacketAdapter, buf []byte, put func(), addr net.Addr) { tgtAddr := socks5.SplitAddr(buf) if tgtAddr == nil { // Unresolved UDP packet, return buffer to the pool - pool.Put(buf) + if put != nil { + put() + } return } target := socks5.ParseAddr(tgtAddr.String()) @@ -71,7 +74,7 @@ func handleSocksUDP(pc net.PacketConn, in chan<- C.PacketAdapter, buf []byte, ad pc: pc, rAddr: addr, payload: payload, - bufRef: buf, + put: put, } select { case in <- inbound.NewPacket(target, packet, C.SHADOWSOCKS): diff --git a/listener/shadowsocks/utils.go b/listener/shadowsocks/utils.go index 2e9fd003..c34c5cd0 100644 --- a/listener/shadowsocks/utils.go +++ b/listener/shadowsocks/utils.go @@ -6,7 +6,6 @@ import ( "net" "net/url" - "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/transport/socks5" ) @@ -14,7 +13,7 @@ type packet struct { pc net.PacketConn rAddr net.Addr payload []byte - bufRef []byte + put func() } func (c *packet) Data() []byte { @@ -37,7 +36,9 @@ func (c *packet) LocalAddr() net.Addr { } func (c *packet) Drop() { - pool.Put(c.bufRef) + if c.put != nil { + c.put() + } } func (c *packet) InAddr() net.Addr { diff --git a/transport/shadowsocks/core/cipher.go b/transport/shadowsocks/core/cipher.go index 7f4f7f71..cd30360c 100644 --- a/transport/shadowsocks/core/cipher.go +++ b/transport/shadowsocks/core/cipher.go @@ -7,6 +7,7 @@ import ( "sort" "strings" + N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/transport/shadowsocks/shadowaead" "github.com/Dreamacro/clash/transport/shadowsocks/shadowstream" ) @@ -21,7 +22,7 @@ type StreamConnCipher interface { } type PacketConnCipher interface { - PacketConn(net.PacketConn) net.PacketConn + PacketConn(N.EnhancePacketConn) N.EnhancePacketConn } // ErrCipherNotSupported occurs when a cipher is not supported (likely because of security concerns). @@ -128,7 +129,7 @@ type AeadCipher struct { } func (aead *AeadCipher) StreamConn(c net.Conn) net.Conn { return shadowaead.NewConn(c, aead) } -func (aead *AeadCipher) PacketConn(c net.PacketConn) net.PacketConn { +func (aead *AeadCipher) PacketConn(c N.EnhancePacketConn) N.EnhancePacketConn { return shadowaead.NewPacketConn(c, aead) } @@ -139,7 +140,7 @@ type StreamCipher struct { } func (ciph *StreamCipher) StreamConn(c net.Conn) net.Conn { return shadowstream.NewConn(c, ciph) } -func (ciph *StreamCipher) PacketConn(c net.PacketConn) net.PacketConn { +func (ciph *StreamCipher) PacketConn(c N.EnhancePacketConn) N.EnhancePacketConn { return shadowstream.NewPacketConn(c, ciph) } @@ -147,8 +148,8 @@ func (ciph *StreamCipher) PacketConn(c net.PacketConn) net.PacketConn { type dummy struct{} -func (dummy) StreamConn(c net.Conn) net.Conn { return c } -func (dummy) PacketConn(c net.PacketConn) net.PacketConn { return c } +func (dummy) StreamConn(c net.Conn) net.Conn { return c } +func (dummy) PacketConn(c N.EnhancePacketConn) N.EnhancePacketConn { return c } // key-derivation function from original Shadowsocks func Kdf(password string, keyLen int) []byte { diff --git a/transport/shadowsocks/shadowaead/packet.go b/transport/shadowsocks/shadowaead/packet.go index 7043ead7..e84ac570 100644 --- a/transport/shadowsocks/shadowaead/packet.go +++ b/transport/shadowsocks/shadowaead/packet.go @@ -6,6 +6,7 @@ import ( "io" "net" + N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/pool" ) @@ -57,15 +58,15 @@ func Unpack(dst, pkt []byte, ciph Cipher) ([]byte, error) { } type PacketConn struct { - net.PacketConn + N.EnhancePacketConn Cipher } const maxPacketSize = 64 * 1024 -// NewPacketConn wraps a net.PacketConn with cipher -func NewPacketConn(c net.PacketConn, ciph Cipher) *PacketConn { - return &PacketConn{PacketConn: c, Cipher: ciph} +// NewPacketConn wraps an N.EnhancePacketConn with cipher +func NewPacketConn(c N.EnhancePacketConn, ciph Cipher) *PacketConn { + return &PacketConn{EnhancePacketConn: c, Cipher: ciph} } // WriteTo encrypts b and write to addr using the embedded PacketConn. @@ -76,13 +77,13 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { if err != nil { return 0, err } - _, err = c.PacketConn.WriteTo(buf, addr) + _, err = c.EnhancePacketConn.WriteTo(buf, addr) return len(b), err } // ReadFrom reads from the embedded PacketConn and decrypts into b. func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { - n, addr, err := c.PacketConn.ReadFrom(b) + n, addr, err := c.EnhancePacketConn.ReadFrom(b) if err != nil { return n, addr, err } @@ -93,3 +94,20 @@ func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { copy(b, bb) return len(bb), addr, err } + +func (c *PacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { + data, put, addr, err = c.EnhancePacketConn.WaitReadFrom() + if err != nil { + return + } + data, err = Unpack(data[c.Cipher.SaltSize():], data, c) + if err != nil { + if put != nil { + put() + } + data = nil + put = nil + return + } + return +} diff --git a/transport/shadowsocks/shadowstream/packet.go b/transport/shadowsocks/shadowstream/packet.go index 0b46dea1..f0bf43ef 100644 --- a/transport/shadowsocks/shadowstream/packet.go +++ b/transport/shadowsocks/shadowstream/packet.go @@ -6,6 +6,7 @@ import ( "io" "net" + N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/pool" ) @@ -43,13 +44,13 @@ func Unpack(dst, pkt []byte, s Cipher) ([]byte, error) { } type PacketConn struct { - net.PacketConn + N.EnhancePacketConn Cipher } -// NewPacketConn wraps a net.PacketConn with stream cipher encryption/decryption. -func NewPacketConn(c net.PacketConn, ciph Cipher) *PacketConn { - return &PacketConn{PacketConn: c, Cipher: ciph} +// NewPacketConn wraps an N.EnhancePacketConn with stream cipher encryption/decryption. +func NewPacketConn(c N.EnhancePacketConn, ciph Cipher) *PacketConn { + return &PacketConn{EnhancePacketConn: c, Cipher: ciph} } const maxPacketSize = 64 * 1024 @@ -61,12 +62,12 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { if err != nil { return 0, err } - _, err = c.PacketConn.WriteTo(buf, addr) + _, err = c.EnhancePacketConn.WriteTo(buf, addr) return len(b), err } func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { - n, addr, err := c.PacketConn.ReadFrom(b) + n, addr, err := c.EnhancePacketConn.ReadFrom(b) if err != nil { return n, addr, err } @@ -77,3 +78,20 @@ func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { copy(b, bb) return len(bb), addr, err } + +func (c *PacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { + data, put, addr, err = c.EnhancePacketConn.WaitReadFrom() + if err != nil { + return + } + data, err = Unpack(data[c.IVSize():], data, c) + if err != nil { + if put != nil { + put() + } + data = nil + put = nil + return + } + return +} diff --git a/transport/ssr/protocol/auth_aes128_sha1.go b/transport/ssr/protocol/auth_aes128_sha1.go index 4de48151..e2f0e143 100644 --- a/transport/ssr/protocol/auth_aes128_sha1.go +++ b/transport/ssr/protocol/auth_aes128_sha1.go @@ -8,6 +8,7 @@ import ( "strconv" "strings" + N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/transport/ssr/tools" @@ -82,13 +83,13 @@ func (a *authAES128) StreamConn(c net.Conn, iv []byte) net.Conn { return &Conn{Conn: c, Protocol: p} } -func (a *authAES128) PacketConn(c net.PacketConn) net.PacketConn { +func (a *authAES128) PacketConn(c N.EnhancePacketConn) N.EnhancePacketConn { p := &authAES128{ Base: a.Base, authAES128Function: a.authAES128Function, userData: a.userData, } - return &PacketConn{PacketConn: c, Protocol: p} + return &PacketConn{EnhancePacketConn: c, Protocol: p} } func (a *authAES128) Decode(dst, src *bytes.Buffer) error { diff --git a/transport/ssr/protocol/auth_chain_a.go b/transport/ssr/protocol/auth_chain_a.go index 6b12ab9b..23efb390 100644 --- a/transport/ssr/protocol/auth_chain_a.go +++ b/transport/ssr/protocol/auth_chain_a.go @@ -11,6 +11,7 @@ import ( "strconv" "strings" + N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/transport/shadowsocks/core" @@ -83,13 +84,13 @@ func (a *authChainA) StreamConn(c net.Conn, iv []byte) net.Conn { return &Conn{Conn: c, Protocol: p} } -func (a *authChainA) PacketConn(c net.PacketConn) net.PacketConn { +func (a *authChainA) PacketConn(c N.EnhancePacketConn) N.EnhancePacketConn { p := &authChainA{ Base: a.Base, salt: a.salt, userData: a.userData, } - return &PacketConn{PacketConn: c, Protocol: p} + return &PacketConn{EnhancePacketConn: c, Protocol: p} } func (a *authChainA) Decode(dst, src *bytes.Buffer) error { diff --git a/transport/ssr/protocol/auth_sha1_v4.go b/transport/ssr/protocol/auth_sha1_v4.go index 9e814ac2..26039181 100644 --- a/transport/ssr/protocol/auth_sha1_v4.go +++ b/transport/ssr/protocol/auth_sha1_v4.go @@ -7,6 +7,7 @@ import ( "hash/crc32" "net" + N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/transport/ssr/tools" @@ -35,7 +36,7 @@ func (a *authSHA1V4) StreamConn(c net.Conn, iv []byte) net.Conn { return &Conn{Conn: c, Protocol: p} } -func (a *authSHA1V4) PacketConn(c net.PacketConn) net.PacketConn { +func (a *authSHA1V4) PacketConn(c N.EnhancePacketConn) N.EnhancePacketConn { return c } diff --git a/transport/ssr/protocol/origin.go b/transport/ssr/protocol/origin.go index 80fdfa9a..52525a2f 100644 --- a/transport/ssr/protocol/origin.go +++ b/transport/ssr/protocol/origin.go @@ -3,6 +3,8 @@ package protocol import ( "bytes" "net" + + N "github.com/Dreamacro/clash/common/net" ) type origin struct{} @@ -13,7 +15,7 @@ func newOrigin(b *Base) Protocol { return &origin{} } func (o *origin) StreamConn(c net.Conn, iv []byte) net.Conn { return c } -func (o *origin) PacketConn(c net.PacketConn) net.PacketConn { return c } +func (o *origin) PacketConn(c N.EnhancePacketConn) N.EnhancePacketConn { return c } func (o *origin) Decode(dst, src *bytes.Buffer) error { dst.ReadFrom(src) diff --git a/transport/ssr/protocol/packet.go b/transport/ssr/protocol/packet.go index 249db70a..988ff75d 100644 --- a/transport/ssr/protocol/packet.go +++ b/transport/ssr/protocol/packet.go @@ -3,11 +3,12 @@ package protocol import ( "net" + N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/pool" ) type PacketConn struct { - net.PacketConn + N.EnhancePacketConn Protocol } @@ -18,12 +19,12 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { if err != nil { return 0, err } - _, err = c.PacketConn.WriteTo(buf.Bytes(), addr) + _, err = c.EnhancePacketConn.WriteTo(buf.Bytes(), addr) return len(b), err } func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { - n, addr, err := c.PacketConn.ReadFrom(b) + n, addr, err := c.EnhancePacketConn.ReadFrom(b) if err != nil { return n, addr, err } @@ -34,3 +35,20 @@ func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { copy(b, decoded) return len(decoded), addr, nil } + +func (c *PacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { + data, put, addr, err = c.EnhancePacketConn.WaitReadFrom() + if err != nil { + return + } + data, err = c.DecodePacket(data) + if err != nil { + if put != nil { + put() + } + data = nil + put = nil + return + } + return +} diff --git a/transport/ssr/protocol/protocol.go b/transport/ssr/protocol/protocol.go index 5b86ecb9..1c27da48 100644 --- a/transport/ssr/protocol/protocol.go +++ b/transport/ssr/protocol/protocol.go @@ -6,6 +6,8 @@ import ( "fmt" "net" + N "github.com/Dreamacro/clash/common/net" + "github.com/zhangyunhao116/fastrand" ) @@ -22,7 +24,7 @@ var ( type Protocol interface { StreamConn(net.Conn, []byte) net.Conn - PacketConn(net.PacketConn) net.PacketConn + PacketConn(N.EnhancePacketConn) N.EnhancePacketConn Decode(dst, src *bytes.Buffer) error Encode(buf *bytes.Buffer, b []byte) error DecodePacket([]byte) ([]byte, error) diff --git a/transport/trojan/trojan.go b/transport/trojan/trojan.go index d37026c1..abe21f34 100644 --- a/transport/trojan/trojan.go +++ b/transport/trojan/trojan.go @@ -370,7 +370,9 @@ func (pc *PacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, er _, err = io.ReadFull(pc.Conn, data[:2+2]) // u16be length + CR LF if err != nil { - put() + if put != nil { + put() + } return nil, nil, nil, err } length := binary.BigEndian.Uint16(data) @@ -379,11 +381,15 @@ func (pc *PacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, er data = data[:length] _, err = io.ReadFull(pc.Conn, data) if err != nil { - put() + if put != nil { + put() + } return nil, nil, nil, err } } else { - put() + if put != nil { + put() + } return nil, nil, addr, nil } diff --git a/transport/tuic/conn.go b/transport/tuic/conn.go index d46a3556..f226746d 100644 --- a/transport/tuic/conn.go +++ b/transport/tuic/conn.go @@ -205,7 +205,6 @@ func (q *quicStreamPacketConn) WaitReadFrom() (data []byte, put func(), addr net return } data = packet.DATA - put = N.NilPut addr = packet.ADDR.UDPAddr() } else { err = net.ErrClosed