From 66c50ba349c961c0e75fc8338a46547ae5d43c4e Mon Sep 17 00:00:00 2001 From: Julio Perez Date: Thu, 14 May 2026 15:57:28 -0700 Subject: [PATCH 1/3] adding localized profiling support --- .../nemo_retriever/graph/abstract_operator.py | 34 ++ .../src/nemo_retriever/graph/executor.py | 55 ++- .../graph/operator_archetype.py | 11 + .../src/nemo_retriever/utils/stage_timing.py | 334 +++++++++++++++++ .../nemo_retriever/utils/stage_timing_viz.py | 336 ++++++++++++++++++ 5 files changed, 768 insertions(+), 2 deletions(-) create mode 100644 nemo_retriever/src/nemo_retriever/utils/stage_timing.py create mode 100644 nemo_retriever/src/nemo_retriever/utils/stage_timing_viz.py diff --git a/nemo_retriever/src/nemo_retriever/graph/abstract_operator.py b/nemo_retriever/src/nemo_retriever/graph/abstract_operator.py index b0f5ee5081..3c82f83217 100644 --- a/nemo_retriever/src/nemo_retriever/graph/abstract_operator.py +++ b/nemo_retriever/src/nemo_retriever/graph/abstract_operator.py @@ -6,12 +6,23 @@ from abc import ABC, abstractmethod import inspect +import os +import time from typing import Any, TYPE_CHECKING +from nemo_retriever.utils import stage_timing + if TYPE_CHECKING: from nemo_retriever.graph.pipeline_graph import Graph, Node +def _safe_len(data: Any) -> int: + try: + return len(data) + except Exception: + return -1 + + class AbstractOperator(ABC): """Base class for all pipeline operators.""" @@ -30,9 +41,32 @@ def process(self, data: Any, **kwargs: Any) -> Any: ... def postprocess(self, data: Any, **kwargs: Any) -> Any: ... def run(self, data: Any, **kwargs: Any) -> Any: + if not stage_timing.is_enabled(): + data = self.preprocess(data, **kwargs) + data = self.process(data, **kwargs) + data = self.postprocess(data, **kwargs) + return data + + stage = getattr(self, "_nr_stage_name", None) or type(self).__name__ + n_in = _safe_len(data) + t0 = time.perf_counter() data = self.preprocess(data, **kwargs) + t1 = time.perf_counter() data = self.process(data, **kwargs) + t2 = time.perf_counter() data = self.postprocess(data, **kwargs) + t3 = time.perf_counter() + stage_timing.record_timing( + stage=stage, + n_rows_in=n_in, + n_rows_out=_safe_len(data), + preprocess_ms=(t1 - t0) * 1000.0, + process_ms=(t2 - t1) * 1000.0, + postprocess_ms=(t3 - t2) * 1000.0, + total_ms=(t3 - t0) * 1000.0, + worker_pid=os.getpid(), + wallclock_start=t0, + ) return data def __call__(self, data: Any, **kwargs: Any) -> Any: diff --git a/nemo_retriever/src/nemo_retriever/graph/executor.py b/nemo_retriever/src/nemo_retriever/graph/executor.py index 14a323ab08..b44262b59c 100644 --- a/nemo_retriever/src/nemo_retriever/graph/executor.py +++ b/nemo_retriever/src/nemo_retriever/graph/executor.py @@ -29,6 +29,7 @@ VLLM_GPUS_PER_ACTOR, OCR_GPUS_PER_ACTOR, ) +from nemo_retriever.utils import stage_timing import logging @@ -249,6 +250,10 @@ def ingest(self, data: Any, **kwargs: Any) -> Any: } ray_env_vars.update(collect_hf_runtime_env()) ray_env_vars.update(collect_remote_auth_runtime_env()) + for _name in (stage_timing.ENABLED_ENV, stage_timing.REPORT_PATH_ENV): + _val = os.environ.get(_name) + if _val: + ray_env_vars[_name] = _val os.environ["HF_HUB_OFFLINE"] = ray_env_vars["HF_HUB_OFFLINE"] runtime_env = {"env_vars": ray_env_vars} ray.init( @@ -257,6 +262,9 @@ def ingest(self, data: Any, **kwargs: Any) -> Any: runtime_env=runtime_env, ) + timing_enabled = stage_timing.is_enabled() + timing_collector = stage_timing.start_collector() if timing_enabled else None + ctx = rd.DataContext.get_current() ctx.enable_rich_progress_bars = True ctx.use_ray_tqdm = False @@ -273,6 +281,24 @@ def ingest(self, data: Any, **kwargs: Any) -> Any: except FileNotFoundError as exc: raise_input_path_not_found(input_paths or [], exc) nodes = self._linearize(resolved_graph) + timing_call_index: Optional[int] = None + timing_node_names: List[str] = [n.name for n in nodes] + timing_graph_label: Optional[str] = None + if timing_enabled: + timing_call_index = stage_timing.next_call_index() + timing_graph_label = stage_timing.slugify_graph_label(timing_node_names) + report_path = stage_timing.resolve_report_path(timing_call_index, timing_graph_label) + logger.info( + "RayDataExecutor.ingest() #%02d | %d nodes: [%s] | label=%s | report -> %s", + timing_call_index, + len(timing_node_names), + ", ".join(timing_node_names), + timing_graph_label, + report_path if report_path is not None else "", + ) + # Per-node operator class stamping happens inside the map_batches + # loop below via stage_timing.make_named_operator_class so two + # nodes sharing the same operator class get distinct stage names. for node in nodes: overrides = dict(self._node_overrides.get(node.name, {})) target_num_rows_per_block = overrides.pop("target_num_rows_per_block", None) @@ -370,8 +396,11 @@ def ingest(self, data: Any, **kwargs: Any) -> Any: # fn_constructor_kwargs for deferred construction on workers. # AbstractOperator.__call__ delegates to run(), so each stage # executes the full preprocess -> process -> postprocess chain. + operator_cls = node.operator_class + if timing_enabled: + operator_cls = stage_timing.make_named_operator_class(operator_cls, node.name) ds = ds.map_batches( - node.operator_class, + operator_cls, batch_size=batch_size, batch_format=batch_format, num_cpus=num_cpus, @@ -380,4 +409,26 @@ def ingest(self, data: Any, **kwargs: Any) -> Any: **overrides, ) - return ds.to_pandas() + result = ds.to_pandas() + if timing_collector is not None: + try: + ray_stats_text = None + try: + ray_stats_text = ds.stats() + except Exception as exc: + logger.warning("Failed to collect ds.stats(): %s", exc) + try: + records = ray.get(timing_collector.dump.remote()) + except Exception as exc: + logger.warning("Failed to retrieve stage timing records: %s", exc) + records = [] + stage_timing.write_report( + records, + ray_stats_text=ray_stats_text, + call_index=timing_call_index, + graph_label=timing_graph_label, + node_names=timing_node_names, + ) + finally: + stage_timing.stop_collector(timing_collector) + return result diff --git a/nemo_retriever/src/nemo_retriever/graph/operator_archetype.py b/nemo_retriever/src/nemo_retriever/graph/operator_archetype.py index 3fb9682631..04458e2b50 100644 --- a/nemo_retriever/src/nemo_retriever/graph/operator_archetype.py +++ b/nemo_retriever/src/nemo_retriever/graph/operator_archetype.py @@ -88,7 +88,18 @@ def _resolve_delegate(self, resources: ClusterResources | Resources | None = Non if operator_class is type(self): raise RuntimeError(f"{type(self).__name__} could not resolve a concrete hardware-specific operator.") +<<<<<<< Updated upstream delegate = operator_class(**operator_kwargs) +======= + resolved_kwargs = type(self).variant_operator_kwargs(operator_class, operator_kwargs) + delegate = operator_class(**resolved_kwargs) + # Propagate the pipeline-level stage name (set by RayDataExecutor on the + # archetype class) so the delegate's run() can label timing records + # consistently with the node name rather than the variant class name. + stage_name = getattr(self, "_nr_stage_name", None) + if stage_name is not None: + delegate._nr_stage_name = stage_name +>>>>>>> Stashed changes self._resolved_delegate = delegate self._resolved_delegate_key = cache_key return delegate diff --git a/nemo_retriever/src/nemo_retriever/utils/stage_timing.py b/nemo_retriever/src/nemo_retriever/utils/stage_timing.py new file mode 100644 index 0000000000..a8652881c8 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/utils/stage_timing.py @@ -0,0 +1,334 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Per-stage / per-batch timing for the Ray batch executor. + +When the ``NR_STAGE_TIMING=1`` environment variable is set, the +``RayDataExecutor`` starts a detached named Ray actor that collects +records emitted by :class:`AbstractOperator.run` on every worker. +After the pipeline materialises, the executor pulls the records, +combines them with ``ds.stats()`` text, and writes a human-readable +report (and optional JSON dump via ``NR_STAGE_TIMING_REPORT_PATH``). +""" + +from __future__ import annotations + +import datetime +import json +import logging +import os +import re +import threading +from collections import defaultdict +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional + +logger = logging.getLogger(__name__) + +COLLECTOR_NAME = "nr_stage_timing_collector" +ENABLED_ENV = "NR_STAGE_TIMING" +REPORT_PATH_ENV = "NR_STAGE_TIMING_REPORT_PATH" + +# Driver-side counter of RayDataExecutor.ingest() invocations within this process. +# Used by Phase 1 diagnostics to identify each graph execution distinctly. +_INGEST_CALL_COUNTER = 0 +_INGEST_CALL_COUNTER_LOCK = threading.Lock() + +# Captured once per process so every report file from one run sorts together. +_RUN_TIMESTAMP = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + + +def resolve_report_path(call_index: Optional[int], graph_label: Optional[str]) -> Optional[Path]: + """Compute the JSON output path for one timing report, or ``None`` if file writes are disabled. + + The ``NR_STAGE_TIMING_REPORT_PATH`` env var is treated as a *base*: + + * If it points to an existing directory (or ends with ``/``), files land + inside it as ``timing___