chore: make all net.Conn wrapper can pass through N.ExtendedConn

This commit is contained in:
wwqgtxx 2023-04-02 22:24:46 +08:00
parent 2ff0f94262
commit 99f84b8a66
10 changed files with 45 additions and 65 deletions

View file

@ -8,7 +8,6 @@ import (
"net/http" "net/http"
"strconv" "strconv"
N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/component/dialer" "github.com/Dreamacro/clash/component/dialer"
tlsC "github.com/Dreamacro/clash/component/tls" tlsC "github.com/Dreamacro/clash/component/tls"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
@ -105,7 +104,7 @@ func (t *Trojan) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error)
return c, err return c, err
} }
err = t.instance.WriteHeader(c, trojan.CommandTCP, serializesSocksAddr(metadata)) err = t.instance.WriteHeader(c, trojan.CommandTCP, serializesSocksAddr(metadata))
return N.NewExtendedConn(c), err return c, err
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter

View file

@ -60,7 +60,7 @@ type wgSingDialer struct {
dialer dialer.Dialer dialer dialer.Dialer
} }
var _ N.Dialer = &wgSingDialer{} var _ N.Dialer = (*wgSingDialer)(nil)
func (d *wgSingDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (d *wgSingDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return d.dialer.DialContext(ctx, network, destination.String()) return d.dialer.DialContext(ctx, network, destination.String())
@ -74,7 +74,7 @@ type wgNetDialer struct {
tunDevice wireguard.Device tunDevice wireguard.Device
} }
var _ dialer.NetDialer = &wgNetDialer{} var _ dialer.NetDialer = (*wgNetDialer)(nil)
func (d wgNetDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { func (d wgNetDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.tunDevice.DialContext(ctx, network, M.ParseSocksaddr(address).Unwrap()) return d.tunDevice.DialContext(ctx, network, M.ParseSocksaddr(address).Unwrap())

View file

@ -22,53 +22,23 @@ func (c *firstWriteCallBackConn) Write(b []byte) (n int, err error) {
return c.Conn.Write(b) return c.Conn.Write(b)
} }
func (c *firstWriteCallBackConn) WriteBuffer(buffer *buf.Buffer) (err error) {
defer func() {
if !c.written {
c.written = true
c.callback(err)
}
}()
return c.Conn.WriteBuffer(buffer)
}
func (c *firstWriteCallBackConn) Upstream() any { func (c *firstWriteCallBackConn) Upstream() any {
return c.Conn return c.Conn
} }
type extendedConn interface { var _ N.ExtendedConn = (*firstWriteCallBackConn)(nil)
C.Conn
N.ExtendedConn
}
type firstWriteCallBackExtendedConn struct {
extendedConn
callback func(error)
written bool
}
func (c *firstWriteCallBackExtendedConn) Write(b []byte) (n int, err error) {
defer func() {
if !c.written {
c.written = true
c.callback(err)
}
}()
return c.extendedConn.Write(b)
}
func (c *firstWriteCallBackExtendedConn) WriteBuffer(buffer *buf.Buffer) (err error) {
defer func() {
if !c.written {
c.written = true
c.callback(err)
}
}()
return c.extendedConn.WriteBuffer(buffer)
}
func (c *firstWriteCallBackExtendedConn) Upstream() any {
return c.extendedConn
}
func NewFirstWriteCallBackConn(c C.Conn, callback func(error)) C.Conn { func NewFirstWriteCallBackConn(c C.Conn, callback func(error)) C.Conn {
if c, ok := c.(extendedConn); ok {
return &firstWriteCallBackExtendedConn{
extendedConn: c,
callback: callback,
written: false,
}
}
return &firstWriteCallBackConn{ return &firstWriteCallBackConn{
Conn: c, Conn: c,
callback: callback, callback: callback,

View file

@ -4,10 +4,12 @@ import (
"net" "net"
"runtime" "runtime"
"time" "time"
"github.com/Dreamacro/clash/common/buf"
) )
type refConn struct { type refConn struct {
conn net.Conn conn ExtendedConn
ref any ref any
} }
@ -55,8 +57,20 @@ func (c *refConn) Upstream() any {
return c.conn return c.conn
} }
func (c *refConn) ReadBuffer(buffer *buf.Buffer) error {
defer runtime.KeepAlive(c.ref)
return c.conn.ReadBuffer(buffer)
}
func (c *refConn) WriteBuffer(buffer *buf.Buffer) error {
defer runtime.KeepAlive(c.ref)
return c.conn.WriteBuffer(buffer)
}
var _ ExtendedConn = (*refConn)(nil)
func NewRefConn(conn net.Conn, ref any) net.Conn { func NewRefConn(conn net.Conn, ref any) net.Conn {
return &refConn{conn: conn, ref: ref} return &refConn{conn: NewExtendedConn(conn), ref: ref}
} }
type refPacketConn struct { type refPacketConn struct {

View file

@ -8,6 +8,7 @@ import (
"sync" "sync"
"time" "time"
N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/component/dialer" "github.com/Dreamacro/clash/component/dialer"
) )
@ -72,7 +73,7 @@ func (c Chain) Last() string {
} }
type Conn interface { type Conn interface {
net.Conn N.ExtendedConn
Connection Connection
} }

View file

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/Dreamacro/clash/adapter/inbound" "github.com/Dreamacro/clash/adapter/inbound"
N "github.com/Dreamacro/clash/common/net"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
"github.com/Dreamacro/clash/transport/socks5" "github.com/Dreamacro/clash/transport/socks5"
@ -33,7 +34,7 @@ type ListenerHandler struct {
} }
type waitCloseConn struct { type waitCloseConn struct {
net.Conn N.ExtendedConn
wg *sync.WaitGroup wg *sync.WaitGroup
close sync.Once close sync.Once
rAddr net.Addr rAddr net.Addr
@ -43,7 +44,7 @@ func (c *waitCloseConn) Close() error { // call from handleTCPConn(connCtx C.Con
c.close.Do(func() { c.close.Do(func() {
c.wg.Done() c.wg.Done()
}) })
return c.Conn.Close() return c.ExtendedConn.Close()
} }
func (c *waitCloseConn) RemoteAddr() net.Addr { func (c *waitCloseConn) RemoteAddr() net.Addr {
@ -51,7 +52,7 @@ func (c *waitCloseConn) RemoteAddr() net.Addr {
} }
func (c *waitCloseConn) Upstream() any { func (c *waitCloseConn) Upstream() any {
return c.Conn return c.ExtendedConn
} }
func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
@ -79,7 +80,7 @@ func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, meta
defer wg.Wait() // this goroutine must exit after conn.Close() defer wg.Wait() // this goroutine must exit after conn.Close()
wg.Add(1) wg.Add(1)
h.TcpIn <- inbound.NewSocket(target, &waitCloseConn{Conn: conn, wg: wg, rAddr: metadata.Source.TCPAddr()}, h.Type, additions...) h.TcpIn <- inbound.NewSocket(target, &waitCloseConn{ExtendedConn: N.NewExtendedConn(conn), wg: wg, rAddr: metadata.Source.TCPAddr()}, h.Type, additions...)
return nil return nil
} }

View file

@ -955,7 +955,7 @@ func (b *bbrSender) CalculateRecoveryWindow(ackedBytes, lostBytes congestion.Byt
b.recoveryWindow = maxByteCount(b.recoveryWindow, b.minCongestionWindow()) b.recoveryWindow = maxByteCount(b.recoveryWindow, b.minCongestionWindow())
} }
var _ congestion.CongestionControl = &bbrSender{} var _ congestion.CongestionControl = (*bbrSender)(nil)
func (b *bbrSender) GetMinRtt() time.Duration { func (b *bbrSender) GetMinRtt() time.Duration {
if b.minRtt > 0 { if b.minRtt > 0 {

View file

@ -103,7 +103,7 @@ func (q *quicStreamConn) RemoteAddr() net.Addr {
return q.rAddr return q.rAddr
} }
var _ net.Conn = &quicStreamConn{} var _ net.Conn = (*quicStreamConn)(nil)
type quicStreamPacketConn struct { type quicStreamPacketConn struct {
connId uint32 connId uint32
@ -252,4 +252,4 @@ func (q *quicStreamPacketConn) LocalAddr() net.Addr {
return q.quicConn.LocalAddr() return q.quicConn.LocalAddr()
} }
var _ net.PacketConn = &quicStreamPacketConn{} var _ net.PacketConn = (*quicStreamPacketConn)(nil)

View file

@ -294,5 +294,5 @@ func (s *serverUDPPacket) Drop() {
s.packet.DATA = nil s.packet.DATA = nil
} }
var _ C.UDPPacket = &serverUDPPacket{} var _ C.UDPPacket = (*serverUDPPacket)(nil)
var _ C.UDPPacketInAddr = &serverUDPPacket{} var _ C.UDPPacketInAddr = (*serverUDPPacket)(nil)

View file

@ -5,7 +5,6 @@ import (
"time" "time"
"github.com/Dreamacro/clash/common/buf" "github.com/Dreamacro/clash/common/buf"
N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/common/utils" "github.com/Dreamacro/clash/common/utils"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
@ -33,8 +32,6 @@ type tcpTracker struct {
C.Conn `json:"-"` C.Conn `json:"-"`
*trackerInfo *trackerInfo
manager *Manager manager *Manager
extendedReader N.ExtendedReader
extendedWriter N.ExtendedWriter
} }
func (tt *tcpTracker) ID() string { func (tt *tcpTracker) ID() string {
@ -50,7 +47,7 @@ func (tt *tcpTracker) Read(b []byte) (int, error) {
} }
func (tt *tcpTracker) ReadBuffer(buffer *buf.Buffer) (err error) { func (tt *tcpTracker) ReadBuffer(buffer *buf.Buffer) (err error) {
err = tt.extendedReader.ReadBuffer(buffer) err = tt.Conn.ReadBuffer(buffer)
download := int64(buffer.Len()) download := int64(buffer.Len())
tt.manager.PushDownloaded(download) tt.manager.PushDownloaded(download)
tt.DownloadTotal.Add(download) tt.DownloadTotal.Add(download)
@ -67,7 +64,7 @@ func (tt *tcpTracker) Write(b []byte) (int, error) {
func (tt *tcpTracker) WriteBuffer(buffer *buf.Buffer) (err error) { func (tt *tcpTracker) WriteBuffer(buffer *buf.Buffer) (err error) {
upload := int64(buffer.Len()) upload := int64(buffer.Len())
err = tt.extendedWriter.WriteBuffer(buffer) err = tt.Conn.WriteBuffer(buffer)
tt.manager.PushUploaded(upload) tt.manager.PushUploaded(upload)
tt.UploadTotal.Add(upload) tt.UploadTotal.Add(upload)
return return
@ -103,8 +100,6 @@ func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.R
UploadTotal: atomic.NewInt64(uploadTotal), UploadTotal: atomic.NewInt64(uploadTotal),
DownloadTotal: atomic.NewInt64(downloadTotal), DownloadTotal: atomic.NewInt64(downloadTotal),
}, },
extendedReader: N.NewExtendedReader(conn),
extendedWriter: N.NewExtendedWriter(conn),
} }
if rule != nil { if rule != nil {