diff --git a/adapter/outbound/tuic.go b/adapter/outbound/tuic.go index 38c220de..5afaa895 100644 --- a/adapter/outbound/tuic.go +++ b/adapter/outbound/tuic.go @@ -56,7 +56,7 @@ type TuicOption struct { // DialContext implements C.ProxyAdapter func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { opts = t.Base.DialOptions(opts...) - conn, err := t.client.DialContext(ctx, metadata, opts...) + conn, err := t.client.DialContext(ctx, metadata, t.dial, opts...) if err != nil { return nil, err } @@ -66,7 +66,7 @@ func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...di // ListenPacketContext implements C.ProxyAdapter func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { opts = t.Base.DialOptions(opts...) - pc, err := t.client.ListenPacketContext(ctx, metadata, opts...) + pc, err := t.client.ListenPacketContext(ctx, metadata, t.dial, opts...) if err != nil { return nil, err } @@ -205,7 +205,6 @@ func NewTuic(option TuicOption) (*Tuic, error) { clientMaxOpenStreams = 1 } clientOption := &tuic.ClientOption{ - DialFn: t.dial, TlsConfig: tlsConfig, QuicConfig: quicConfig, Host: host, @@ -219,7 +218,7 @@ func NewTuic(option TuicOption) (*Tuic, error) { MaxOpenStreams: clientMaxOpenStreams, } - t.client = tuic.NewClientPool(clientOption) + t.client = tuic.NewPoolClient(clientOption) return t, nil } diff --git a/adapter/outbound/wireguard.go b/adapter/outbound/wireguard.go index d08c0325..eb35ba03 100644 --- a/adapter/outbound/wireguard.go +++ b/adapter/outbound/wireguard.go @@ -12,8 +12,8 @@ import ( "strconv" "strings" "sync" - "time" + CN "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/component/dialer" "github.com/Dreamacro/clash/component/resolver" C "github.com/Dreamacro/clash/constant" @@ -220,7 +220,7 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts if conn == nil { return nil, E.New("conn is nil") } - return NewConn(&wgConn{conn, w}, w), nil + return NewConn(CN.NewRefConn(conn, w), w), nil } func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { @@ -249,90 +249,5 @@ func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadat if pc == nil { return nil, E.New("packetConn is nil") } - return newPacketConn(&wgPacketConn{pc, w}, w), nil -} - -type wgConn struct { - conn net.Conn - wg *WireGuard -} - -func (c *wgConn) Read(b []byte) (n int, err error) { - defer runtime.KeepAlive(c.wg) - return c.conn.Read(b) -} - -func (c *wgConn) Write(b []byte) (n int, err error) { - defer runtime.KeepAlive(c.wg) - return c.conn.Write(b) -} - -func (c *wgConn) Close() error { - defer runtime.KeepAlive(c.wg) - return c.conn.Close() -} - -func (c *wgConn) LocalAddr() net.Addr { - defer runtime.KeepAlive(c.wg) - return c.conn.LocalAddr() -} - -func (c *wgConn) RemoteAddr() net.Addr { - defer runtime.KeepAlive(c.wg) - return c.conn.RemoteAddr() -} - -func (c *wgConn) SetDeadline(t time.Time) error { - defer runtime.KeepAlive(c.wg) - return c.conn.SetDeadline(t) -} - -func (c *wgConn) SetReadDeadline(t time.Time) error { - defer runtime.KeepAlive(c.wg) - return c.conn.SetReadDeadline(t) -} - -func (c *wgConn) SetWriteDeadline(t time.Time) error { - defer runtime.KeepAlive(c.wg) - return c.conn.SetWriteDeadline(t) -} - -type wgPacketConn struct { - pc net.PacketConn - wg *WireGuard -} - -func (pc *wgPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - defer runtime.KeepAlive(pc.wg) - return pc.pc.ReadFrom(p) -} - -func (pc *wgPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - defer runtime.KeepAlive(pc.wg) - return pc.pc.WriteTo(p, addr) -} - -func (pc *wgPacketConn) Close() error { - defer runtime.KeepAlive(pc.wg) - return pc.pc.Close() -} - -func (pc *wgPacketConn) LocalAddr() net.Addr { - defer runtime.KeepAlive(pc.wg) - return pc.pc.LocalAddr() -} - -func (pc *wgPacketConn) SetDeadline(t time.Time) error { - defer runtime.KeepAlive(pc.wg) - return pc.pc.SetDeadline(t) -} - -func (pc *wgPacketConn) SetReadDeadline(t time.Time) error { - defer runtime.KeepAlive(pc.wg) - return pc.pc.SetReadDeadline(t) -} - -func (pc *wgPacketConn) SetWriteDeadline(t time.Time) error { - defer runtime.KeepAlive(pc.wg) - return pc.pc.SetWriteDeadline(t) + return newPacketConn(CN.NewRefPacketConn(pc, w), w), nil } diff --git a/common/net/refconn.go b/common/net/refconn.go new file mode 100644 index 00000000..6d28a2bf --- /dev/null +++ b/common/net/refconn.go @@ -0,0 +1,100 @@ +package net + +import ( + "net" + "runtime" + "time" +) + +type refConn struct { + conn net.Conn + ref any +} + +func (c *refConn) Read(b []byte) (n int, err error) { + defer runtime.KeepAlive(c.ref) + return c.conn.Read(b) +} + +func (c *refConn) Write(b []byte) (n int, err error) { + defer runtime.KeepAlive(c.ref) + return c.conn.Write(b) +} + +func (c *refConn) Close() error { + defer runtime.KeepAlive(c.ref) + return c.conn.Close() +} + +func (c *refConn) LocalAddr() net.Addr { + defer runtime.KeepAlive(c.ref) + return c.conn.LocalAddr() +} + +func (c *refConn) RemoteAddr() net.Addr { + defer runtime.KeepAlive(c.ref) + return c.conn.RemoteAddr() +} + +func (c *refConn) SetDeadline(t time.Time) error { + defer runtime.KeepAlive(c.ref) + return c.conn.SetDeadline(t) +} + +func (c *refConn) SetReadDeadline(t time.Time) error { + defer runtime.KeepAlive(c.ref) + return c.conn.SetReadDeadline(t) +} + +func (c *refConn) SetWriteDeadline(t time.Time) error { + defer runtime.KeepAlive(c.ref) + return c.conn.SetWriteDeadline(t) +} + +func NewRefConn(conn net.Conn, ref any) net.Conn { + return &refConn{conn: conn, ref: ref} +} + +type refPacketConn struct { + pc net.PacketConn + ref any +} + +func (pc *refPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + defer runtime.KeepAlive(pc.ref) + return pc.pc.ReadFrom(p) +} + +func (pc *refPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + defer runtime.KeepAlive(pc.ref) + return pc.pc.WriteTo(p, addr) +} + +func (pc *refPacketConn) Close() error { + defer runtime.KeepAlive(pc.ref) + return pc.pc.Close() +} + +func (pc *refPacketConn) LocalAddr() net.Addr { + defer runtime.KeepAlive(pc.ref) + return pc.pc.LocalAddr() +} + +func (pc *refPacketConn) SetDeadline(t time.Time) error { + defer runtime.KeepAlive(pc.ref) + return pc.pc.SetDeadline(t) +} + +func (pc *refPacketConn) SetReadDeadline(t time.Time) error { + defer runtime.KeepAlive(pc.ref) + return pc.pc.SetReadDeadline(t) +} + +func (pc *refPacketConn) SetWriteDeadline(t time.Time) error { + defer runtime.KeepAlive(pc.ref) + return pc.pc.SetWriteDeadline(t) +} + +func NewRefPacketConn(pc net.PacketConn, ref any) net.PacketConn { + return &refPacketConn{pc: pc, ref: ref} +} diff --git a/transport/tuic/client.go b/transport/tuic/client.go index a1dfcc30..dcb4e3aa 100644 --- a/transport/tuic/client.go +++ b/transport/tuic/client.go @@ -19,6 +19,7 @@ import ( "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/component/dialer" C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" ) var ( @@ -26,9 +27,9 @@ var ( TooManyOpenStreams = errors.New("tuic: too many open streams") ) -type ClientOption struct { - DialFn func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) +type DialFunc func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) +type ClientOption struct { TlsConfig *tls.Config QuicConfig *quic.Config Host string @@ -42,7 +43,7 @@ type ClientOption struct { MaxOpenStreams int64 } -type Client struct { +type clientImpl struct { *ClientOption udp bool @@ -55,18 +56,17 @@ type Client struct { udpInputMap sync.Map // only ready for PoolClient - poolRef *PoolClient optionRef any lastVisited time.Time } -func (t *Client) getQuicConn(ctx context.Context) (quic.Connection, error) { +func (t *clientImpl) getQuicConn(ctx context.Context, dialFn DialFunc, opts ...dialer.Option) (quic.Connection, error) { t.connMutex.Lock() defer t.connMutex.Unlock() if t.quicConn != nil { return t.quicConn, nil } - pc, addr, err := t.DialFn(ctx) + pc, addr, err := dialFn(ctx, opts...) if err != nil { return nil, err } @@ -97,7 +97,7 @@ func (t *Client) getQuicConn(ctx context.Context) (quic.Connection, error) { return quicConn, nil } -func (t *Client) sendAuthentication(quicConn quic.Connection) (err error) { +func (t *clientImpl) sendAuthentication(quicConn quic.Connection) (err error) { defer func() { t.deferQuicConn(quicConn, err) }() @@ -122,7 +122,7 @@ func (t *Client) sendAuthentication(quicConn quic.Connection) (err error) { return nil } -func (t *Client) parseUDP(quicConn quic.Connection) (err error) { +func (t *clientImpl) parseUDP(quicConn quic.Connection) (err error) { defer func() { t.deferQuicConn(quicConn, err) }() @@ -199,45 +199,50 @@ func (t *Client) parseUDP(quicConn quic.Connection) (err error) { } } -func (t *Client) deferQuicConn(quicConn quic.Connection, err error) { +func (t *clientImpl) deferQuicConn(quicConn quic.Connection, err error) { var netError net.Error if err != nil && errors.As(err, &netError) { - t.connMutex.Lock() - defer t.connMutex.Unlock() - if t.quicConn == quicConn { - t.forceClose(err, true) + t.forceClose(quicConn, err) + } +} + +func (t *clientImpl) forceClose(quicConn quic.Connection, err error) { + t.connMutex.Lock() + defer t.connMutex.Unlock() + if quicConn == nil { + quicConn = t.quicConn + } + if quicConn != nil { + if quicConn == t.quicConn { + t.quicConn = nil } } -} - -func (t *Client) forceClose(err error, locked bool) { - if !locked { - t.connMutex.Lock() - defer t.connMutex.Unlock() + errStr := "" + if err != nil { + errStr = err.Error() } - quicConn := t.quicConn if quicConn != nil { - _ = quicConn.CloseWithError(ProtocolError, err.Error()) - t.udpInputMap.Range(func(key, value any) bool { - if conn, ok := value.(net.Conn); ok { - _ = conn.Close() - } - t.udpInputMap.Delete(key) - return true - }) - t.quicConn = nil + _ = quicConn.CloseWithError(ProtocolError, errStr) } + udpInputMap := &t.udpInputMap + udpInputMap.Range(func(key, value any) bool { + if conn, ok := value.(net.Conn); ok { + _ = conn.Close() + } + udpInputMap.Delete(key) + return true + }) } -func (t *Client) Close() { +func (t *clientImpl) Close() { t.closed.Store(true) if t.openStreams.Load() == 0 { - t.forceClose(ClientClosed, false) + t.forceClose(nil, ClientClosed) } } -func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata) (net.Conn, error) { - quicConn, err := t.getQuicConn(ctx) +func (t *clientImpl) DialContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.Conn, error) { + quicConn, err := t.getQuicConn(ctx, dialFn, opts...) if err != nil { return nil, err } @@ -264,12 +269,11 @@ func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata) (net.Con Stream: quicStream, lAddr: quicConn.LocalAddr(), rAddr: quicConn.RemoteAddr(), - ref: t, closeDeferFn: func() { time.AfterFunc(C.DefaultTCPTimeout, func() { openStreams := t.openStreams.Add(-1) if openStreams == 0 && t.closed.Load() { - t.forceClose(ClientClosed, false) + t.forceClose(quicConn, ClientClosed) } }) }, @@ -335,8 +339,8 @@ func (conn *earlyConn) Read(b []byte) (n int, err error) { return conn.BufferedConn.Read(b) } -func (t *Client) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (net.PacketConn, error) { - quicConn, err := t.getQuicConn(ctx) +func (t *clientImpl) ListenPacketContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.PacketConn, error) { + quicConn, err := t.getQuicConn(ctx, dialFn, opts...) if err != nil { return nil, err } @@ -362,14 +366,13 @@ func (t *Client) ListenPacketContext(ctx context.Context, metadata *C.Metadata) inputConn: N.NewBufferedConn(pipe2), udpRelayMode: t.UdpRelayMode, maxUdpRelayPacketSize: t.MaxUdpRelayPacketSize, - ref: t, deferQuicConnFn: t.deferQuicConn, closeDeferFn: func() { t.udpInputMap.Delete(connId) time.AfterFunc(C.DefaultUDPTimeout, func() { openStreams := t.openStreams.Add(-1) if openStreams == 0 && t.closed.Load() { - t.forceClose(ClientClosed, false) + t.forceClose(quicConn, ClientClosed) } }) }, @@ -377,15 +380,42 @@ func (t *Client) ListenPacketContext(ctx context.Context, metadata *C.Metadata) return pc, nil } +type Client struct { + *clientImpl // use an independent pointer to let Finalizer can work no matter somewhere handle an influence in clientImpl inner +} + +func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.Conn, error) { + conn, err := t.clientImpl.DialContext(ctx, metadata, dialFn, opts...) + if err != nil { + return nil, err + } + return N.NewRefConn(conn, t), err +} + +func (t *Client) ListenPacketContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.PacketConn, error) { + pc, err := t.clientImpl.ListenPacketContext(ctx, metadata, dialFn, opts...) + if err != nil { + return nil, err + } + return N.NewRefPacketConn(pc, t), nil +} + +func (t *Client) forceClose() { + t.clientImpl.forceClose(nil, ClientClosed) +} + func NewClient(clientOption *ClientOption, udp bool) *Client { - c := &Client{ + ci := &clientImpl{ ClientOption: clientOption, udp: udp, } + c := &Client{ci} runtime.SetFinalizer(c, closeClient) + log.Debugln("New Tuic Client at %p", c) return c } func closeClient(client *Client) { - client.forceClose(ClientClosed, false) + log.Debugln("Close Tuic Client at %p", client) + client.forceClose() } diff --git a/transport/tuic/conn.go b/transport/tuic/conn.go index 3e759f25..81e8c40e 100644 --- a/transport/tuic/conn.go +++ b/transport/tuic/conn.go @@ -58,8 +58,6 @@ type quicStreamConn struct { lAddr net.Addr rAddr net.Addr - ref any - closeDeferFn func() closeOnce sync.Once @@ -117,8 +115,6 @@ type quicStreamPacketConn struct { udpRelayMode string maxUdpRelayPacketSize int - ref any - deferQuicConnFn func(quicConn quic.Connection, err error) closeDeferFn func() writeClosed *atomic.Bool diff --git a/transport/tuic/pool_client.go b/transport/tuic/pool_client.go index 97e12053..9753da0d 100644 --- a/transport/tuic/pool_client.go +++ b/transport/tuic/pool_client.go @@ -9,8 +9,10 @@ import ( "time" "github.com/Dreamacro/clash/common/generics/list" + N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/component/dialer" C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" ) type dialResult struct { @@ -31,29 +33,35 @@ type PoolClient struct { udpClientsMutex *sync.Mutex } -func (t *PoolClient) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (net.Conn, error) { - conn, err := t.getClient(false, opts...).DialContext(ctx, metadata) +func (t *PoolClient) DialContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.Conn, error) { + newDialFn := func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) { + return t.dial(ctx, dialFn, opts...) + } + conn, err := t.getClient(false, opts...).DialContext(ctx, metadata, newDialFn, opts...) if errors.Is(err, TooManyOpenStreams) { - conn, err = t.newClient(false, opts...).DialContext(ctx, metadata) + conn, err = t.newClient(false, opts...).DialContext(ctx, metadata, newDialFn, opts...) } if err != nil { return nil, err } - return conn, err + return N.NewRefConn(conn, t), err } -func (t *PoolClient) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (net.PacketConn, error) { - pc, err := t.getClient(true, opts...).ListenPacketContext(ctx, metadata) +func (t *PoolClient) ListenPacketContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.PacketConn, error) { + newDialFn := func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) { + return t.dial(ctx, dialFn, opts...) + } + pc, err := t.getClient(true, opts...).ListenPacketContext(ctx, metadata, newDialFn, opts...) if errors.Is(err, TooManyOpenStreams) { - pc, err = t.newClient(false, opts...).ListenPacketContext(ctx, metadata) + pc, err = t.newClient(false, opts...).ListenPacketContext(ctx, metadata, newDialFn, opts...) } if err != nil { return nil, err } - return pc, nil + return N.NewRefPacketConn(pc, t), nil } -func (t *PoolClient) dial(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) { +func (t *PoolClient) dial(ctx context.Context, dialFn DialFunc, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) { var o any = *dialer.ApplyOptions(opts...) t.dialResultMutex.Lock() @@ -63,7 +71,7 @@ func (t *PoolClient) dial(ctx context.Context, opts ...dialer.Option) (pc net.Pa return dr.pc, dr.addr, dr.err } - pc, addr, err = t.DialFn(ctx, opts...) + pc, addr, err = dialFn(ctx, opts...) if err != nil { return nil, nil, err } @@ -102,7 +110,6 @@ func (t *PoolClient) newClient(udp bool, opts ...dialer.Option) *Client { defer clientsMutex.Unlock() client := NewClient(t.newClientOption, udp) - client.poolRef = t // make sure pool has a reference client.optionRef = o client.lastVisited = time.Now() @@ -141,16 +148,20 @@ func (t *PoolClient) getClient(udp bool, opts ...dialer.Option) *Client { } } } - if client.openStreams.Load() == 0 && time.Now().Sub(client.lastVisited) > 30*time.Minute { - client.Close() - next := it.Next() - clients.Remove(it) - it = next - continue - } it = it.Next() } }() + for it := clients.Front(); it != nil; { + client := it.Value + if client != bestClient && client.openStreams.Load() == 0 && time.Now().Sub(client.lastVisited) > 30*time.Minute { + client.Close() + next := it.Next() + clients.Remove(it) + it = next + continue + } + it = it.Next() + } if bestClient == nil { return t.newClient(udp, opts...) @@ -160,7 +171,7 @@ func (t *PoolClient) getClient(udp bool, opts ...dialer.Option) *Client { } } -func NewClientPool(clientOption *ClientOption) *PoolClient { +func NewPoolClient(clientOption *ClientOption) *PoolClient { p := &PoolClient{ ClientOption: clientOption, dialResultMap: make(map[any]dialResult), @@ -171,12 +182,13 @@ func NewClientPool(clientOption *ClientOption) *PoolClient { udpClientsMutex: &sync.Mutex{}, } newClientOption := *clientOption - newClientOption.DialFn = p.dial p.newClientOption = &newClientOption runtime.SetFinalizer(p, closeClientPool) + log.Debugln("New Tuic PoolClient at %p", p) return p } func closeClientPool(client *PoolClient) { + log.Debugln("Close Tuic PoolClient at %p", client) client.forceClose() } diff --git a/transport/tuic/server.go b/transport/tuic/server.go index c1213d68..3f459fd6 100644 --- a/transport/tuic/server.go +++ b/transport/tuic/server.go @@ -143,7 +143,6 @@ func (s *serverHandler) parsePacket(packet Packet, udpRelayMode string) (err err inputConn: nil, udpRelayMode: udpRelayMode, maxUdpRelayPacketSize: s.MaxUdpRelayPacketSize, - ref: s, deferQuicConnFn: nil, closeDeferFn: nil, writeClosed: writeClosed, @@ -173,7 +172,6 @@ func (s *serverHandler) handleStream() (err error) { Stream: quicStream, lAddr: s.quicConn.LocalAddr(), rAddr: s.quicConn.RemoteAddr(), - ref: s, } conn := N.NewBufferedConn(stream) connect, err := ReadConnect(conn)