From 99f84b8a66de6bf32b6c3d6d6e725633dece7047 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Sun, 2 Apr 2023 22:24:46 +0800 Subject: [PATCH] chore: make all net.Conn wrapper can pass through N.ExtendedConn --- adapter/outbound/trojan.go | 3 +- adapter/outbound/wireguard.go | 4 +- common/callback/callback.go | 52 ++++++------------------- common/net/refconn.go | 18 ++++++++- constant/adapters.go | 3 +- listener/sing/sing.go | 9 +++-- transport/tuic/congestion/bbr_sender.go | 2 +- transport/tuic/conn.go | 4 +- transport/tuic/server.go | 4 +- tunnel/statistic/tracker.go | 11 ++---- 10 files changed, 45 insertions(+), 65 deletions(-) diff --git a/adapter/outbound/trojan.go b/adapter/outbound/trojan.go index 030e74a9..4a31538b 100644 --- a/adapter/outbound/trojan.go +++ b/adapter/outbound/trojan.go @@ -8,7 +8,6 @@ import ( "net/http" "strconv" - N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/component/dialer" tlsC "github.com/Dreamacro/clash/component/tls" C "github.com/Dreamacro/clash/constant" @@ -105,7 +104,7 @@ func (t *Trojan) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) return c, err } err = t.instance.WriteHeader(c, trojan.CommandTCP, serializesSocksAddr(metadata)) - return N.NewExtendedConn(c), err + return c, err } // DialContext implements C.ProxyAdapter diff --git a/adapter/outbound/wireguard.go b/adapter/outbound/wireguard.go index 881bdf99..f82d120e 100644 --- a/adapter/outbound/wireguard.go +++ b/adapter/outbound/wireguard.go @@ -60,7 +60,7 @@ type wgSingDialer struct { dialer dialer.Dialer } -var _ N.Dialer = &wgSingDialer{} +var _ N.Dialer = (*wgSingDialer)(nil) func (d *wgSingDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { return d.dialer.DialContext(ctx, network, destination.String()) @@ -74,7 +74,7 @@ type wgNetDialer struct { tunDevice wireguard.Device } -var _ dialer.NetDialer = &wgNetDialer{} +var _ dialer.NetDialer = (*wgNetDialer)(nil) func (d wgNetDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { return d.tunDevice.DialContext(ctx, network, M.ParseSocksaddr(address).Unwrap()) diff --git a/common/callback/callback.go b/common/callback/callback.go index 9d64bb92..0bf720f4 100644 --- a/common/callback/callback.go +++ b/common/callback/callback.go @@ -22,53 +22,23 @@ func (c *firstWriteCallBackConn) Write(b []byte) (n int, err error) { return c.Conn.Write(b) } +func (c *firstWriteCallBackConn) WriteBuffer(buffer *buf.Buffer) (err error) { + defer func() { + if !c.written { + c.written = true + c.callback(err) + } + }() + return c.Conn.WriteBuffer(buffer) +} + func (c *firstWriteCallBackConn) Upstream() any { return c.Conn } -type extendedConn interface { - C.Conn - N.ExtendedConn -} - -type firstWriteCallBackExtendedConn struct { - extendedConn - callback func(error) - written bool -} - -func (c *firstWriteCallBackExtendedConn) Write(b []byte) (n int, err error) { - defer func() { - if !c.written { - c.written = true - c.callback(err) - } - }() - return c.extendedConn.Write(b) -} - -func (c *firstWriteCallBackExtendedConn) WriteBuffer(buffer *buf.Buffer) (err error) { - defer func() { - if !c.written { - c.written = true - c.callback(err) - } - }() - return c.extendedConn.WriteBuffer(buffer) -} - -func (c *firstWriteCallBackExtendedConn) Upstream() any { - return c.extendedConn -} +var _ N.ExtendedConn = (*firstWriteCallBackConn)(nil) func NewFirstWriteCallBackConn(c C.Conn, callback func(error)) C.Conn { - if c, ok := c.(extendedConn); ok { - return &firstWriteCallBackExtendedConn{ - extendedConn: c, - callback: callback, - written: false, - } - } return &firstWriteCallBackConn{ Conn: c, callback: callback, diff --git a/common/net/refconn.go b/common/net/refconn.go index 324e6474..59225db0 100644 --- a/common/net/refconn.go +++ b/common/net/refconn.go @@ -4,10 +4,12 @@ import ( "net" "runtime" "time" + + "github.com/Dreamacro/clash/common/buf" ) type refConn struct { - conn net.Conn + conn ExtendedConn ref any } @@ -55,8 +57,20 @@ func (c *refConn) Upstream() any { return c.conn } +func (c *refConn) ReadBuffer(buffer *buf.Buffer) error { + defer runtime.KeepAlive(c.ref) + return c.conn.ReadBuffer(buffer) +} + +func (c *refConn) WriteBuffer(buffer *buf.Buffer) error { + defer runtime.KeepAlive(c.ref) + return c.conn.WriteBuffer(buffer) +} + +var _ ExtendedConn = (*refConn)(nil) + func NewRefConn(conn net.Conn, ref any) net.Conn { - return &refConn{conn: conn, ref: ref} + return &refConn{conn: NewExtendedConn(conn), ref: ref} } type refPacketConn struct { diff --git a/constant/adapters.go b/constant/adapters.go index bf5f7fdb..ce9f9911 100644 --- a/constant/adapters.go +++ b/constant/adapters.go @@ -8,6 +8,7 @@ import ( "sync" "time" + N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/component/dialer" ) @@ -72,7 +73,7 @@ func (c Chain) Last() string { } type Conn interface { - net.Conn + N.ExtendedConn Connection } diff --git a/listener/sing/sing.go b/listener/sing/sing.go index 70462728..2a5a7d50 100644 --- a/listener/sing/sing.go +++ b/listener/sing/sing.go @@ -10,6 +10,7 @@ import ( "time" "github.com/Dreamacro/clash/adapter/inbound" + N "github.com/Dreamacro/clash/common/net" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/transport/socks5" @@ -33,7 +34,7 @@ type ListenerHandler struct { } type waitCloseConn struct { - net.Conn + N.ExtendedConn wg *sync.WaitGroup close sync.Once rAddr net.Addr @@ -43,7 +44,7 @@ func (c *waitCloseConn) Close() error { // call from handleTCPConn(connCtx C.Con c.close.Do(func() { c.wg.Done() }) - return c.Conn.Close() + return c.ExtendedConn.Close() } func (c *waitCloseConn) RemoteAddr() net.Addr { @@ -51,7 +52,7 @@ func (c *waitCloseConn) RemoteAddr() net.Addr { } func (c *waitCloseConn) Upstream() any { - return c.Conn + return c.ExtendedConn } func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { @@ -79,7 +80,7 @@ func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, meta defer wg.Wait() // this goroutine must exit after conn.Close() wg.Add(1) - h.TcpIn <- inbound.NewSocket(target, &waitCloseConn{Conn: conn, wg: wg, rAddr: metadata.Source.TCPAddr()}, h.Type, additions...) + h.TcpIn <- inbound.NewSocket(target, &waitCloseConn{ExtendedConn: N.NewExtendedConn(conn), wg: wg, rAddr: metadata.Source.TCPAddr()}, h.Type, additions...) return nil } diff --git a/transport/tuic/congestion/bbr_sender.go b/transport/tuic/congestion/bbr_sender.go index d848a9a8..99164362 100644 --- a/transport/tuic/congestion/bbr_sender.go +++ b/transport/tuic/congestion/bbr_sender.go @@ -955,7 +955,7 @@ func (b *bbrSender) CalculateRecoveryWindow(ackedBytes, lostBytes congestion.Byt b.recoveryWindow = maxByteCount(b.recoveryWindow, b.minCongestionWindow()) } -var _ congestion.CongestionControl = &bbrSender{} +var _ congestion.CongestionControl = (*bbrSender)(nil) func (b *bbrSender) GetMinRtt() time.Duration { if b.minRtt > 0 { diff --git a/transport/tuic/conn.go b/transport/tuic/conn.go index d5955e13..8f63da75 100644 --- a/transport/tuic/conn.go +++ b/transport/tuic/conn.go @@ -103,7 +103,7 @@ func (q *quicStreamConn) RemoteAddr() net.Addr { return q.rAddr } -var _ net.Conn = &quicStreamConn{} +var _ net.Conn = (*quicStreamConn)(nil) type quicStreamPacketConn struct { connId uint32 @@ -252,4 +252,4 @@ func (q *quicStreamPacketConn) LocalAddr() net.Addr { return q.quicConn.LocalAddr() } -var _ net.PacketConn = &quicStreamPacketConn{} +var _ net.PacketConn = (*quicStreamPacketConn)(nil) diff --git a/transport/tuic/server.go b/transport/tuic/server.go index 5eb6e611..88169aed 100644 --- a/transport/tuic/server.go +++ b/transport/tuic/server.go @@ -294,5 +294,5 @@ func (s *serverUDPPacket) Drop() { s.packet.DATA = nil } -var _ C.UDPPacket = &serverUDPPacket{} -var _ C.UDPPacketInAddr = &serverUDPPacket{} +var _ C.UDPPacket = (*serverUDPPacket)(nil) +var _ C.UDPPacketInAddr = (*serverUDPPacket)(nil) diff --git a/tunnel/statistic/tracker.go b/tunnel/statistic/tracker.go index f7e9d971..11b9c0cd 100644 --- a/tunnel/statistic/tracker.go +++ b/tunnel/statistic/tracker.go @@ -5,7 +5,6 @@ import ( "time" "github.com/Dreamacro/clash/common/buf" - N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/utils" C "github.com/Dreamacro/clash/constant" @@ -32,9 +31,7 @@ type trackerInfo struct { type tcpTracker struct { C.Conn `json:"-"` *trackerInfo - manager *Manager - extendedReader N.ExtendedReader - extendedWriter N.ExtendedWriter + manager *Manager } func (tt *tcpTracker) ID() string { @@ -50,7 +47,7 @@ func (tt *tcpTracker) Read(b []byte) (int, error) { } func (tt *tcpTracker) ReadBuffer(buffer *buf.Buffer) (err error) { - err = tt.extendedReader.ReadBuffer(buffer) + err = tt.Conn.ReadBuffer(buffer) download := int64(buffer.Len()) tt.manager.PushDownloaded(download) tt.DownloadTotal.Add(download) @@ -67,7 +64,7 @@ func (tt *tcpTracker) Write(b []byte) (int, error) { func (tt *tcpTracker) WriteBuffer(buffer *buf.Buffer) (err error) { upload := int64(buffer.Len()) - err = tt.extendedWriter.WriteBuffer(buffer) + err = tt.Conn.WriteBuffer(buffer) tt.manager.PushUploaded(upload) tt.UploadTotal.Add(upload) return @@ -103,8 +100,6 @@ func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.R UploadTotal: atomic.NewInt64(uploadTotal), DownloadTotal: atomic.NewInt64(downloadTotal), }, - extendedReader: N.NewExtendedReader(conn), - extendedWriter: N.NewExtendedWriter(conn), } if rule != nil {