diff --git a/adapter/outbound/tuic.go b/adapter/outbound/tuic.go index 6ffc0095..fa24ae39 100644 --- a/adapter/outbound/tuic.go +++ b/adapter/outbound/tuic.go @@ -55,17 +55,12 @@ type TuicOption struct { // DialContext implements C.ProxyAdapter func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { - opts = t.Base.DialOptions(opts...) - conn, err := t.client.DialContext(ctx, metadata, t.dial, opts...) - if err != nil { - return nil, err - } - return NewConn(conn, t), err + return t.DialContextWithDialer(ctx, dialer.NewDialer(t.Base.DialOptions(opts...)...), metadata) } // DialContextWithDialer implements C.ProxyAdapter func (t *Tuic) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.Conn, error) { - conn, err := t.client.DialContextWithDialer(ctx, dialer, metadata, t.dialWithDialer) + conn, err := t.client.DialContextWithDialer(ctx, metadata, dialer, t.dialWithDialer) if err != nil { return nil, err } @@ -74,17 +69,12 @@ func (t *Tuic) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metad // ListenPacketContext implements C.ProxyAdapter func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { - opts = t.Base.DialOptions(opts...) - pc, err := t.client.ListenPacketContext(ctx, metadata, t.dial, opts...) - if err != nil { - return nil, err - } - return newPacketConn(pc, t), nil + return t.ListenPacketWithDialer(ctx, dialer.NewDialer(t.Base.DialOptions(opts...)...), metadata) } // ListenPacketWithDialer implements C.ProxyAdapter func (t *Tuic) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) { - pc, err := t.client.ListenPacketWithDialer(ctx, dialer, metadata, t.dialWithDialer) + pc, err := t.client.ListenPacketWithDialer(ctx, metadata, dialer, t.dialWithDialer) if err != nil { return nil, err } diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index 027b25b9..d7c0f3db 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -36,7 +36,7 @@ func ParseNetwork(network string, addr netip.Addr) string { return network } -func ApplyOptions(options ...Option) *option { +func applyOptions(options ...Option) *option { opt := &option{ interfaceName: DefaultInterface.Load(), routingMark: int(DefaultRoutingMark.Load()), @@ -54,7 +54,7 @@ func ApplyOptions(options ...Option) *option { } func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) { - opt := ApplyOptions(options...) + opt := applyOptions(options...) if opt.network == 4 || opt.network == 6 { if strings.Contains(network, "tcp") { @@ -445,19 +445,19 @@ func concurrentIPv6DialContext(ctx context.Context, network, address string, opt return concurrentDialContext(ctx, network, ips, port, opt) } -type dialer struct { - opt option +type Dialer struct { + Opt option } -func (d dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - return DialContext(ctx, network, address, withOption(d.opt)) +func (d Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return DialContext(ctx, network, address, WithOption(d.Opt)) } -func (d dialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) { - return ListenPacket(ctx, ParseNetwork(network, rAddrPort.Addr()), address, withOption(d.opt)) +func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) { + return ListenPacket(ctx, ParseNetwork(network, rAddrPort.Addr()), address, WithOption(d.Opt)) } -func NewDialer(options ...Option) dialer { - opt := ApplyOptions(options...) - return dialer{opt: *opt} +func NewDialer(options ...Option) Dialer { + opt := applyOptions(options...) + return Dialer{Opt: *opt} } diff --git a/component/dialer/options.go b/component/dialer/options.go index 8cd6fd39..27adc845 100644 --- a/component/dialer/options.go +++ b/component/dialer/options.go @@ -69,7 +69,7 @@ func WithOnlySingleStack(isIPv4 bool) Option { } } -func withOption(o option) Option { +func WithOption(o option) Option { return func(opt *option) { *opt = o } diff --git a/transport/tuic/client.go b/transport/tuic/client.go index 418c4133..1d1c3d15 100644 --- a/transport/tuic/client.go +++ b/transport/tuic/client.go @@ -17,7 +17,6 @@ import ( N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/pool" - "github.com/Dreamacro/clash/component/dialer" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/log" ) @@ -27,8 +26,7 @@ var ( TooManyOpenStreams = errors.New("tuic: too many open streams") ) -type DialFunc func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) -type DialWithDialerFunc func(ctx context.Context, dialer C.Dialer) (pc net.PacketConn, addr net.Addr, err error) +type DialFunc func(ctx context.Context, dialer C.Dialer) (pc net.PacketConn, addr net.Addr, err error) type ClientOption struct { TlsConfig *tls.Config @@ -57,17 +55,17 @@ type clientImpl struct { udpInputMap sync.Map // only ready for PoolClient - optionRef any + dialerRef C.Dialer lastVisited time.Time } -func (t *clientImpl) getQuicConn(ctx context.Context, dialFn DialFunc, opts ...dialer.Option) (quic.Connection, error) { +func (t *clientImpl) getQuicConn(ctx context.Context, dialer C.Dialer, dialFn DialFunc) (quic.Connection, error) { t.connMutex.Lock() defer t.connMutex.Unlock() if t.quicConn != nil { return t.quicConn, nil } - pc, addr, err := dialFn(ctx, opts...) + pc, addr, err := dialFn(ctx, dialer) if err != nil { return nil, err } @@ -242,8 +240,8 @@ func (t *clientImpl) Close() { } } -func (t *clientImpl) DialContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.Conn, error) { - quicConn, err := t.getQuicConn(ctx, dialFn, opts...) +func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.Conn, error) { + quicConn, err := t.getQuicConn(ctx, dialer, dialFn) if err != nil { return nil, err } @@ -340,8 +338,8 @@ func (conn *earlyConn) Read(b []byte) (n int, err error) { return conn.BufferedConn.Read(b) } -func (t *clientImpl) ListenPacketContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.PacketConn, error) { - quicConn, err := t.getQuicConn(ctx, dialFn, opts...) +func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.PacketConn, error) { + quicConn, err := t.getQuicConn(ctx, dialer, dialFn) if err != nil { return nil, err } @@ -385,16 +383,16 @@ type Client struct { *clientImpl // use an independent pointer to let Finalizer can work no matter somewhere handle an influence in clientImpl inner } -func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.Conn, error) { - conn, err := t.clientImpl.DialContext(ctx, metadata, dialFn, opts...) +func (t *Client) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.Conn, error) { + conn, err := t.clientImpl.DialContextWithDialer(ctx, metadata, dialer, dialFn) if err != nil { return nil, err } return N.NewRefConn(conn, t), err } -func (t *Client) ListenPacketContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.PacketConn, error) { - pc, err := t.clientImpl.ListenPacketContext(ctx, metadata, dialFn, opts...) +func (t *Client) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.PacketConn, error) { + pc, err := t.clientImpl.ListenPacketWithDialer(ctx, metadata, dialer, dialFn) if err != nil { return nil, err } diff --git a/transport/tuic/pool_client.go b/transport/tuic/pool_client.go index 304cc92d..fe06c2f3 100644 --- a/transport/tuic/pool_client.go +++ b/transport/tuic/pool_client.go @@ -10,7 +10,6 @@ import ( "github.com/Dreamacro/clash/common/generics/list" N "github.com/Dreamacro/clash/common/net" - "github.com/Dreamacro/clash/component/dialer" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/log" ) @@ -25,7 +24,7 @@ type PoolClient struct { *ClientOption newClientOption *ClientOption - dialResultMap map[any]dialResult + dialResultMap map[C.Dialer]dialResult dialResultMutex *sync.Mutex tcpClients *list.List[*Client] tcpClientsMutex *sync.Mutex @@ -33,14 +32,10 @@ type PoolClient struct { udpClientsMutex *sync.Mutex } -func (t *PoolClient) DialContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.Conn, error) { - newDialFn := func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) { - return t.dial(ctx, dialFn, opts...) - } - var o any = *dialer.ApplyOptions(opts...) - conn, err := t.getClient(false, o).DialContext(ctx, metadata, newDialFn, opts...) +func (t *PoolClient) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.Conn, error) { + conn, err := t.getClient(false, dialer).DialContextWithDialer(ctx, metadata, dialer, dialFn) if errors.Is(err, TooManyOpenStreams) { - conn, err = t.newClient(false, o).DialContext(ctx, metadata, newDialFn, opts...) + conn, err = t.newClient(false, dialer).DialContextWithDialer(ctx, metadata, dialer, dialFn) } if err != nil { return nil, err @@ -48,29 +43,10 @@ func (t *PoolClient) DialContext(ctx context.Context, metadata *C.Metadata, dial return N.NewRefConn(conn, t), err } -func (t *PoolClient) DialContextWithDialer(ctx context.Context, d C.Dialer, metadata *C.Metadata, dialFn DialWithDialerFunc) (net.Conn, error) { - newDialFn := func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) { - return dialFn(ctx, d) - } - var o any = d - conn, err := t.getClient(false, o).DialContext(ctx, metadata, newDialFn) +func (t *PoolClient) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.PacketConn, error) { + pc, err := t.getClient(true, dialer).ListenPacketWithDialer(ctx, metadata, dialer, dialFn) if errors.Is(err, TooManyOpenStreams) { - conn, err = t.newClient(false, o).DialContext(ctx, metadata, newDialFn) - } - if err != nil { - return nil, err - } - return N.NewRefConn(conn, t), err -} - -func (t *PoolClient) ListenPacketContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.PacketConn, error) { - newDialFn := func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) { - return t.dial(ctx, dialFn, opts...) - } - var o any = *dialer.ApplyOptions(opts...) - pc, err := t.getClient(true, o).ListenPacketContext(ctx, metadata, newDialFn, opts...) - if errors.Is(err, TooManyOpenStreams) { - pc, err = t.newClient(true, o).ListenPacketContext(ctx, metadata, newDialFn, opts...) + pc, err = t.newClient(true, dialer).ListenPacketWithDialer(ctx, metadata, dialer, dialFn) } if err != nil { return nil, err @@ -78,32 +54,15 @@ func (t *PoolClient) ListenPacketContext(ctx context.Context, metadata *C.Metada return N.NewRefPacketConn(pc, t), nil } -func (t *PoolClient) ListenPacketWithDialer(ctx context.Context, d C.Dialer, metadata *C.Metadata, dialFn DialWithDialerFunc) (net.PacketConn, error) { - newDialFn := func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) { - return dialFn(ctx, d) - } - var o any = d - pc, err := t.getClient(true, o).ListenPacketContext(ctx, metadata, newDialFn) - if errors.Is(err, TooManyOpenStreams) { - pc, err = t.newClient(true, o).ListenPacketContext(ctx, metadata, newDialFn) - } - if err != nil { - return nil, err - } - return N.NewRefPacketConn(pc, t), nil -} - -func (t *PoolClient) dial(ctx context.Context, dialFn DialFunc, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) { - var o any = *dialer.ApplyOptions(opts...) - +func (t *PoolClient) dial(ctx context.Context, dialer C.Dialer, dialFn DialFunc) (pc net.PacketConn, addr net.Addr, err error) { t.dialResultMutex.Lock() - dr, ok := t.dialResultMap[o] + dr, ok := t.dialResultMap[dialer] t.dialResultMutex.Unlock() if ok { return dr.pc, dr.addr, dr.err } - pc, addr, err = dialFn(ctx, opts...) + pc, addr, err = dialFn(ctx, dialer) if err != nil { return nil, nil, err } @@ -111,7 +70,7 @@ func (t *PoolClient) dial(ctx context.Context, dialFn DialFunc, opts ...dialer.O dr.pc, dr.addr, dr.err = pc, addr, err t.dialResultMutex.Lock() - t.dialResultMap[o] = dr + t.dialResultMap[dialer] = dr t.dialResultMutex.Unlock() return pc, addr, err } @@ -128,7 +87,7 @@ func (t *PoolClient) forceClose() { } } -func (t *PoolClient) newClient(udp bool, o any) *Client { +func (t *PoolClient) newClient(udp bool, dialer C.Dialer) *Client { clients := t.tcpClients clientsMutex := t.tcpClientsMutex if udp { @@ -140,14 +99,14 @@ func (t *PoolClient) newClient(udp bool, o any) *Client { defer clientsMutex.Unlock() client := NewClient(t.newClientOption, udp) - client.optionRef = o + client.dialerRef = dialer client.lastVisited = time.Now() clients.PushFront(client) return client } -func (t *PoolClient) getClient(udp bool, o any) *Client { +func (t *PoolClient) getClient(udp bool, dialer C.Dialer) *Client { clients := t.tcpClients clientsMutex := t.tcpClientsMutex if udp { @@ -167,7 +126,7 @@ func (t *PoolClient) getClient(udp bool, o any) *Client { it = next continue } - if client.optionRef == o { + if client.dialerRef == dialer { if bestClient == nil { bestClient = client } else { @@ -192,7 +151,7 @@ func (t *PoolClient) getClient(udp bool, o any) *Client { } if bestClient == nil { - return t.newClient(udp, o) + return t.newClient(udp, dialer) } else { bestClient.lastVisited = time.Now() return bestClient @@ -202,7 +161,7 @@ func (t *PoolClient) getClient(udp bool, o any) *Client { func NewPoolClient(clientOption *ClientOption) *PoolClient { p := &PoolClient{ ClientOption: clientOption, - dialResultMap: make(map[any]dialResult), + dialResultMap: make(map[C.Dialer]dialResult), dialResultMutex: &sync.Mutex{}, tcpClients: list.New[*Client](), tcpClientsMutex: &sync.Mutex{},