chore: make all net.Conn wrapper can pass through N.ExtendedConn
This commit is contained in:
parent
d7cd598e53
commit
ae5fafa885
10 changed files with 45 additions and 65 deletions
|
@ -8,7 +8,6 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
N "github.com/Dreamacro/clash/common/net"
|
|
||||||
"github.com/Dreamacro/clash/component/dialer"
|
"github.com/Dreamacro/clash/component/dialer"
|
||||||
tlsC "github.com/Dreamacro/clash/component/tls"
|
tlsC "github.com/Dreamacro/clash/component/tls"
|
||||||
C "github.com/Dreamacro/clash/constant"
|
C "github.com/Dreamacro/clash/constant"
|
||||||
|
@ -105,7 +104,7 @@ func (t *Trojan) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error)
|
||||||
return c, err
|
return c, err
|
||||||
}
|
}
|
||||||
err = t.instance.WriteHeader(c, trojan.CommandTCP, serializesSocksAddr(metadata))
|
err = t.instance.WriteHeader(c, trojan.CommandTCP, serializesSocksAddr(metadata))
|
||||||
return N.NewExtendedConn(c), err
|
return c, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialContext implements C.ProxyAdapter
|
// DialContext implements C.ProxyAdapter
|
||||||
|
|
|
@ -60,7 +60,7 @@ type wgSingDialer struct {
|
||||||
dialer dialer.Dialer
|
dialer dialer.Dialer
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ N.Dialer = &wgSingDialer{}
|
var _ N.Dialer = (*wgSingDialer)(nil)
|
||||||
|
|
||||||
func (d *wgSingDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
func (d *wgSingDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||||
return d.dialer.DialContext(ctx, network, destination.String())
|
return d.dialer.DialContext(ctx, network, destination.String())
|
||||||
|
@ -74,7 +74,7 @@ type wgNetDialer struct {
|
||||||
tunDevice wireguard.Device
|
tunDevice wireguard.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ dialer.NetDialer = &wgNetDialer{}
|
var _ dialer.NetDialer = (*wgNetDialer)(nil)
|
||||||
|
|
||||||
func (d wgNetDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
func (d wgNetDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
return d.tunDevice.DialContext(ctx, network, M.ParseSocksaddr(address).Unwrap())
|
return d.tunDevice.DialContext(ctx, network, M.ParseSocksaddr(address).Unwrap())
|
||||||
|
|
|
@ -22,53 +22,23 @@ func (c *firstWriteCallBackConn) Write(b []byte) (n int, err error) {
|
||||||
return c.Conn.Write(b)
|
return c.Conn.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *firstWriteCallBackConn) WriteBuffer(buffer *buf.Buffer) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if !c.written {
|
||||||
|
c.written = true
|
||||||
|
c.callback(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return c.Conn.WriteBuffer(buffer)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *firstWriteCallBackConn) Upstream() any {
|
func (c *firstWriteCallBackConn) Upstream() any {
|
||||||
return c.Conn
|
return c.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
type extendedConn interface {
|
var _ N.ExtendedConn = (*firstWriteCallBackConn)(nil)
|
||||||
C.Conn
|
|
||||||
N.ExtendedConn
|
|
||||||
}
|
|
||||||
|
|
||||||
type firstWriteCallBackExtendedConn struct {
|
|
||||||
extendedConn
|
|
||||||
callback func(error)
|
|
||||||
written bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *firstWriteCallBackExtendedConn) Write(b []byte) (n int, err error) {
|
|
||||||
defer func() {
|
|
||||||
if !c.written {
|
|
||||||
c.written = true
|
|
||||||
c.callback(err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return c.extendedConn.Write(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *firstWriteCallBackExtendedConn) WriteBuffer(buffer *buf.Buffer) (err error) {
|
|
||||||
defer func() {
|
|
||||||
if !c.written {
|
|
||||||
c.written = true
|
|
||||||
c.callback(err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return c.extendedConn.WriteBuffer(buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *firstWriteCallBackExtendedConn) Upstream() any {
|
|
||||||
return c.extendedConn
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewFirstWriteCallBackConn(c C.Conn, callback func(error)) C.Conn {
|
func NewFirstWriteCallBackConn(c C.Conn, callback func(error)) C.Conn {
|
||||||
if c, ok := c.(extendedConn); ok {
|
|
||||||
return &firstWriteCallBackExtendedConn{
|
|
||||||
extendedConn: c,
|
|
||||||
callback: callback,
|
|
||||||
written: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &firstWriteCallBackConn{
|
return &firstWriteCallBackConn{
|
||||||
Conn: c,
|
Conn: c,
|
||||||
callback: callback,
|
callback: callback,
|
||||||
|
|
|
@ -4,10 +4,12 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Dreamacro/clash/common/buf"
|
||||||
)
|
)
|
||||||
|
|
||||||
type refConn struct {
|
type refConn struct {
|
||||||
conn net.Conn
|
conn ExtendedConn
|
||||||
ref any
|
ref any
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,8 +57,20 @@ func (c *refConn) Upstream() any {
|
||||||
return c.conn
|
return c.conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *refConn) ReadBuffer(buffer *buf.Buffer) error {
|
||||||
|
defer runtime.KeepAlive(c.ref)
|
||||||
|
return c.conn.ReadBuffer(buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *refConn) WriteBuffer(buffer *buf.Buffer) error {
|
||||||
|
defer runtime.KeepAlive(c.ref)
|
||||||
|
return c.conn.WriteBuffer(buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ExtendedConn = (*refConn)(nil)
|
||||||
|
|
||||||
func NewRefConn(conn net.Conn, ref any) net.Conn {
|
func NewRefConn(conn net.Conn, ref any) net.Conn {
|
||||||
return &refConn{conn: conn, ref: ref}
|
return &refConn{conn: NewExtendedConn(conn), ref: ref}
|
||||||
}
|
}
|
||||||
|
|
||||||
type refPacketConn struct {
|
type refPacketConn struct {
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
N "github.com/Dreamacro/clash/common/net"
|
||||||
"github.com/Dreamacro/clash/component/dialer"
|
"github.com/Dreamacro/clash/component/dialer"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -72,7 +73,7 @@ func (c Chain) Last() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Conn interface {
|
type Conn interface {
|
||||||
net.Conn
|
N.ExtendedConn
|
||||||
Connection
|
Connection
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Dreamacro/clash/adapter/inbound"
|
"github.com/Dreamacro/clash/adapter/inbound"
|
||||||
|
N "github.com/Dreamacro/clash/common/net"
|
||||||
C "github.com/Dreamacro/clash/constant"
|
C "github.com/Dreamacro/clash/constant"
|
||||||
"github.com/Dreamacro/clash/log"
|
"github.com/Dreamacro/clash/log"
|
||||||
"github.com/Dreamacro/clash/transport/socks5"
|
"github.com/Dreamacro/clash/transport/socks5"
|
||||||
|
@ -33,7 +34,7 @@ type ListenerHandler struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type waitCloseConn struct {
|
type waitCloseConn struct {
|
||||||
net.Conn
|
N.ExtendedConn
|
||||||
wg *sync.WaitGroup
|
wg *sync.WaitGroup
|
||||||
close sync.Once
|
close sync.Once
|
||||||
rAddr net.Addr
|
rAddr net.Addr
|
||||||
|
@ -43,7 +44,7 @@ func (c *waitCloseConn) Close() error { // call from handleTCPConn(connCtx C.Con
|
||||||
c.close.Do(func() {
|
c.close.Do(func() {
|
||||||
c.wg.Done()
|
c.wg.Done()
|
||||||
})
|
})
|
||||||
return c.Conn.Close()
|
return c.ExtendedConn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *waitCloseConn) RemoteAddr() net.Addr {
|
func (c *waitCloseConn) RemoteAddr() net.Addr {
|
||||||
|
@ -51,7 +52,7 @@ func (c *waitCloseConn) RemoteAddr() net.Addr {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *waitCloseConn) Upstream() any {
|
func (c *waitCloseConn) Upstream() any {
|
||||||
return c.Conn
|
return c.ExtendedConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||||
|
@ -79,7 +80,7 @@ func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, meta
|
||||||
defer wg.Wait() // this goroutine must exit after conn.Close()
|
defer wg.Wait() // this goroutine must exit after conn.Close()
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
|
||||||
h.TcpIn <- inbound.NewSocket(target, &waitCloseConn{Conn: conn, wg: wg, rAddr: metadata.Source.TCPAddr()}, h.Type, additions...)
|
h.TcpIn <- inbound.NewSocket(target, &waitCloseConn{ExtendedConn: N.NewExtendedConn(conn), wg: wg, rAddr: metadata.Source.TCPAddr()}, h.Type, additions...)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -955,7 +955,7 @@ func (b *bbrSender) CalculateRecoveryWindow(ackedBytes, lostBytes congestion.Byt
|
||||||
b.recoveryWindow = maxByteCount(b.recoveryWindow, b.minCongestionWindow())
|
b.recoveryWindow = maxByteCount(b.recoveryWindow, b.minCongestionWindow())
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ congestion.CongestionControl = &bbrSender{}
|
var _ congestion.CongestionControl = (*bbrSender)(nil)
|
||||||
|
|
||||||
func (b *bbrSender) GetMinRtt() time.Duration {
|
func (b *bbrSender) GetMinRtt() time.Duration {
|
||||||
if b.minRtt > 0 {
|
if b.minRtt > 0 {
|
||||||
|
|
|
@ -103,7 +103,7 @@ func (q *quicStreamConn) RemoteAddr() net.Addr {
|
||||||
return q.rAddr
|
return q.rAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ net.Conn = &quicStreamConn{}
|
var _ net.Conn = (*quicStreamConn)(nil)
|
||||||
|
|
||||||
type quicStreamPacketConn struct {
|
type quicStreamPacketConn struct {
|
||||||
connId uint32
|
connId uint32
|
||||||
|
@ -252,4 +252,4 @@ func (q *quicStreamPacketConn) LocalAddr() net.Addr {
|
||||||
return q.quicConn.LocalAddr()
|
return q.quicConn.LocalAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ net.PacketConn = &quicStreamPacketConn{}
|
var _ net.PacketConn = (*quicStreamPacketConn)(nil)
|
||||||
|
|
|
@ -294,5 +294,5 @@ func (s *serverUDPPacket) Drop() {
|
||||||
s.packet.DATA = nil
|
s.packet.DATA = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ C.UDPPacket = &serverUDPPacket{}
|
var _ C.UDPPacket = (*serverUDPPacket)(nil)
|
||||||
var _ C.UDPPacketInAddr = &serverUDPPacket{}
|
var _ C.UDPPacketInAddr = (*serverUDPPacket)(nil)
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Dreamacro/clash/common/buf"
|
"github.com/Dreamacro/clash/common/buf"
|
||||||
N "github.com/Dreamacro/clash/common/net"
|
|
||||||
"github.com/Dreamacro/clash/common/utils"
|
"github.com/Dreamacro/clash/common/utils"
|
||||||
C "github.com/Dreamacro/clash/constant"
|
C "github.com/Dreamacro/clash/constant"
|
||||||
|
|
||||||
|
@ -33,8 +32,6 @@ type tcpTracker struct {
|
||||||
C.Conn `json:"-"`
|
C.Conn `json:"-"`
|
||||||
*trackerInfo
|
*trackerInfo
|
||||||
manager *Manager
|
manager *Manager
|
||||||
extendedReader N.ExtendedReader
|
|
||||||
extendedWriter N.ExtendedWriter
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tt *tcpTracker) ID() string {
|
func (tt *tcpTracker) ID() string {
|
||||||
|
@ -50,7 +47,7 @@ func (tt *tcpTracker) Read(b []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tt *tcpTracker) ReadBuffer(buffer *buf.Buffer) (err error) {
|
func (tt *tcpTracker) ReadBuffer(buffer *buf.Buffer) (err error) {
|
||||||
err = tt.extendedReader.ReadBuffer(buffer)
|
err = tt.Conn.ReadBuffer(buffer)
|
||||||
download := int64(buffer.Len())
|
download := int64(buffer.Len())
|
||||||
tt.manager.PushDownloaded(download)
|
tt.manager.PushDownloaded(download)
|
||||||
tt.DownloadTotal.Add(download)
|
tt.DownloadTotal.Add(download)
|
||||||
|
@ -67,7 +64,7 @@ func (tt *tcpTracker) Write(b []byte) (int, error) {
|
||||||
|
|
||||||
func (tt *tcpTracker) WriteBuffer(buffer *buf.Buffer) (err error) {
|
func (tt *tcpTracker) WriteBuffer(buffer *buf.Buffer) (err error) {
|
||||||
upload := int64(buffer.Len())
|
upload := int64(buffer.Len())
|
||||||
err = tt.extendedWriter.WriteBuffer(buffer)
|
err = tt.Conn.WriteBuffer(buffer)
|
||||||
tt.manager.PushUploaded(upload)
|
tt.manager.PushUploaded(upload)
|
||||||
tt.UploadTotal.Add(upload)
|
tt.UploadTotal.Add(upload)
|
||||||
return
|
return
|
||||||
|
@ -103,8 +100,6 @@ func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.R
|
||||||
UploadTotal: atomic.NewInt64(uploadTotal),
|
UploadTotal: atomic.NewInt64(uploadTotal),
|
||||||
DownloadTotal: atomic.NewInt64(downloadTotal),
|
DownloadTotal: atomic.NewInt64(downloadTotal),
|
||||||
},
|
},
|
||||||
extendedReader: N.NewExtendedReader(conn),
|
|
||||||
extendedWriter: N.NewExtendedWriter(conn),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if rule != nil {
|
if rule != nil {
|
||||||
|
|
Loading…
Reference in a new issue