Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemo_retriever/src/nemo_retriever/graph/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,4 +380,4 @@ def ingest(self, data: Any, **kwargs: Any) -> Any:
**overrides,
)

return ds.to_pandas()
return ds.materialize()
18 changes: 18 additions & 0 deletions nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,24 @@ def _force_cpu_only(node_name: str) -> None:
store_override["concurrency"] = (1, store_workers, 1) if store_workers > 1 else 1
store_override["num_cpus"] = DEFAULT_STORE_CPUS_PER_ACTOR

# IngestVdbOperator opens the LanceDB table with ``mode="overwrite"`` on
# its first batch and ``table.add(...)`` on every subsequent batch. With
# concurrency > 1, multiple actors would each treat their own first batch
# as the overwrite and clobber each other's writes.
overrides.setdefault(IngestVdbOperator.__name__, {})["concurrency"] = 1

# Coalesce Ray Data blocks into ~1000-row LanceDB writes. Every
# ``table.add(rows)`` commits a new entry to Lance's table-version
# manifest, and Lance reads the manifest linearly on every commit *and*
# every table reopen — so commit latency grows with the count of prior
# commits. At the default ~32-row block size coming out of
# ``StreamingRepartition`` this would be ~80k commits for a bo767-sized
# corpus (~80k rows); batching to ~1000 cuts that ~30× to ~3k commits and
# keeps per-commit time roughly flat. Higher values further reduce commit
# count but trade off per-batch memory in the writer actor — callers can
# raise the value via ``node_overrides``; we only set a default here.
overrides[IngestVdbOperator.__name__].setdefault("batch_size", 1000)

return overrides


Expand Down
16 changes: 16 additions & 0 deletions nemo_retriever/src/nemo_retriever/graph_ingestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from nemo_retriever.graph import InprocessExecutor, RayDataExecutor
from nemo_retriever.graph.ingestor_runtime import batch_tuning_to_node_overrides, build_graph
from nemo_retriever.ingestor import ingestor
from nemo_retriever.vdb.operators import _construct_vdb
from nemo_retriever.params import (
ASRParams,
AudioChunkParams,
Expand Down Expand Up @@ -557,8 +558,23 @@ def ingest(self, params: Any = None, **kwargs: Any) -> Any:
result = executor.ingest(self._documents)

self._raise_for_stage_errors(result)
self._finalize_vdb_upload()
return result

def _finalize_vdb_upload(self) -> None:
"""Build the VDB search index once after the graph has finished
writing.

``IngestVdbOperator`` streams per-batch writes during the run; the
index build is a one-shot operation that doesn't need any of the
graph's row data.
"""
params = self._vdb_upload_params
if params is None:
return
vdb = _construct_vdb(vdb_op=params.vdb_op, vdb_kwargs=params.vdb_kwargs)
vdb.build_index()

# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
Expand Down
59 changes: 28 additions & 31 deletions nemo_retriever/src/nemo_retriever/pipeline/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
VideoFrameTextDedupParams,
)
from nemo_retriever.params.models import BatchTuningParams
from nemo_retriever.pipeline.ingest_result import IngestResult
from nemo_retriever.utils.input_files import resolve_input_patterns
from nemo_retriever.utils.remote_auth import resolve_remote_api_key

Expand Down Expand Up @@ -577,31 +578,26 @@ def _build_ingestor(
return ingestor


def _collect_results(run_mode: str, result: Any) -> tuple[list[dict[str, Any]], Any, float, int]:
"""Materialize the graph result into a list of records + DataFrame.
def _collect_results(run_mode: str, result: Any) -> tuple["IngestResult", float, int]:
"""Wrap the graph result into a streaming-friendly IngestResult.

Ingest may return a ``pandas.DataFrame`` (in-process or after
``ray.data.Dataset.to_pandas()`` in the executor), a ``ray.data.Dataset``,
or a :class:`~nemo_retriever.service_ingestor.ServiceIngestResult` (service
mode); normalize to a consistent ``(records, DataFrame, secs, units)`` tuple.
Ingest may return a ``pandas.DataFrame`` (in-process), a ``ray.data.Dataset``
(batch mode), or a ``nemo_retriever.service_ingestor.ServiceIngestResult``
(service mode); each is mapped to the matching ``IngestResult`` constructor without
ever pulling the full corpus onto the driver as a single pandas DataFrame.

Returns ``(records, result_df, ray_download_secs, num_input_units)``.
Returns ``(result_handle, ray_download_secs, num_input_units)``.
"""

if run_mode == "service":
records = list(result)
result_df = pd.DataFrame(records) if records else pd.DataFrame()
num_units = getattr(result, "total_pages", 0) or len(records)
return records, result_df, 0.0, num_units
handle = IngestResult.from_service(result)
return handle, 0.0, handle.unique_source_count()

if isinstance(result, pd.DataFrame):
result_df = result
handle = IngestResult.from_dataframe(result)
else:
result_df = result.to_pandas()
records = result_df.to_dict("records")
ray_download_time = 0.0

return records, result_df, float(ray_download_time), _count_input_units(result_df)
handle = IngestResult.from_dataset(result)
return handle, 0.0, handle.unique_source_count()


def _count_uploadable_vdb_records(records: list[dict[str, Any]]) -> int:
Expand Down Expand Up @@ -1411,7 +1407,8 @@ def run(
ingest_start = time.perf_counter()
raw_result = ingestor.ingest()
ingestion_only_total_time = time.perf_counter() - ingest_start
ingest_local_results, result_df, ray_download_time, num_rows = _collect_results(run_mode, raw_result)
ingest_result, ray_download_time, num_rows = _collect_results(run_mode, raw_result)
total_row_count = ingest_result.row_count()

if run_mode == "service":
# The service writes embeddings to LanceDB server-side during
Expand All @@ -1421,14 +1418,14 @@ def run(
logger.info(
"Service-mode ingestion complete (%d results from %d input(s), %.1fs). "
"VDB writes are handled server-side.",
len(ingest_local_results),
total_row_count,
num_rows,
ingestion_only_total_time,
)
uploadable_vdb_records = len(ingest_local_results)
uploadable_vdb_records = total_row_count
vdb_upload_time = 0.0
else:
uploadable_vdb_records = _count_uploadable_vdb_records(ingest_local_results)
uploadable_vdb_records = ingest_result.count_uploadable_vdb_records()
vdb_upload_time = 0.0
if uploadable_vdb_records == 0:
logger.warning(
Expand All @@ -1440,26 +1437,26 @@ def run(
"Prepared %s uploadable VDB records (%s graph rows) for in-graph upload to %s "
"(row conversion count, not backend-confirmed writes; see VDB/operator logs for persistence).",
uploadable_vdb_records,
len(ingest_local_results),
total_row_count,
resolved_vdb_op,
)

if save_intermediate is not None:
out_dir = Path(save_intermediate).expanduser().resolve()
out_dir.mkdir(parents=True, exist_ok=True)
# Streaming write: Ray Data produces a directory of per-block
# parquet files (one file in inprocess/service mode). Existing
# readers in nemo_retriever already handle the directory form.
out_path = out_dir / "extraction.parquet"
result_df.to_parquet(out_path, index=False)
ingest_result.write_parquet_dir(out_path)
logger.info("Wrote extraction Parquet for intermediate use: %s", out_path)

if detection_summary_file is not None:
from nemo_retriever.utils.detection_summary import (
collect_detection_summary_from_df,
write_detection_summary,
)
from nemo_retriever.utils.detection_summary import write_detection_summary

write_detection_summary(
Path(detection_summary_file),
collect_detection_summary_from_df(result_df),
ingest_result.detection_summary(),
)

if uploadable_vdb_records == 0 and run_mode != "service":
Expand Down Expand Up @@ -1494,7 +1491,7 @@ def run(
"input_path": str(Path(input_path).resolve()),
"input_pages": int(num_rows),
"num_pages": int(num_rows),
"num_rows": int(len(result_df.index)),
"num_rows": int(total_row_count),
"ingestion_only_secs": float(ingestion_only_total_time),
"ray_download_secs": float(ray_download_time),
"vdb_upload_secs": float(vdb_upload_time),
Expand Down Expand Up @@ -1572,7 +1569,7 @@ def run(
"input_path": str(Path(input_path).resolve()),
"input_pages": int(num_rows),
"num_pages": int(num_rows),
"num_rows": int(len(result_df.index)),
"num_rows": int(total_row_count),
"ingestion_only_secs": float(ingestion_only_total_time),
"ray_download_secs": float(ray_download_time),
"vdb_upload_secs": float(vdb_upload_time),
Expand Down Expand Up @@ -1600,7 +1597,7 @@ def run(
"input_path": str(Path(input_path).resolve()),
"input_pages": int(num_rows),
"num_pages": int(num_rows),
"num_rows": int(len(result_df.index)),
"num_rows": int(total_row_count),
"ingestion_only_secs": float(ingestion_only_total_time),
"ray_download_secs": float(ray_download_time),
"vdb_upload_secs": float(vdb_upload_time),
Expand Down
Loading
Loading