diff --git a/adapter/outboundgroup/relay.go b/adapter/outboundgroup/relay.go index c3fe8bd6..a596454f 100644 --- a/adapter/outboundgroup/relay.go +++ b/adapter/outboundgroup/relay.go @@ -27,7 +27,7 @@ func (r *Relay) DialContext(ctx context.Context, metadata *C.Metadata, opts ...d var d C.Dialer d = dialer.NewDialer(r.Base.DialOptions(opts...)...) for _, proxy := range proxies[:len(proxies)-1] { - d = proxydialer.New(proxy, d) + d = proxydialer.New(proxy, d, false) } last := proxies[len(proxies)-1] conn, err := last.DialContextWithDialer(ctx, d, metadata) @@ -58,7 +58,7 @@ func (r *Relay) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o var d C.Dialer d = dialer.NewDialer(r.Base.DialOptions(opts...)...) for _, proxy := range proxies[:len(proxies)-1] { - d = proxydialer.New(proxy, d) + d = proxydialer.New(proxy, d, false) } last := proxies[len(proxies)-1] pc, err := last.ListenPacketWithDialer(ctx, d, metadata) diff --git a/component/proxydialer/proxydialer.go b/component/proxydialer/proxydialer.go index fad3835d..6c1b3cf2 100644 --- a/component/proxydialer/proxydialer.go +++ b/component/proxydialer/proxydialer.go @@ -11,21 +11,23 @@ import ( "github.com/Dreamacro/clash/component/dialer" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/tunnel" + "github.com/Dreamacro/clash/tunnel/statistic" ) type proxyDialer struct { - proxy C.ProxyAdapter - dialer C.Dialer + proxy C.ProxyAdapter + dialer C.Dialer + statistic bool } -func New(proxy C.ProxyAdapter, dialer C.Dialer) C.Dialer { - return proxyDialer{proxy: proxy, dialer: dialer} +func New(proxy C.ProxyAdapter, dialer C.Dialer, statistic bool) C.Dialer { + return proxyDialer{proxy: proxy, dialer: dialer, statistic: statistic} } func NewByName(proxyName string, dialer C.Dialer) (C.Dialer, error) { proxies := tunnel.Proxies() if proxy, ok := proxies[proxyName]; ok { - return New(proxy, dialer), nil + return New(proxy, dialer, true), nil } return nil, fmt.Errorf("proxyName[%s] not found", proxyName) } @@ -42,17 +44,29 @@ func (p proxyDialer) DialContext(ctx context.Context, network, address string) ( } return N.NewBindPacketConn(pc, currentMeta.UDPAddr()), nil } + var conn C.Conn switch p.proxy.SupportWithDialer() { case C.ALLNet: fallthrough case C.TCP: - return p.proxy.DialContextWithDialer(ctx, p.dialer, currentMeta) + conn, err = p.proxy.DialContextWithDialer(ctx, p.dialer, currentMeta) + if err != nil { + return nil, err + } default: // fallback to old function if d, ok := p.dialer.(dialer.Dialer); ok { // fallback to old function - return p.proxy.DialContext(ctx, currentMeta, dialer.WithOption(d.Opt)) + conn, err = p.proxy.DialContext(ctx, currentMeta, dialer.WithOption(d.Opt)) + if err != nil { + return nil, err + } + } else { + return nil, C.ErrNotSupport } - return nil, C.ErrNotSupport } + if p.statistic { + conn = statistic.NewTCPTracker(conn, statistic.DefaultManager, currentMeta, nil, 0, 0) + } + return conn, err } func (p proxyDialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) { @@ -63,19 +77,32 @@ func (p proxyDialer) ListenPacket(ctx context.Context, network, address string, return p.listenPacket(ctx, currentMeta) } -func (p proxyDialer) listenPacket(ctx context.Context, currentMeta *C.Metadata) (net.PacketConn, error) { +func (p proxyDialer) listenPacket(ctx context.Context, currentMeta *C.Metadata) (C.PacketConn, error) { + var pc C.PacketConn + var err error currentMeta.NetWork = C.UDP switch p.proxy.SupportWithDialer() { case C.ALLNet: fallthrough case C.UDP: - return p.proxy.ListenPacketWithDialer(ctx, p.dialer, currentMeta) + pc, err = p.proxy.ListenPacketWithDialer(ctx, p.dialer, currentMeta) + if err != nil { + return nil, err + } default: // fallback to old function if d, ok := p.dialer.(dialer.Dialer); ok { // fallback to old function - return p.proxy.ListenPacketContext(ctx, currentMeta, dialer.WithOption(d.Opt)) + pc, err = p.proxy.ListenPacketContext(ctx, currentMeta, dialer.WithOption(d.Opt)) + if err != nil { + return nil, err + } + } else { + return nil, C.ErrNotSupport } - return nil, C.ErrNotSupport } + if p.statistic { + pc = statistic.NewUDPTracker(pc, statistic.DefaultManager, currentMeta, nil, 0, 0) + } + return pc, nil } func addrToMetadata(rawAddress string) (addr *C.Metadata, err error) { @@ -97,6 +124,7 @@ func addrToMetadata(rawAddress string) (addr *C.Metadata, err error) { DstPort: port, } } + addr.Type = C.INNER return }