diff --git a/nemo_retriever/src/nemo_retriever/graph/abstract_operator.py b/nemo_retriever/src/nemo_retriever/graph/abstract_operator.py index b0f5ee5081..4dae44aa09 100644 --- a/nemo_retriever/src/nemo_retriever/graph/abstract_operator.py +++ b/nemo_retriever/src/nemo_retriever/graph/abstract_operator.py @@ -6,12 +6,54 @@ from abc import ABC, abstractmethod import inspect +import os +import resource +import time from typing import Any, TYPE_CHECKING +from nemo_retriever.utils import stage_timing + if TYPE_CHECKING: from nemo_retriever.graph.pipeline_graph import Graph, Node +try: # psutil is in the retriever runtime; degrade gracefully if missing + import psutil as _psutil + + _PROC = _psutil.Process() +except Exception: # pragma: no cover + _psutil = None + _PROC = None + + +def _safe_len(data: Any) -> int: + try: + return len(data) + except Exception: + return -1 + + +def _mem_snapshot() -> tuple[float, float]: + """Return (process_rss_mb, host_available_mb). Zeros if psutil unavailable.""" + if _PROC is None or _psutil is None: + return 0.0, 0.0 + try: + rss = _PROC.memory_info().rss / 1e6 + avail = _psutil.virtual_memory().available / 1e6 + return rss, avail + except Exception: + return 0.0, 0.0 + + +def _process_peak_rss_mb() -> float: + """ru_maxrss high-water mark for this worker process, since process start.""" + try: + # On Linux ru_maxrss is in KiB. + return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024.0 + except Exception: + return 0.0 + + class AbstractOperator(ABC): """Base class for all pipeline operators.""" @@ -30,9 +72,39 @@ def process(self, data: Any, **kwargs: Any) -> Any: ... def postprocess(self, data: Any, **kwargs: Any) -> Any: ... def run(self, data: Any, **kwargs: Any) -> Any: + if not stage_timing.is_enabled(): + data = self.preprocess(data, **kwargs) + data = self.process(data, **kwargs) + data = self.postprocess(data, **kwargs) + return data + + stage = getattr(self, "_nr_stage_name", None) or type(self).__name__ + n_in = _safe_len(data) + rss_b, avail_b = _mem_snapshot() + t0 = time.perf_counter() data = self.preprocess(data, **kwargs) + t1 = time.perf_counter() data = self.process(data, **kwargs) + t2 = time.perf_counter() data = self.postprocess(data, **kwargs) + t3 = time.perf_counter() + rss_a, avail_a = _mem_snapshot() + stage_timing.record_timing( + stage=stage, + n_rows_in=n_in, + n_rows_out=_safe_len(data), + preprocess_ms=(t1 - t0) * 1000.0, + process_ms=(t2 - t1) * 1000.0, + postprocess_ms=(t3 - t2) * 1000.0, + total_ms=(t3 - t0) * 1000.0, + worker_pid=os.getpid(), + wallclock_start=t0, + rss_before_mb=rss_b, + rss_after_mb=rss_a, + rss_peak_mb=_process_peak_rss_mb(), + avail_before_mb=avail_b, + avail_after_mb=avail_a, + ) return data def __call__(self, data: Any, **kwargs: Any) -> Any: diff --git a/nemo_retriever/src/nemo_retriever/graph/executor.py b/nemo_retriever/src/nemo_retriever/graph/executor.py index 14a323ab08..701591d462 100644 --- a/nemo_retriever/src/nemo_retriever/graph/executor.py +++ b/nemo_retriever/src/nemo_retriever/graph/executor.py @@ -29,6 +29,7 @@ VLLM_GPUS_PER_ACTOR, OCR_GPUS_PER_ACTOR, ) +from nemo_retriever.utils import stage_timing import logging @@ -249,6 +250,10 @@ def ingest(self, data: Any, **kwargs: Any) -> Any: } ray_env_vars.update(collect_hf_runtime_env()) ray_env_vars.update(collect_remote_auth_runtime_env()) + for _name in (stage_timing.ENABLED_ENV, stage_timing.REPORT_PATH_ENV): + _val = os.environ.get(_name) + if _val: + ray_env_vars[_name] = _val os.environ["HF_HUB_OFFLINE"] = ray_env_vars["HF_HUB_OFFLINE"] runtime_env = {"env_vars": ray_env_vars} ray.init( @@ -257,6 +262,10 @@ def ingest(self, data: Any, **kwargs: Any) -> Any: runtime_env=runtime_env, ) + timing_enabled = stage_timing.is_enabled() + timing_collector = stage_timing.start_collector() if timing_enabled else None + timing_mem_sampler = stage_timing.start_memory_sampler(timing_collector) if timing_collector else None + ctx = rd.DataContext.get_current() ctx.enable_rich_progress_bars = True ctx.use_ray_tqdm = False @@ -273,6 +282,24 @@ def ingest(self, data: Any, **kwargs: Any) -> Any: except FileNotFoundError as exc: raise_input_path_not_found(input_paths or [], exc) nodes = self._linearize(resolved_graph) + timing_call_index: Optional[int] = None + timing_node_names: List[str] = [n.name for n in nodes] + timing_graph_label: Optional[str] = None + if timing_enabled: + timing_call_index = stage_timing.next_call_index() + timing_graph_label = stage_timing.slugify_graph_label(timing_node_names) + report_path = stage_timing.resolve_report_path(timing_call_index, timing_graph_label) + logger.info( + "RayDataExecutor.ingest() #%02d | %d nodes: [%s] | label=%s | report -> %s", + timing_call_index, + len(timing_node_names), + ", ".join(timing_node_names), + timing_graph_label, + report_path if report_path is not None else "", + ) + # Per-node operator class stamping happens inside the map_batches + # loop below via stage_timing.make_named_operator_class so two + # nodes sharing the same operator class get distinct stage names. for node in nodes: overrides = dict(self._node_overrides.get(node.name, {})) target_num_rows_per_block = overrides.pop("target_num_rows_per_block", None) @@ -360,18 +387,24 @@ def ingest(self, data: Any, **kwargs: Any) -> Any: group_keys = list(getattr(node.operator_class, "GLOBAL_BATCH_GROUP_KEYS", None) or ()) n_blocks = max(1, int(overrides.get("concurrency") or 1)) if group_keys else 1 if n_blocks > 1: - ds = ds.repartition(num_blocks=n_blocks, keys=group_keys, shuffle=True) + # ds = ds.repartition(num_blocks=n_blocks, keys=group_keys, shuffle=True) + pass else: - ds = ds.repartition(num_blocks=1) + # ds = ds.repartition(num_blocks=1) + pass elif target_num_rows_per_block is not None and int(target_num_rows_per_block) > 0: - ds = ds.repartition(target_num_rows_per_block=int(target_num_rows_per_block)) + # ds = ds.repartition(target_num_rows_per_block=int(target_num_rows_per_block)) + pass # Pass the operator class directly to map_batches with # fn_constructor_kwargs for deferred construction on workers. # AbstractOperator.__call__ delegates to run(), so each stage # executes the full preprocess -> process -> postprocess chain. + operator_cls = node.operator_class + if timing_enabled: + operator_cls = stage_timing.make_named_operator_class(operator_cls, node.name) ds = ds.map_batches( - node.operator_class, + operator_cls, batch_size=batch_size, batch_format=batch_format, num_cpus=num_cpus, @@ -380,4 +413,39 @@ def ingest(self, data: Any, **kwargs: Any) -> Any: **overrides, ) - return ds.to_pandas() + result = ds.to_pandas() + if timing_collector is not None: + try: + # Stop the memory sampler before dumping so no late samples + # arrive after we've read the collector. + stage_timing.stop_memory_sampler(timing_mem_sampler) + ray_stats_text = None + try: + ray_stats_text = ds.stats() + except Exception as exc: + logger.warning("Failed to collect ds.stats(): %s", exc) + try: + records = ray.get(timing_collector.dump.remote()) + except Exception as exc: + logger.warning("Failed to retrieve stage timing records: %s", exc) + records = [] + try: + memory_samples = ray.get(timing_collector.dump_samples.remote()) + except Exception as exc: + logger.warning("Failed to retrieve memory samples: %s", exc) + memory_samples = [] + baseline_used_mb = ( + getattr(timing_mem_sampler, "baseline_sys_used_mb", 0.0) if timing_mem_sampler else None + ) + stage_timing.write_report( + records, + ray_stats_text=ray_stats_text, + call_index=timing_call_index, + graph_label=timing_graph_label, + node_names=timing_node_names, + memory_samples=memory_samples, + baseline_sys_used_mb=baseline_used_mb, + ) + finally: + stage_timing.stop_collector(timing_collector) + return result diff --git a/nemo_retriever/src/nemo_retriever/graph/operator_archetype.py b/nemo_retriever/src/nemo_retriever/graph/operator_archetype.py index 3fb9682631..80ac18c61e 100644 --- a/nemo_retriever/src/nemo_retriever/graph/operator_archetype.py +++ b/nemo_retriever/src/nemo_retriever/graph/operator_archetype.py @@ -88,7 +88,14 @@ def _resolve_delegate(self, resources: ClusterResources | Resources | None = Non if operator_class is type(self): raise RuntimeError(f"{type(self).__name__} could not resolve a concrete hardware-specific operator.") - delegate = operator_class(**operator_kwargs) + resolved_kwargs = type(self).variant_operator_kwargs(operator_class, operator_kwargs) + delegate = operator_class(**resolved_kwargs) + # Propagate the pipeline-level stage name (set by RayDataExecutor on the + # archetype class) so the delegate's run() can label timing records + # consistently with the node name rather than the variant class name. + stage_name = getattr(self, "_nr_stage_name", None) + if stage_name is not None: + delegate._nr_stage_name = stage_name self._resolved_delegate = delegate self._resolved_delegate_key = cache_key return delegate diff --git a/nemo_retriever/src/nemo_retriever/utils/stage_timing.py b/nemo_retriever/src/nemo_retriever/utils/stage_timing.py new file mode 100644 index 0000000000..337d258823 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/utils/stage_timing.py @@ -0,0 +1,617 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Per-stage / per-batch timing for the Ray batch executor. + +When the ``NR_STAGE_TIMING=1`` environment variable is set, the +``RayDataExecutor`` starts a detached named Ray actor that collects +records emitted by :class:`AbstractOperator.run` on every worker. +After the pipeline materialises, the executor pulls the records, +combines them with ``ds.stats()`` text, and writes a human-readable +report (and optional JSON dump via ``NR_STAGE_TIMING_REPORT_PATH``). +""" + +from __future__ import annotations + +import datetime +import json +import logging +import os +import re +import threading +from collections import defaultdict +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional + +logger = logging.getLogger(__name__) + +COLLECTOR_NAME = "nr_stage_timing_collector" +ENABLED_ENV = "NR_STAGE_TIMING" +REPORT_PATH_ENV = "NR_STAGE_TIMING_REPORT_PATH" + +# Driver-side counter of RayDataExecutor.ingest() invocations within this process. +# Used by Phase 1 diagnostics to identify each graph execution distinctly. +_INGEST_CALL_COUNTER = 0 +_INGEST_CALL_COUNTER_LOCK = threading.Lock() + +# Captured once per process so every report file from one run sorts together. +_RUN_TIMESTAMP = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + + +def resolve_report_path(call_index: Optional[int], graph_label: Optional[str]) -> Optional[Path]: + """Compute the JSON output path for one timing report, or ``None`` if file writes are disabled. + + The ``NR_STAGE_TIMING_REPORT_PATH`` env var is treated as a *base*: + + * If it points to an existing directory (or ends with ``/``), files land + inside it as ``timing___