Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions nemo_retriever/src/nemo_retriever/graph/abstract_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,54 @@

from abc import ABC, abstractmethod
import inspect
import os
import resource
import time
from typing import Any, TYPE_CHECKING

from nemo_retriever.utils import stage_timing

if TYPE_CHECKING:
from nemo_retriever.graph.pipeline_graph import Graph, Node


try: # psutil is in the retriever runtime; degrade gracefully if missing
import psutil as _psutil

_PROC = _psutil.Process()
except Exception: # pragma: no cover
_psutil = None
_PROC = None


def _safe_len(data: Any) -> int:
try:
return len(data)
except Exception:
return -1


def _mem_snapshot() -> tuple[float, float]:
"""Return (process_rss_mb, host_available_mb). Zeros if psutil unavailable."""
if _PROC is None or _psutil is None:
return 0.0, 0.0
try:
rss = _PROC.memory_info().rss / 1e6
avail = _psutil.virtual_memory().available / 1e6
return rss, avail
except Exception:
return 0.0, 0.0


def _process_peak_rss_mb() -> float:
"""ru_maxrss high-water mark for this worker process, since process start."""
try:
# On Linux ru_maxrss is in KiB.
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024.0
except Exception:
return 0.0


class AbstractOperator(ABC):
"""Base class for all pipeline operators."""

Expand All @@ -30,9 +72,39 @@ def process(self, data: Any, **kwargs: Any) -> Any: ...
def postprocess(self, data: Any, **kwargs: Any) -> Any: ...

def run(self, data: Any, **kwargs: Any) -> Any:
if not stage_timing.is_enabled():
data = self.preprocess(data, **kwargs)
data = self.process(data, **kwargs)
data = self.postprocess(data, **kwargs)
return data

stage = getattr(self, "_nr_stage_name", None) or type(self).__name__
n_in = _safe_len(data)
rss_b, avail_b = _mem_snapshot()
t0 = time.perf_counter()
data = self.preprocess(data, **kwargs)
t1 = time.perf_counter()
data = self.process(data, **kwargs)
t2 = time.perf_counter()
data = self.postprocess(data, **kwargs)
t3 = time.perf_counter()
rss_a, avail_a = _mem_snapshot()
stage_timing.record_timing(
stage=stage,
n_rows_in=n_in,
n_rows_out=_safe_len(data),
preprocess_ms=(t1 - t0) * 1000.0,
process_ms=(t2 - t1) * 1000.0,
postprocess_ms=(t3 - t2) * 1000.0,
total_ms=(t3 - t0) * 1000.0,
worker_pid=os.getpid(),
wallclock_start=t0,
rss_before_mb=rss_b,
rss_after_mb=rss_a,
rss_peak_mb=_process_peak_rss_mb(),
avail_before_mb=avail_b,
avail_after_mb=avail_a,
)
return data

def __call__(self, data: Any, **kwargs: Any) -> Any:
Expand Down
78 changes: 73 additions & 5 deletions nemo_retriever/src/nemo_retriever/graph/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
VLLM_GPUS_PER_ACTOR,
OCR_GPUS_PER_ACTOR,
)
from nemo_retriever.utils import stage_timing

import logging

Expand Down Expand Up @@ -249,6 +250,10 @@ def ingest(self, data: Any, **kwargs: Any) -> Any:
}
ray_env_vars.update(collect_hf_runtime_env())
ray_env_vars.update(collect_remote_auth_runtime_env())
for _name in (stage_timing.ENABLED_ENV, stage_timing.REPORT_PATH_ENV):
_val = os.environ.get(_name)
if _val:
ray_env_vars[_name] = _val
os.environ["HF_HUB_OFFLINE"] = ray_env_vars["HF_HUB_OFFLINE"]
runtime_env = {"env_vars": ray_env_vars}
ray.init(
Expand All @@ -257,6 +262,10 @@ def ingest(self, data: Any, **kwargs: Any) -> Any:
runtime_env=runtime_env,
)

timing_enabled = stage_timing.is_enabled()
timing_collector = stage_timing.start_collector() if timing_enabled else None
timing_mem_sampler = stage_timing.start_memory_sampler(timing_collector) if timing_collector else None

ctx = rd.DataContext.get_current()
ctx.enable_rich_progress_bars = True
ctx.use_ray_tqdm = False
Expand All @@ -273,6 +282,24 @@ def ingest(self, data: Any, **kwargs: Any) -> Any:
except FileNotFoundError as exc:
raise_input_path_not_found(input_paths or [], exc)
nodes = self._linearize(resolved_graph)
timing_call_index: Optional[int] = None
timing_node_names: List[str] = [n.name for n in nodes]
timing_graph_label: Optional[str] = None
if timing_enabled:
timing_call_index = stage_timing.next_call_index()
timing_graph_label = stage_timing.slugify_graph_label(timing_node_names)
report_path = stage_timing.resolve_report_path(timing_call_index, timing_graph_label)
logger.info(
"RayDataExecutor.ingest() #%02d | %d nodes: [%s] | label=%s | report -> %s",
timing_call_index,
len(timing_node_names),
", ".join(timing_node_names),
timing_graph_label,
report_path if report_path is not None else "<stdout only>",
)
# Per-node operator class stamping happens inside the map_batches
# loop below via stage_timing.make_named_operator_class so two
# nodes sharing the same operator class get distinct stage names.
for node in nodes:
overrides = dict(self._node_overrides.get(node.name, {}))
target_num_rows_per_block = overrides.pop("target_num_rows_per_block", None)
Expand Down Expand Up @@ -360,18 +387,24 @@ def ingest(self, data: Any, **kwargs: Any) -> Any:
group_keys = list(getattr(node.operator_class, "GLOBAL_BATCH_GROUP_KEYS", None) or ())
n_blocks = max(1, int(overrides.get("concurrency") or 1)) if group_keys else 1
if n_blocks > 1:
ds = ds.repartition(num_blocks=n_blocks, keys=group_keys, shuffle=True)
# ds = ds.repartition(num_blocks=n_blocks, keys=group_keys, shuffle=True)
pass
else:
ds = ds.repartition(num_blocks=1)
# ds = ds.repartition(num_blocks=1)
pass
elif target_num_rows_per_block is not None and int(target_num_rows_per_block) > 0:
ds = ds.repartition(target_num_rows_per_block=int(target_num_rows_per_block))
# ds = ds.repartition(target_num_rows_per_block=int(target_num_rows_per_block))
pass
Comment on lines 389 to +397
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Repartition calls silently disabled

The three ds.repartition(...) calls that control block layout for GLOBAL_BATCH_GROUP_KEYS operators (hash-partitioning so co-keyed rows stay co-located) and target_num_rows_per_block tuning have been commented out and replaced with pass. Any operator that sets GLOBAL_BATCH_GROUP_KEYS will now silently receive arbitrarily-partitioned batches, breaking the guarantee that related rows land on the same actor. Similarly, target_num_rows_per_block overrides are completely ignored. These look like debugging leftovers that should be restored before merging.

Prompt To Fix With AI
This is a comment left during a code review.
Path: nemo_retriever/src/nemo_retriever/graph/executor.py
Line: 389-397

Comment:
**Repartition calls silently disabled**

The three `ds.repartition(...)` calls that control block layout for `GLOBAL_BATCH_GROUP_KEYS` operators (hash-partitioning so co-keyed rows stay co-located) and `target_num_rows_per_block` tuning have been commented out and replaced with `pass`. Any operator that sets `GLOBAL_BATCH_GROUP_KEYS` will now silently receive arbitrarily-partitioned batches, breaking the guarantee that related rows land on the same actor. Similarly, `target_num_rows_per_block` overrides are completely ignored. These look like debugging leftovers that should be restored before merging.

How can I resolve this? If you propose a fix, please make it concise.


# Pass the operator class directly to map_batches with
# fn_constructor_kwargs for deferred construction on workers.
# AbstractOperator.__call__ delegates to run(), so each stage
# executes the full preprocess -> process -> postprocess chain.
operator_cls = node.operator_class
if timing_enabled:
operator_cls = stage_timing.make_named_operator_class(operator_cls, node.name)
ds = ds.map_batches(
node.operator_class,
operator_cls,
batch_size=batch_size,
batch_format=batch_format,
num_cpus=num_cpus,
Expand All @@ -380,4 +413,39 @@ def ingest(self, data: Any, **kwargs: Any) -> Any:
**overrides,
)

return ds.to_pandas()
result = ds.to_pandas()
if timing_collector is not None:
try:
# Stop the memory sampler before dumping so no late samples
# arrive after we've read the collector.
stage_timing.stop_memory_sampler(timing_mem_sampler)
ray_stats_text = None
try:
ray_stats_text = ds.stats()
except Exception as exc:
logger.warning("Failed to collect ds.stats(): %s", exc)
try:
records = ray.get(timing_collector.dump.remote())
except Exception as exc:
logger.warning("Failed to retrieve stage timing records: %s", exc)
records = []
try:
memory_samples = ray.get(timing_collector.dump_samples.remote())
except Exception as exc:
logger.warning("Failed to retrieve memory samples: %s", exc)
memory_samples = []
baseline_used_mb = (
getattr(timing_mem_sampler, "baseline_sys_used_mb", 0.0) if timing_mem_sampler else None
)
stage_timing.write_report(
records,
ray_stats_text=ray_stats_text,
call_index=timing_call_index,
graph_label=timing_graph_label,
node_names=timing_node_names,
memory_samples=memory_samples,
baseline_sys_used_mb=baseline_used_mb,
)
finally:
stage_timing.stop_collector(timing_collector)
return result
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,14 @@ def _resolve_delegate(self, resources: ClusterResources | Resources | None = Non
if operator_class is type(self):
raise RuntimeError(f"{type(self).__name__} could not resolve a concrete hardware-specific operator.")

delegate = operator_class(**operator_kwargs)
resolved_kwargs = type(self).variant_operator_kwargs(operator_class, operator_kwargs)
delegate = operator_class(**resolved_kwargs)
# Propagate the pipeline-level stage name (set by RayDataExecutor on the
# archetype class) so the delegate's run() can label timing records
# consistently with the node name rather than the variant class name.
stage_name = getattr(self, "_nr_stage_name", None)
if stage_name is not None:
delegate._nr_stage_name = stage_name
self._resolved_delegate = delegate
self._resolved_delegate_key = cache_key
return delegate
Expand Down
Loading