From ecf08c468b9be2d0c3c1f4a5e0ec06b17d7c5620 Mon Sep 17 00:00:00 2001 From: Jacob Ioffe Date: Thu, 21 May 2026 20:00:11 +0000 Subject: [PATCH] Replace ingest input-type routing with manifest branches (cherry picked from commit 74975d118328f5abf1bdd68ce581df5c373a0b00) --- .../nemo-retriever/references/ingest.md | 35 +- .../src/nemo_retriever/adapters/cli/main.py | 72 +-- .../adapters/cli/sdk_workflow.py | 204 +-------- .../src/nemo_retriever/branch_extraction.py | 294 ++++++++++++ .../src/nemo_retriever/graph/executor.py | 11 +- .../nemo_retriever/graph/ingestor_runtime.py | 62 ++- .../graph/multi_type_extract_operator.py | 40 +- .../src/nemo_retriever/graph_ingestor.py | 420 ++++++++++-------- .../src/nemo_retriever/ingest_manifest.py | 221 +++++++++ .../src/nemo_retriever/params/models.py | 4 - nemo_retriever/tests/test_ingest_interface.py | 18 +- nemo_retriever/tests/test_ingest_manifest.py | 266 +++++++++++ nemo_retriever/tests/test_ingest_plans.py | 28 -- nemo_retriever/tests/test_pipeline_graph.py | 112 +++-- .../tests/test_root_cli_workflow.py | 181 +------- 15 files changed, 1199 insertions(+), 769 deletions(-) create mode 100644 nemo_retriever/src/nemo_retriever/branch_extraction.py create mode 100644 nemo_retriever/src/nemo_retriever/ingest_manifest.py create mode 100644 nemo_retriever/tests/test_ingest_manifest.py diff --git a/.claude/skills/nemo-retriever/references/ingest.md b/.claude/skills/nemo-retriever/references/ingest.md index b3a52788ce..bf354386ba 100644 --- a/.claude/skills/nemo-retriever/references/ingest.md +++ b/.claude/skills/nemo-retriever/references/ingest.md @@ -1,7 +1,7 @@ # retriever ingest -End-to-end ingestion of documents and media into a LanceDB table — runs the -full extract → embed → vector-DB pipeline in a single command. +End-to-end ingestion of supported documents and media into a LanceDB table — runs the full +extract → embed → vector-DB pipeline in a single command. If flags below look stale, re-check `retriever ingest --help`. @@ -9,9 +9,8 @@ If flags below look stale, re-check `retriever ingest --help`. - You have one or more supported files (or a directory/glob of files) and want them searchable via `retriever query`. -- You want the default pipeline: auto-select extraction for PDF/DOC/PPTX, - text, HTML, image, audio, or video inputs, then embed and insert into - LanceDB. No per-stage tuning needed. +- You want the default pipeline: PDF split → extraction → page-element + detection → OCRv2 → embedding → LanceDB insert. No per-stage tuning needed. **Use a different command when:** @@ -25,7 +24,7 @@ If flags below look stale, re-check `retriever ingest --help`. ## Canonical invocations -Ingest a single file into the default table (`lancedb/nv-ingest.lance`): +Ingest a single PDF into the default table (`lancedb/nemo-retriever.lance`): ```bash retriever ingest data/multimodal_test.pdf @@ -43,15 +42,6 @@ Ingest via glob: retriever ingest "data/**/*" ``` -Force a specific input family: - -```bash -retriever ingest data/slides/ --input-type doc -retriever ingest data/images/ --input-type image -retriever ingest data/audio/ --input-type audio -retriever ingest data/video/ --input-type video -``` - Write to a custom DB / table: ```bash @@ -62,11 +52,8 @@ retriever ingest data/multimodal_test.pdf \ ## Inputs -- **Positional `DOCUMENTS...`** — one or more file paths, directories, or - shell globs. Required, repeatable. -- **Supported input types** — `pdf`, `doc` (`.docx`, `.pptx`), `txt`, `html`, - `image` (`.jpg`, `.jpeg`, `.png`, `.tiff`, `.tif`, `.bmp`, `.svg`), - `audio` (`.mp3`, `.wav`, `.m4a`), and `video` (`.mp4`, `.mov`, `.mkv`). +- **Positional `DOCUMENTS...`** — one or more of: PDF file paths, directories + containing PDFs, or shell globs. Required, repeatable. ## Outputs @@ -81,13 +68,12 @@ retriever ingest data/multimodal_test.pdf \ | Flag | Default | Notes | |---|---|---| | `--lancedb-uri` | `lancedb` | Path or URI of the LanceDB database. | -| `--table-name` | `nv-ingest` | LanceDB table to write into. Must match `retriever query`'s table on read. | -| `--input-type` | `auto` | Input family to ingest. `auto` detects from file extensions and supports mixed directories. | +| `--table-name` | `nemo-retriever` | LanceDB table to write into. Must match `retriever query`'s table on read. | | `--run-mode` | `inprocess` | `inprocess` for local runs; `batch` for the SDK batch ingestor. | ## Pipeline shape -For PDF/DOC/PPTX inputs, `ingest` runs the optimized document pipeline: +The default `ingest` runs 8 stages, in order: 1. `DocToPdfConversionActor` — non-PDF inputs → PDF (no-op for PDFs). 2. `PDFSplitActor` — split into per-page tasks. @@ -98,9 +84,6 @@ For PDF/DOC/PPTX inputs, `ingest` runs the optimized document pipeline: 7. `_BatchEmbedActor` — embed primitives with `llama-nemotron-embed-1b-v2`. 8. `IngestVdbOperator` — insert rows into LanceDB. -For text, HTML, image, audio, video, or mixed `auto` inputs, `ingest` routes -through the same GraphIngestor extraction paths used by `retriever pipeline`. - ## Common failure modes - **`Clamping num_partitions from 16 to 7`** — informational, not an error. diff --git a/nemo_retriever/src/nemo_retriever/adapters/cli/main.py b/nemo_retriever/src/nemo_retriever/adapters/cli/main.py index c39d6f10c9..b9a7e63668 100644 --- a/nemo_retriever/src/nemo_retriever/adapters/cli/main.py +++ b/nemo_retriever/src/nemo_retriever/adapters/cli/main.py @@ -17,12 +17,9 @@ import typer from nemo_retriever.adapters.cli.sdk_workflow import ( - IngestInputTypeValue, IngestRunModeValue, - LocalIngestEmbedBackendValue, OcrLangValue, OcrVersionValue, - TableOutputFormatValue, ingest_documents, query_documents, ) @@ -145,12 +142,7 @@ def main() -> None: def ingest_command( documents: list[str] = typer.Argument( ..., - help="One or more file paths, directories, or globs to ingest.", - ), - input_type: IngestInputTypeValue = typer.Option( - "auto", - "--input-type", - help="Input type: auto, pdf, doc, txt, html, image, audio, or video.", + help="One or more files, directories, or globs. Supported file types are detected automatically.", ), lancedb_uri: str = typer.Option("lancedb", "--lancedb-uri", help="LanceDB database URI."), table_name: str = typer.Option("nv-ingest", "--table-name", help="LanceDB table name."), @@ -199,22 +191,12 @@ def ingest_command( "--table-structure-invoke-url", help="Table-structure NIM endpoint URL.", ), - table_output_format: TableOutputFormatValue | None = typer.Option( - None, - "--table-output-format", - help="Table text format. 'markdown' enables local table-structure extraction.", - ), embed_invoke_url: str | None = typer.Option(None, "--embed-invoke-url", help="Embedding NIM endpoint URL."), embed_model_name: str | None = typer.Option( None, "--embed-model-name", help="Optional embedding model name override.", ), - local_ingest_embed_backend: LocalIngestEmbedBackendValue | None = typer.Option( - None, - "--local-ingest-embed-backend", - help="Local ingest-time text embedder when --embed-invoke-url is unset.", - ), pdf_extract_workers: int | None = typer.Option( None, "--pdf-extract-workers", @@ -251,12 +233,6 @@ def ingest_command( min=0.0, help="CPUs reserved per page-element detection actor in batch mode.", ), - page_elements_gpus_per_actor: float | None = typer.Option( - None, - "--page-elements-gpus-per-actor", - min=0.0, - help="GPUs reserved per local page-element detection actor in batch mode.", - ), ocr_workers: int | None = typer.Option( None, "--ocr-workers", @@ -275,36 +251,6 @@ def ingest_command( min=0.0, help="CPUs reserved per OCR actor in batch mode.", ), - ocr_gpus_per_actor: float | None = typer.Option( - None, - "--ocr-gpus-per-actor", - min=0.0, - help="GPUs reserved per local OCR actor in batch mode.", - ), - table_structure_workers: int | None = typer.Option( - None, - "--table-structure-workers", - min=1, - help="Number of Ray actors for table-structure extraction in batch mode.", - ), - table_structure_batch_size: int | None = typer.Option( - None, - "--table-structure-batch-size", - min=1, - help="Table-structure extraction batch size per actor in batch mode.", - ), - table_structure_cpus_per_actor: float | None = typer.Option( - None, - "--table-structure-cpus-per-actor", - min=0.0, - help="CPUs reserved per table-structure actor in batch mode.", - ), - table_structure_gpus_per_actor: float | None = typer.Option( - None, - "--table-structure-gpus-per-actor", - min=0.0, - help="GPUs reserved per local table-structure actor in batch mode.", - ), embed_workers: int | None = typer.Option( None, "--embed-workers", @@ -323,12 +269,6 @@ def ingest_command( min=0.0, help="CPUs reserved per embedding actor in batch mode.", ), - embed_gpus_per_actor: float | None = typer.Option( - None, - "--embed-gpus-per-actor", - min=0.0, - help="GPUs reserved per local embedding actor in batch mode.", - ), quiet: bool = typer.Option( False, "--quiet", @@ -347,7 +287,6 @@ def ingest_command( with capture: summary = ingest_documents( documents, - input_type=input_type, run_mode=run_mode, ray_address=ray_address, ray_log_to_driver=ray_log_to_driver, @@ -360,29 +299,20 @@ def ingest_command( ocr_lang=ocr_lang, graphic_elements_invoke_url=graphic_elements_invoke_url, table_structure_invoke_url=table_structure_invoke_url, - table_output_format=table_output_format, embed_invoke_url=embed_invoke_url, embed_model_name=embed_model_name, - local_ingest_embed_backend=local_ingest_embed_backend, pdf_extract_workers=pdf_extract_workers, pdf_extract_batch_size=pdf_extract_batch_size, pdf_extract_cpus_per_task=pdf_extract_cpus_per_task, page_elements_workers=page_elements_workers, page_elements_batch_size=page_elements_batch_size, page_elements_cpus_per_actor=page_elements_cpus_per_actor, - page_elements_gpus_per_actor=page_elements_gpus_per_actor, ocr_workers=ocr_workers, ocr_batch_size=ocr_batch_size, ocr_cpus_per_actor=ocr_cpus_per_actor, - ocr_gpus_per_actor=ocr_gpus_per_actor, - table_structure_workers=table_structure_workers, - table_structure_batch_size=table_structure_batch_size, - table_structure_cpus_per_actor=table_structure_cpus_per_actor, - table_structure_gpus_per_actor=table_structure_gpus_per_actor, embed_workers=embed_workers, embed_batch_size=embed_batch_size, embed_cpus_per_actor=embed_cpus_per_actor, - embed_gpus_per_actor=embed_gpus_per_actor, ) except _ROOT_CLI_ERRORS as exc: typer.echo(f"Error: {exc}", err=True) diff --git a/nemo_retriever/src/nemo_retriever/adapters/cli/sdk_workflow.py b/nemo_retriever/src/nemo_retriever/adapters/cli/sdk_workflow.py index 3bb3c78d1b..5c11461827 100644 --- a/nemo_retriever/src/nemo_retriever/adapters/cli/sdk_workflow.py +++ b/nemo_retriever/src/nemo_retriever/adapters/cli/sdk_workflow.py @@ -9,50 +9,22 @@ from nemo_retriever.ingestor import create_ingestor from nemo_retriever.ocr.config import OCRLang, OCRVersion -from nemo_retriever.params import ( - AudioChunkParams, - AudioVisualFuseParams, - BatchTuningParams, - EmbedParams, - ExtractParams, - HtmlChunkParams, - TextChunkParams, - VdbUploadParams, - VideoFrameParams, - VideoFrameTextDedupParams, -) +from nemo_retriever.params import BatchTuningParams, EmbedParams, ExtractParams, VdbUploadParams from nemo_retriever.params.utils import normalize_embed_kwargs from nemo_retriever.retriever import Retriever from nemo_retriever.utils.input_files import ( AUTO_INPUT_EXTENSIONS, - INPUT_TYPE_EXTENSIONS, expand_input_file_patterns, - input_type_for_path, resolve_input_files, ) from nemo_retriever.utils.remote_auth import resolve_remote_api_key from nemo_retriever.vdb.records import RetrievalHit -IngestInputTypeValue = Literal["auto", "pdf", "doc", "txt", "html", "image", "audio", "video"] IngestRunModeValue = Literal["inprocess", "batch"] -LocalIngestEmbedBackendValue = Literal["vllm", "hf"] OcrLangValue = OCRLang OcrVersionValue = OCRVersion -TableOutputFormatValue = Literal["pseudo_markdown", "markdown"] _SUPPORTED_RUN_MODES: tuple[IngestRunModeValue, ...] = ("inprocess", "batch") -_SUPPORTED_INPUT_TYPES: tuple[IngestInputTypeValue, ...] = ( - "auto", - "pdf", - "doc", - "txt", - "html", - "image", - "audio", - "video", -) -_AUDIO_SPLIT_INTERVAL = 500000 -_VIDEO_FRAME_FPS = 0.5 def _validate_run_mode(run_mode: str) -> IngestRunModeValue: @@ -61,145 +33,43 @@ def _validate_run_mode(run_mode: str) -> IngestRunModeValue: return cast(IngestRunModeValue, run_mode) -def _validate_input_type(input_type: str) -> IngestInputTypeValue: - if input_type not in _SUPPORTED_INPUT_TYPES: - raise ValueError(f"input_type must be one of {', '.join(_SUPPORTED_INPUT_TYPES)}, got {input_type!r}.") - return cast(IngestInputTypeValue, input_type) - - -def _input_type_for_extension(path: str) -> IngestInputTypeValue | None: - return cast(IngestInputTypeValue | None, input_type_for_path(path)) - - -def _validate_ingest_document_types( - documents: Sequence[str], - *, - input_type: IngestInputTypeValue, -) -> None: - allowed = AUTO_INPUT_EXTENSIONS if input_type == "auto" else INPUT_TYPE_EXTENSIONS[input_type] +# The ingest command accepts bare dataset directories; expand those to supported +# files before passing file/glob inputs through the shared input normalizer. +def _validate_ingest_document_types(documents: Sequence[str]) -> None: unsupported = [ document for document in documents - if not any(ch in str(document) for ch in "*?[") and Path(document).suffix.lower() not in allowed + if not any(ch in str(document) for ch in "*?[") and Path(document).suffix.lower() not in AUTO_INPUT_EXTENSIONS ] if unsupported: examples = ", ".join(unsupported[:3]) - if input_type == "auto": - raise ValueError(f"Unsupported input file type(s) for retriever ingest: {examples}") - raise ValueError(f"Input file type(s) do not match --input-type={input_type!r}: {examples}") + raise ValueError(f"Unsupported input file type(s) for retriever ingest: {examples}") -# The ingest command accepts bare dataset directories; expand those to supported -# files before passing file/glob inputs through the shared input normalizer. -def _expand_ingest_documents( - documents: Sequence[str], - *, - input_type: IngestInputTypeValue, -) -> list[str]: +def _expand_ingest_documents(documents: Sequence[str]) -> list[str]: inputs: list[str] = [] for document in documents: raw_document = str(document) path = Path(raw_document).expanduser() if path.is_dir(): - directory_files = resolve_input_files(path, input_type) + directory_files = resolve_input_files(path, "auto") if not directory_files: - if input_type == "auto": - raise FileNotFoundError(f"No supported ingest files found under directory: {path}") - raise FileNotFoundError(f"No {input_type} files found under directory: {path}") + raise FileNotFoundError(f"No supported ingest files found under directory: {path}") inputs.extend(str(file) for file in directory_files) else: inputs.append(raw_document) document_list = expand_input_file_patterns(inputs) - _validate_ingest_document_types(document_list, input_type=input_type) + _validate_ingest_document_types(document_list) return document_list -def _resolve_effective_input_type( - documents: Sequence[str], - *, - input_type: IngestInputTypeValue, -) -> IngestInputTypeValue: - if input_type != "auto": - return "pdf" if input_type == "doc" else input_type - - observed = { - resolved - for document in documents - if not any(ch in str(document) for ch in "*?[") - if (resolved := _input_type_for_extension(str(document))) is not None - } - if not observed: - return "auto" - if observed <= {"pdf", "doc"}: - return "pdf" - if len(observed) == 1: - only = next(iter(observed)) - return "pdf" if only == "doc" else only - return "auto" - - -def _default_asr_params() -> Any: - from nemo_retriever.audio import asr_params_from_env - - return asr_params_from_env() - - -def _attach_extract_stage( - ingestor: Any, - *, - input_type: IngestInputTypeValue, - extract_params: ExtractParams | None, -) -> Any: - if input_type == "pdf": - params = extract_params or ExtractParams() - return ingestor.extract(params, extraction_mode="pdf") - if input_type == "txt": - return ingestor.extract_txt(TextChunkParams()) - if input_type == "html": - return ingestor.extract_html(HtmlChunkParams()) - if input_type == "image": - return ingestor.extract_image_files(extract_params or ExtractParams()) - if input_type == "audio": - asr_params = _default_asr_params().model_copy(update={"segment_audio": False}) - return ingestor.extract_audio( - params=AudioChunkParams(split_type="size", split_interval=_AUDIO_SPLIT_INTERVAL), - asr_params=asr_params, - ) - if input_type == "video": - asr_params = _default_asr_params().model_copy(update={"segment_audio": False}) - return ingestor.extract_video( - params=AudioChunkParams( - enabled=True, - split_type="size", - split_interval=_AUDIO_SPLIT_INTERVAL, - ), - asr_params=asr_params, - video_frame_params=VideoFrameParams( - enabled=True, - fps=_VIDEO_FRAME_FPS, - dedup=True, - ), - video_text_dedup_params=VideoFrameTextDedupParams(enabled=True, max_dropped_frames=2), - av_fuse_params=AudioVisualFuseParams(enabled=True), - extract_params=extract_params or ExtractParams(), - ) - return ingestor.extract( - extract_params or ExtractParams(), - extraction_mode="auto", - text_params=TextChunkParams(), - html_params=HtmlChunkParams(), - ) - - def _build_embed_kwargs( embed_invoke_url: str | None, embed_model_name: str | None, - local_ingest_embed_backend: LocalIngestEmbedBackendValue | None = None, embed_workers: int | None = None, embed_batch_size: int | None = None, embed_cpus_per_actor: float | None = None, - embed_gpus_per_actor: float | None = None, ) -> dict[str, Any]: embed_kwargs: dict[str, Any] = {} if embed_invoke_url is not None: @@ -208,13 +78,10 @@ def _build_embed_kwargs( # Remote HTTP embedding reads model_name; local/GPU paths read embed_model_name. embed_kwargs["model_name"] = embed_model_name embed_kwargs["embed_model_name"] = embed_model_name - if local_ingest_embed_backend is not None: - embed_kwargs["local_ingest_embed_backend"] = local_ingest_embed_backend embed_tuning = _build_embed_batch_tuning( embed_workers=embed_workers, embed_batch_size=embed_batch_size, embed_cpus_per_actor=embed_cpus_per_actor, - embed_gpus_per_actor=embed_gpus_per_actor, ) if embed_tuning is not None: embed_kwargs["batch_tuning"] = embed_tuning @@ -229,15 +96,9 @@ def _build_extract_batch_tuning( page_elements_workers: int | None, page_elements_batch_size: int | None, page_elements_cpus_per_actor: float | None, - page_elements_gpus_per_actor: float | None, ocr_workers: int | None, ocr_batch_size: int | None, ocr_cpus_per_actor: float | None, - ocr_gpus_per_actor: float | None, - table_structure_workers: int | None, - table_structure_batch_size: int | None, - table_structure_cpus_per_actor: float | None, - table_structure_gpus_per_actor: float | None, ) -> BatchTuningParams | None: tuning_kwargs = { key: value @@ -249,15 +110,9 @@ def _build_extract_batch_tuning( "page_elements_workers": page_elements_workers, "page_elements_batch_size": page_elements_batch_size, "page_elements_cpus_per_actor": page_elements_cpus_per_actor, - "gpu_page_elements": page_elements_gpus_per_actor, "ocr_workers": ocr_workers, "ocr_inference_batch_size": ocr_batch_size, "ocr_cpus_per_actor": ocr_cpus_per_actor, - "gpu_ocr": ocr_gpus_per_actor, - "table_structure_workers": table_structure_workers, - "table_structure_batch_size": table_structure_batch_size, - "table_structure_cpus_per_actor": table_structure_cpus_per_actor, - "gpu_table_structure": table_structure_gpus_per_actor, }.items() if value is not None } @@ -269,7 +124,6 @@ def _build_embed_batch_tuning( embed_workers: int | None, embed_batch_size: int | None, embed_cpus_per_actor: float | None, - embed_gpus_per_actor: float | None, ) -> BatchTuningParams | None: tuning_kwargs = { key: value @@ -277,7 +131,6 @@ def _build_embed_batch_tuning( "embed_workers": embed_workers, "embed_batch_size": embed_batch_size, "embed_cpus_per_actor": embed_cpus_per_actor, - "gpu_embed": embed_gpus_per_actor, }.items() if value is not None } @@ -311,7 +164,7 @@ def _build_rerank_kwargs( rerank_kwargs["api_key"] = api_key return rerank_kwargs - # Local GPU reranker - VL by default to pair with the local VL embedder. + # Local GPU reranker — VL by default to pair with the local VL embedder. # ``NemotronRerankGPUActor`` loads the model once per actor; the rerank # model is ~2 GB and coexists with the vLLM embedder (which respects # ``gpu_memory_utilization=0.45``). @@ -324,7 +177,6 @@ def _build_rerank_kwargs( def ingest_documents( documents: Sequence[str], *, - input_type: IngestInputTypeValue = "auto", run_mode: IngestRunModeValue = "inprocess", ray_address: str | None = None, ray_log_to_driver: bool | None = None, @@ -337,32 +189,27 @@ def ingest_documents( ocr_lang: OcrLangValue | None = None, graphic_elements_invoke_url: str | None = None, table_structure_invoke_url: str | None = None, - table_output_format: TableOutputFormatValue | None = None, embed_invoke_url: str | None = None, embed_model_name: str | None = None, - local_ingest_embed_backend: LocalIngestEmbedBackendValue | None = None, pdf_extract_workers: int | None = None, pdf_extract_batch_size: int | None = None, pdf_extract_cpus_per_task: float | None = None, page_elements_workers: int | None = None, page_elements_batch_size: int | None = None, page_elements_cpus_per_actor: float | None = None, - page_elements_gpus_per_actor: float | None = None, ocr_workers: int | None = None, ocr_batch_size: int | None = None, ocr_cpus_per_actor: float | None = None, - ocr_gpus_per_actor: float | None = None, - table_structure_workers: int | None = None, - table_structure_batch_size: int | None = None, - table_structure_cpus_per_actor: float | None = None, - table_structure_gpus_per_actor: float | None = None, embed_workers: int | None = None, embed_batch_size: int | None = None, embed_cpus_per_actor: float | None = None, - embed_gpus_per_actor: float | None = None, ) -> dict[str, Any]: """Run the root CLI ingestion path through the SDK adapter. + Input families are inferred from concrete file extensions and routed by + the graph ingestor manifest planner; the root CLI intentionally has no + user-facing input-type selector. + ``ray_address`` and ``ray_log_to_driver`` are forwarded only when the caller sets them, preserving the default ``create_ingestor`` behavior. Batch tuning arguments are opt-in and are translated into @@ -370,9 +217,7 @@ def ingest_documents( ``run_mode="batch"`` and ignored by callers that leave them unset. """ validated_run_mode = _validate_run_mode(run_mode) - validated_input_type = _validate_input_type(input_type) - document_list = _expand_ingest_documents(documents, input_type=validated_input_type) - effective_input_type = _resolve_effective_input_type(document_list, input_type=validated_input_type) + document_list = _expand_ingest_documents(documents) extract_kwargs = { key: value for key, value in { @@ -382,12 +227,9 @@ def ingest_documents( "ocr_lang": ocr_lang, "graphic_elements_invoke_url": graphic_elements_invoke_url, "table_structure_invoke_url": table_structure_invoke_url, - "table_output_format": table_output_format, }.items() if value is not None } - if table_output_format == "markdown": - extract_kwargs["use_table_structure"] = True extract_tuning = _build_extract_batch_tuning( pdf_extract_workers=pdf_extract_workers, pdf_extract_batch_size=pdf_extract_batch_size, @@ -395,26 +237,18 @@ def ingest_documents( page_elements_workers=page_elements_workers, page_elements_batch_size=page_elements_batch_size, page_elements_cpus_per_actor=page_elements_cpus_per_actor, - page_elements_gpus_per_actor=page_elements_gpus_per_actor, ocr_workers=ocr_workers, ocr_batch_size=ocr_batch_size, ocr_cpus_per_actor=ocr_cpus_per_actor, - ocr_gpus_per_actor=ocr_gpus_per_actor, - table_structure_workers=table_structure_workers, - table_structure_batch_size=table_structure_batch_size, - table_structure_cpus_per_actor=table_structure_cpus_per_actor, - table_structure_gpus_per_actor=table_structure_gpus_per_actor, ) if extract_tuning is not None: extract_kwargs["batch_tuning"] = extract_tuning embed_kwargs = _build_embed_kwargs( embed_invoke_url, embed_model_name, - local_ingest_embed_backend=local_ingest_embed_backend, embed_workers=embed_workers, embed_batch_size=embed_batch_size, embed_cpus_per_actor=embed_cpus_per_actor, - embed_gpus_per_actor=embed_gpus_per_actor, ) extract_params = ExtractParams(**extract_kwargs) if extract_kwargs else None embed_params = EmbedParams(**embed_kwargs) if embed_kwargs else None @@ -429,11 +263,7 @@ def ingest_documents( create_kwargs["ray_log_to_driver"] = ray_log_to_driver ingestor = create_ingestor(**create_kwargs).files(document_list) - ingestor = _attach_extract_stage( - ingestor, - input_type=effective_input_type, - extract_params=extract_params, - ) + ingestor = ingestor.extract(extract_params or ExtractParams()) ingestor = ingestor.embed(embed_params) if embed_params is not None else ingestor.embed() result = ingestor.vdb_upload(vdb_params).ingest() return { diff --git a/nemo_retriever/src/nemo_retriever/branch_extraction.py b/nemo_retriever/src/nemo_retriever/branch_extraction.py new file mode 100644 index 0000000000..4cf8506b9d --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/branch_extraction.py @@ -0,0 +1,294 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-26, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Branch execution for manifest-planned retriever ingest extraction.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from io import BytesIO +from typing import Any, Callable + +from nemo_retriever.graph import InprocessExecutor, RayDataExecutor +from nemo_retriever.graph.ingestor_runtime import batch_tuning_to_node_overrides, build_graph, build_post_extract_graph +from nemo_retriever.ingest_manifest import ( + ExtractionBranchPlan, + ResolvedExtractionInputs, + format_branch_summary, + resolve_branch_extraction_inputs, +) + + +logger = logging.getLogger(__name__) + + +def ensure_pandas_columns(batch_df: Any, *, columns: tuple[str, ...]) -> Any: + """Pad a pandas batch to a stable schema before unioning branch outputs.""" + + for column in columns: + if column not in batch_df.columns: + batch_df[column] = None + return batch_df.loc[:, list(columns)] + + +@dataclass +class ExtractionBranchExecutor: + """Run manifest extraction branches and common post-extraction stages.""" + + run_mode: str + branches: tuple[ExtractionBranchPlan, ...] + documents: list[str] + buffers: list[tuple[str, BytesIO]] + split_config: dict[str, Any] + extract_params: Any | None + text_params: Any | None + html_params: Any | None + audio_chunk_params: Any | None + asr_params: Any | None + video_frame_params: Any | None + video_text_dedup_params: Any | None + av_fuse_params: Any | None + embed_params: Any | None + caption_params: Any | None + dedup_params: Any | None + store_params: Any | None + vdb_upload_params: Any | None + webhook_params: Any | None + post_extract_order: tuple[str, ...] + ray_address: str | None + batch_size: int + num_cpus: float + num_gpus: float + node_overrides: dict[str, dict[str, Any]] + show_progress: bool + allow_no_gpu: bool + ensure_batch_runtime: Callable[[], tuple[Any, Any]] + + def execute(self) -> Any: + logger.info( + "Retriever ingest manifest planned %d extraction branches: %s", + len(self.branches), + format_branch_summary(self.branches), + ) + if self.run_mode == "batch": + return self._execute_batch() + return self._execute_inprocess() + + def _execute_batch(self) -> Any: + _ray, cluster_resources = self.ensure_batch_runtime() + effective_allow_no_gpu = self.allow_no_gpu or cluster_resources.available_gpu_count() == 0 + branch_datasets: list[Any] = [] + for branch in self.branches: + effective_extraction = self._resolve_branch(branch) + logger.info( + "Retriever ingest extraction branch family=%s files=%d graph_mode=%s", + branch.family, + len(branch.input_paths), + effective_extraction.extraction_mode, + ) + graph = self._build_extraction_only_graph(effective_extraction) + derived_overrides = batch_tuning_to_node_overrides( + effective_extraction.extract_params, + None, + store_params=None, + cluster_resources=cluster_resources, + allow_no_gpu=effective_allow_no_gpu, + caption_params=None, + video_frame_params=effective_extraction.video_frame_params, + ) + executor = self._ray_executor(graph, derived_overrides) + branch_datasets.append(executor.build_dataset(list(branch.input_paths))) + + normalized = normalize_ray_branch_datasets(branch_datasets) + combined = normalized[0] + for branch_ds in normalized[1:]: + combined = combined.union(branch_ds) + + logger.info("Retriever ingest post-extraction stages: %s", format_post_stage_summary(self.post_extract_order)) + post_graph = build_post_extract_graph( + dedup_params=self.dedup_params, + embed_params=self.embed_params, + caption_params=self.caption_params, + store_params=self.store_params, + vdb_upload_params=self.vdb_upload_params, + webhook_params=self.webhook_params, + stage_order=self.post_extract_order, + ) + post_overrides = batch_tuning_to_node_overrides( + None, + self.embed_params, + store_params=self.store_params, + cluster_resources=cluster_resources, + allow_no_gpu=effective_allow_no_gpu, + caption_params=self.caption_params, + video_frame_params=None, + ) + return self._ray_executor(post_graph, post_overrides).ingest(combined) + + def _execute_inprocess(self) -> Any: + frames = [] + for branch in self.branches: + effective_extraction = self._resolve_branch(branch) + logger.info( + "Retriever ingest extraction branch family=%s files=%d graph_mode=%s", + branch.family, + len(branch.input_paths), + effective_extraction.extraction_mode, + ) + graph = self._build_extraction_only_graph(effective_extraction) + executor = InprocessExecutor(graph, show_progress=self.show_progress) + frames.append(executor.ingest(self._inprocess_branch_input(branch))) + + combined = concat_dataframes(frames) + logger.info("Retriever ingest post-extraction stages: %s", format_post_stage_summary(self.post_extract_order)) + post_graph = build_post_extract_graph( + dedup_params=self.dedup_params, + embed_params=self.embed_params, + caption_params=self.caption_params, + store_params=self.store_params, + vdb_upload_params=self.vdb_upload_params, + webhook_params=self.webhook_params, + stage_order=self.post_extract_order, + ) + return InprocessExecutor(post_graph, show_progress=self.show_progress).ingest(combined) + + def _resolve_branch(self, branch: ExtractionBranchPlan) -> ResolvedExtractionInputs: + return resolve_branch_extraction_inputs( + branch, + extract_params=self.extract_params, + text_params=self.text_params, + html_params=self.html_params, + audio_chunk_params=self.audio_chunk_params, + asr_params=self.asr_params, + video_frame_params=self.video_frame_params, + video_text_dedup_params=self.video_text_dedup_params, + av_fuse_params=self.av_fuse_params, + ) + + def _build_extraction_only_graph(self, effective_extraction: ResolvedExtractionInputs) -> Any: + return build_graph( + extraction_mode=effective_extraction.extraction_mode, + extract_params=effective_extraction.extract_params, + text_params=effective_extraction.text_params, + html_params=effective_extraction.html_params, + audio_chunk_params=effective_extraction.audio_chunk_params, + asr_params=effective_extraction.asr_params, + video_frame_params=effective_extraction.video_frame_params, + video_text_dedup_params=effective_extraction.video_text_dedup_params, + av_fuse_params=effective_extraction.av_fuse_params, + split_config=self.split_config, + stage_order=(), + ) + + def _ray_executor(self, graph: Any, derived_overrides: dict[str, dict[str, Any]]) -> RayDataExecutor: + return RayDataExecutor( + graph, + ray_address=self.ray_address, + batch_size=self.batch_size, + num_cpus=self.num_cpus, + num_gpus=self.num_gpus, + node_overrides=merge_node_overrides(derived_overrides, self.node_overrides), + ) + + def _inprocess_branch_input(self, branch: ExtractionBranchPlan) -> Any: + if not self.buffers: + return list(branch.input_paths) + + import pandas as pd + + buffer_by_name = {name: buf for name, buf in self.buffers} + file_paths: list[str] = [] + buffer_rows: list[dict[str, Any]] = [] + for path in branch.input_paths: + if path in buffer_by_name: + buffer_rows.append({"bytes": buffer_by_name[path].getvalue(), "path": path}) + else: + file_paths.append(path) + + frames = [] + if file_paths: + frames.append(InprocessExecutor._load_files(file_paths)) + if buffer_rows: + frames.append(pd.DataFrame(buffer_rows)) + return concat_dataframes(frames) + + +def merge_node_overrides( + derived_overrides: dict[str, dict[str, Any]], + explicit_overrides: dict[str, dict[str, Any]], +) -> dict[str, dict[str, Any]]: + merged_overrides: dict[str, dict[str, Any]] = {} + for node_name in set(derived_overrides) | set(explicit_overrides): + merged_overrides[node_name] = { + **derived_overrides.get(node_name, {}), + **explicit_overrides.get(node_name, {}), + } + return merged_overrides + + +def concat_dataframes(frames: list[Any]) -> Any: + import pandas as pd + + if not frames: + return pd.DataFrame(columns=["bytes", "path"]) + columns: list[str] = [] + seen: set[str] = set() + for frame in frames: + for column in frame.columns: + if column not in seen: + columns.append(column) + seen.add(column) + normalized = [frame.reindex(columns=columns) for frame in frames] + return pd.concat(normalized, ignore_index=True, sort=False) + + +def normalize_ray_branch_datasets(branch_datasets: list[Any]) -> list[Any]: + columns: list[str] = [] + seen: set[str] = set() + for dataset in branch_datasets: + dataset_columns = ray_dataset_columns(dataset) + if not dataset_columns: + # Avoid eager schema discovery: Ray computes missing schemas by + # executing a limit=1 plan, which pre-runs extraction branches. + return branch_datasets + for column in dataset_columns: + if column not in seen: + columns.append(column) + seen.add(column) + if not columns: + return branch_datasets + stable_columns = tuple(columns) + return [ + dataset.map_batches( + ensure_pandas_columns, + batch_format="pandas", + fn_kwargs={"columns": stable_columns}, + ) + for dataset in branch_datasets + ] + + +def ray_dataset_columns(dataset: Any) -> tuple[str, ...]: + try: + schema = dataset.schema(fetch_if_missing=False) + except TypeError: + schema = dataset.schema() + if schema is None: + return () + names = getattr(schema, "names", None) + if callable(names): + names = names() + if names is None: + base_schema = getattr(schema, "base_schema", None) + names = getattr(base_schema, "names", None) if base_schema is not None else None + if callable(names): + names = names() + if names is None: + return () + return tuple(str(name) for name in names) + + +def format_post_stage_summary(post_extract_order: tuple[str, ...]) -> str: + return ", ".join(post_extract_order) if post_extract_order else "none" diff --git a/nemo_retriever/src/nemo_retriever/graph/executor.py b/nemo_retriever/src/nemo_retriever/graph/executor.py index 14a323ab08..0ce9c27cf6 100644 --- a/nemo_retriever/src/nemo_retriever/graph/executor.py +++ b/nemo_retriever/src/nemo_retriever/graph/executor.py @@ -213,7 +213,12 @@ def _linearize(graph: Graph) -> List[Node]: return ordered def ingest(self, data: Any, **kwargs: Any) -> Any: - """Build and execute a Ray Data pipeline from the graph. + """Build, execute, and materialize a Ray Data pipeline from the graph.""" + + return self.build_dataset(data, **kwargs).to_pandas() + + def build_dataset(self, data: Any, **kwargs: Any) -> Any: + """Build a lazy Ray Data pipeline from the graph. Parameters ---------- @@ -224,7 +229,7 @@ def ingest(self, data: Any, **kwargs: Any) -> Any: Returns ------- ray.data.Dataset - The materialized result dataset. + The lazy Ray dataset with all graph stages appended. """ import ray import ray.data as rd @@ -380,4 +385,4 @@ def ingest(self, data: Any, **kwargs: Any) -> Any: **overrides, ) - return ds.to_pandas() + return ds diff --git a/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py b/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py index e95482ad4d..1c0359a7be 100644 --- a/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py +++ b/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py @@ -243,31 +243,23 @@ def _force_cpu_only(node_name: str) -> None: # --- Table Structure --- table_structure_invoke_url = _positive(getattr(extract_params, "table_structure_invoke_url", None)) - ts_bs = _positive( - getattr(extract_tuning, "table_structure_batch_size", None) if extract_tuning is not None else None - ) or (plan.table_structure_batch_size if plan else None) + ts_bs = plan.table_structure_batch_size if plan else None _set(TableStructureActor.__name__, "batch_size", ts_bs) if ts_bs: overrides.setdefault(TableStructureActor.__name__, {})["target_num_rows_per_block"] = ts_bs ts_concurrency: int = 0 - ts_concurrency = _resolve( - getattr(extract_tuning, "table_structure_workers", None) if extract_tuning is not None else None, - plan.table_structure_initial_actors if plan else None, - ) or (2 if table_structure_invoke_url else 0) + if table_structure_invoke_url: + ts_concurrency = (plan.table_structure_initial_actors if plan else None) or 2 + else: + ts_concurrency = (plan.table_structure_initial_actors if plan else None) or 0 _set(TableStructureActor.__name__, "concurrency", ts_concurrency or None) - ts_cpus = ( - _resolve( - getattr(extract_tuning, "table_structure_cpus_per_actor", None) if extract_tuning is not None else None, - ) - or 1.0 - ) - _set(TableStructureActor.__name__, "num_cpus", ts_cpus) + _set(TableStructureActor.__name__, "num_cpus", 1) if effective_allow_no_gpu: _force_cpu_only(TableStructureActor.__name__) elif not table_structure_invoke_url: - _set_gpu( + _set( TableStructureActor.__name__, - getattr(extract_tuning, "gpu_table_structure", None) if extract_tuning is not None else None, + "num_gpus", plan.table_structure_gpus_per_actor if plan else None, ) @@ -340,7 +332,7 @@ def _force_cpu_only(node_name: str) -> None: + page_elements_concurrency * page_elements_cpus + ocr_concurrency * ocr_cpus + embed_concurrency * embed_cpus - + ts_concurrency * ts_cpus + + ts_concurrency * 1 + ge_concurrency * 1 ) pdf_extract_tasks = min( @@ -471,7 +463,6 @@ def _maybe_append_chunk_actor(graph: Graph, split_config: dict[str, Any], key: s def _append_ordered_transform_stages( graph: Graph, *, - extraction_mode: str, dedup_params: Any | None, caption_params: Any | None, store_params: Any | None, @@ -480,7 +471,7 @@ def _append_ordered_transform_stages( webhook_params: Any | None = None, stage_order: tuple[str, ...], supports_dedup: bool, - reshape_for_modal_content: bool, + reshape_content_before_embed: bool, ) -> Graph: """Append post-extraction transform stages in the exact recorded plan order.""" @@ -508,8 +499,7 @@ def _append_ordered_transform_stages( elif stage_name == "caption" and caption_params is not None: graph = graph >> CaptionActor(caption_params) elif stage_name == "embed" and embed_params is not None: - needs_content_reshape = reshape_for_modal_content and extraction_mode in {"pdf", "image", "auto"} - if needs_content_reshape: + if reshape_content_before_embed: content_columns = (_CONTENT_COLUMNS + ("images",)) if caption_params is not None else _CONTENT_COLUMNS if embed_params.embed_granularity == "page": graph = graph >> UDFOperator( @@ -546,6 +536,32 @@ def _append_ordered_transform_stages( return graph +def build_post_extract_graph( + *, + dedup_params: Any | None = None, + embed_params: Any | None = None, + caption_params: Any | None = None, + store_params: Any | None = None, + vdb_upload_params: VdbUploadParams | None = None, + webhook_params: Any | None = None, + stage_order: tuple[str, ...] = (), +) -> Graph: + """Build only the common stages that run after extraction branch union.""" + + return _append_ordered_transform_stages( + Graph(), + dedup_params=dedup_params, + caption_params=caption_params, + store_params=store_params, + embed_params=embed_params, + vdb_upload_params=vdb_upload_params, + webhook_params=webhook_params, + stage_order=stage_order, + supports_dedup=True, + reshape_content_before_embed=True, + ) + + def build_graph( *, execution_plan: IngestExecutionPlan | None = None, @@ -672,7 +688,6 @@ def build_graph( asr_params=asr_params, caption_params=caption_params, video_frame_params=video_frame_params, - video_text_dedup_params=video_text_dedup_params, av_fuse_params=av_fuse_params, split_config=split_config, ) @@ -814,7 +829,6 @@ def build_graph( return _append_ordered_transform_stages( graph, - extraction_mode=extraction_mode, dedup_params=dedup_params, caption_params=caption_params, store_params=store_params, @@ -823,7 +837,7 @@ def build_graph( webhook_params=webhook_params, stage_order=stage_order, supports_dedup=True, - reshape_for_modal_content=True, + reshape_content_before_embed=extraction_mode in {"pdf", "image", "auto"}, ) diff --git a/nemo_retriever/src/nemo_retriever/graph/multi_type_extract_operator.py b/nemo_retriever/src/nemo_retriever/graph/multi_type_extract_operator.py index f58d019cf9..591113e616 100644 --- a/nemo_retriever/src/nemo_retriever/graph/multi_type_extract_operator.py +++ b/nemo_retriever/src/nemo_retriever/graph/multi_type_extract_operator.py @@ -15,11 +15,11 @@ from nemo_retriever.audio import ASRActor from nemo_retriever.audio import MediaChunkActor -from nemo_retriever.audio import asr_params_from_env from nemo_retriever.chart.chart_detection import GraphicElementsActor from nemo_retriever.graph.abstract_operator import AbstractOperator from nemo_retriever.html.ray_data import HtmlSplitActor from nemo_retriever.image.ray_data import ImageLoadActor +from nemo_retriever.image.load import SUPPORTED_IMAGE_EXTENSIONS from nemo_retriever.graph.cpu_operator import CPUOperator from nemo_retriever.graph.gpu_operator import GPUOperator from nemo_retriever.graph.operator_archetype import ArchetypeOperator @@ -50,20 +50,17 @@ from nemo_retriever.video import dedup_video_frames from nemo_retriever.video import video_asr_audio_chunk_params from nemo_retriever.graph.designer import designer_component -from nemo_retriever.utils.input_files import INPUT_TYPE_EXTENSIONS from nemo_retriever.utils.ray_resource_hueristics import gather_local_resources logger = logging.getLogger(__name__) # Define file type mappings -PDF_EXTENSIONS = INPUT_TYPE_EXTENSIONS["pdf"] | INPUT_TYPE_EXTENSIONS["doc"] -TEXT_EXTENSIONS = INPUT_TYPE_EXTENSIONS["txt"] -HTML_EXTENSIONS = INPUT_TYPE_EXTENSIONS["html"] -AUDIO_EXTENSIONS = INPUT_TYPE_EXTENSIONS["audio"] -IMAGE_EXTENSIONS = INPUT_TYPE_EXTENSIONS["image"] -VIDEO_EXTENSIONS = INPUT_TYPE_EXTENSIONS["video"] -DEFAULT_AUDIO_SPLIT_INTERVAL = 500000 -DEFAULT_VIDEO_FRAME_FPS = 0.5 +PDF_EXTENSIONS = {".pdf", ".docx", ".pptx"} +TEXT_EXTENSIONS = {".txt"} +HTML_EXTENSIONS = {".html"} +AUDIO_EXTENSIONS = {".mp3", ".wav"} +IMAGE_EXTENSIONS = SUPPORTED_IMAGE_EXTENSIONS +VIDEO_EXTENSIONS = {".mp4", ".mov", ".mkv"} def _unsupported_extension_message(ext: str) -> str: @@ -78,18 +75,6 @@ def _has_endpoint(*values: Any) -> bool: return any(bool(str(value or "").strip()) for value in values) -def _default_asr_params() -> ASRParams: - return asr_params_from_env().model_copy(update={"segment_audio": False}) - - -def _default_audio_chunk_params() -> AudioChunkParams: - return AudioChunkParams(split_type="size", split_interval=DEFAULT_AUDIO_SPLIT_INTERVAL) - - -def _default_video_frame_params() -> VideoFrameParams: - return VideoFrameParams(enabled=True, fps=DEFAULT_VIDEO_FRAME_FPS, dedup=True) - - def _parse_mode_enabled(extract_params: ExtractParams) -> bool: tuning = getattr(extract_params, "batch_tuning", None) return extract_params.method == "nemotron_parse" or ( @@ -167,14 +152,11 @@ def __init__( self.extract_params = extract_params or ExtractParams() self.text_params = text_params or TextChunkParams() self.html_params = html_params or HtmlChunkParams() - self.audio_chunk_params = audio_chunk_params or _default_audio_chunk_params() - self.asr_params = asr_params or _default_asr_params() + self.audio_chunk_params = audio_chunk_params or AudioChunkParams() + self.asr_params = asr_params or ASRParams() self.caption_params = caption_params - self.video_frame_params = video_frame_params or _default_video_frame_params() - self.video_text_dedup_params = video_text_dedup_params or VideoFrameTextDedupParams( - enabled=True, - max_dropped_frames=2, - ) + self.video_frame_params = video_frame_params or VideoFrameParams() + self.video_text_dedup_params = video_text_dedup_params or VideoFrameTextDedupParams() self.av_fuse_params = av_fuse_params or AudioVisualFuseParams() self._split_config: dict[str, Any] = split_config if split_config is not None else resolve_split_params(None) self._resolved_resources = None diff --git a/nemo_retriever/src/nemo_retriever/graph_ingestor.py b/nemo_retriever/src/nemo_retriever/graph_ingestor.py index a5bde71647..62934d89c8 100644 --- a/nemo_retriever/src/nemo_retriever/graph_ingestor.py +++ b/nemo_retriever/src/nemo_retriever/graph_ingestor.py @@ -26,14 +26,23 @@ from __future__ import annotations +import logging import os import sys -from dataclasses import dataclass from io import BytesIO from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union from nemo_retriever.graph import InprocessExecutor, RayDataExecutor +from nemo_retriever.branch_extraction import ExtractionBranchExecutor, merge_node_overrides from nemo_retriever.graph.ingestor_runtime import batch_tuning_to_node_overrides, build_graph +from nemo_retriever.ingest_manifest import ( + ExtractionBranchPlan, + ResolvedExtractionInputs, + build_input_manifest, + format_branch_summary, + plan_extraction_branches, + resolve_branch_extraction_inputs, +) from nemo_retriever.ingestor import ingestor from nemo_retriever.params import ( ASRParams, @@ -69,6 +78,7 @@ _DEFAULT_PAGE_ELEMENTS_COLUMN = "page_elements_v3" _DEFAULT_EMBED_COLUMN = "text_embeddings_1b_v2" _ERROR_MESSAGE_LIMIT = 256 +logger = logging.getLogger(__name__) _EXPLICIT_MODE_INPUT_TYPES: dict[str, frozenset[str]] = { "pdf": PDF_DOCUMENT_INPUT_TYPES, "image": frozenset({"image"}), @@ -79,19 +89,6 @@ } -@dataclass(frozen=True) -class _EffectiveExtractionInputs: - extraction_mode: str - extract_params: Any | None - text_params: Any | None - html_params: Any | None - audio_chunk_params: Any | None - asr_params: Any | None - video_frame_params: Any | None - video_text_dedup_params: Any | None - av_fuse_params: Any | None - - class GraphIngestionError(RuntimeError): """Raised when graph ingestion stages report structured row-level errors.""" @@ -258,6 +255,7 @@ def __init__( self._show_progress = show_progress self._error_policy = error_policy self._rd_dataset: Any = None + self._buffers: list[tuple[str, BytesIO]] = [] # Pipeline configuration accumulated by fluent methods self._extraction_mode: str | None = "pdf" @@ -507,18 +505,21 @@ def ingest(self, params: Any = None, **kwargs: Any) -> Any: ``run_mode='inprocess'`` A ``pandas.DataFrame``. """ - effective_extraction = self._resolve_effective_extraction_inputs() + default_branches = self._plan_default_extraction_branches() + if default_branches is None: + single_effective = self._resolve_effective_extraction_inputs() + elif len(default_branches) == 1: + single_effective = self._resolve_branch_extraction_inputs(default_branches[0]) + else: + single_effective = None + # Auto-enable dedup before captioning so that images overlapping # with table/chart/infographic detections are removed first. # Skip for image-only extraction — the image IS the content. - if ( - self._caption_params is not None - and self._dedup_params is None - and effective_extraction.extraction_mode != "image" - ): + image_only = single_effective is not None and single_effective.extraction_mode == "image" + if self._caption_params is not None and self._dedup_params is None and not image_only: self._dedup_params = DedupParams() if "dedup" not in self._stage_order: - # Insert dedup right before caption in the stage order. try: idx = self._stage_order.index("caption") except ValueError: @@ -527,111 +528,171 @@ def ingest(self, params: Any = None, **kwargs: Any) -> Any: post_extract_order = tuple(s for s in self._stage_order if s != "extract") - if self._run_mode == "batch": - import ray - - if self._ray_address or not ray.is_initialized(): - venv = os.path.dirname(os.path.dirname(sys.executable)) - venv_bin = os.path.join(venv, "bin") - pypath = os.pathsep.join(p for p in sys.path if p) - ray_env_vars: dict[str, str] = { - "VIRTUAL_ENV": venv, - "PATH": venv_bin + os.pathsep + os.environ.get("PATH", ""), - "PYTHONPATH": pypath, - } - ray_env_vars.update(collect_hf_runtime_env()) - ray_env_vars.update(collect_remote_auth_runtime_env()) - os.environ["HF_HUB_OFFLINE"] = ray_env_vars["HF_HUB_OFFLINE"] - runtime_env = {"env_vars": ray_env_vars} - ray.init( - address=self._ray_address, - ignore_reinit_error=True, - runtime_env=runtime_env, - log_to_driver=self._ray_log_to_driver, - ) - cluster_resources = gather_cluster_resources(ray) - - graph = build_graph( - extraction_mode=effective_extraction.extraction_mode, - extract_params=effective_extraction.extract_params, - text_params=effective_extraction.text_params, - html_params=effective_extraction.html_params, - audio_chunk_params=effective_extraction.audio_chunk_params, - asr_params=effective_extraction.asr_params, - video_frame_params=effective_extraction.video_frame_params, - video_text_dedup_params=effective_extraction.video_text_dedup_params, - av_fuse_params=effective_extraction.av_fuse_params, - embed_params=self._embed_params, - split_config=self._split_config, - caption_params=self._caption_params, - dedup_params=self._dedup_params, - store_params=self._store_params, - vdb_upload_params=self._vdb_upload_params, - webhook_params=self._webhook_params, - stage_order=post_extract_order, - ) - # Derive per-node Ray scheduling config from BatchTuningParams plus - # cluster-scaled heuristic defaults, then let any explicit - # node_overrides passed to __init__ take precedence. - effective_allow_no_gpu = self._allow_no_gpu or cluster_resources.available_gpu_count() == 0 - derived_overrides = batch_tuning_to_node_overrides( - effective_extraction.extract_params, - self._embed_params, - store_params=self._store_params, - cluster_resources=cluster_resources, - allow_no_gpu=effective_allow_no_gpu, - caption_params=self._caption_params, - video_frame_params=effective_extraction.video_frame_params, - ) - merged_overrides: Dict[str, Dict[str, Any]] = {} - for node_name in set(derived_overrides) | set(self._node_overrides): - merged_overrides[node_name] = { - **derived_overrides.get(node_name, {}), - **self._node_overrides.get(node_name, {}), - } - executor = RayDataExecutor( - graph, - ray_address=self._ray_address, - batch_size=self._batch_size, - num_cpus=self._num_cpus, - num_gpus=self._num_gpus, - node_overrides=merged_overrides, - ) - result = executor.ingest(self._documents) - self._rd_dataset = result + if default_branches is not None and len(default_branches) > 1: + result = self._execute_extraction_branches(default_branches, post_extract_order=post_extract_order) else: - graph = build_graph( - extraction_mode=effective_extraction.extraction_mode, - extract_params=effective_extraction.extract_params, - text_params=effective_extraction.text_params, - html_params=effective_extraction.html_params, - audio_chunk_params=effective_extraction.audio_chunk_params, - asr_params=effective_extraction.asr_params, - video_frame_params=effective_extraction.video_frame_params, - video_text_dedup_params=effective_extraction.video_text_dedup_params, - av_fuse_params=effective_extraction.av_fuse_params, - embed_params=self._embed_params, - split_config=self._split_config, - caption_params=self._caption_params, - dedup_params=self._dedup_params, - store_params=self._store_params, - vdb_upload_params=self._vdb_upload_params, - webhook_params=self._webhook_params, - stage_order=post_extract_order, - ) - executor = InprocessExecutor(graph, show_progress=self._show_progress) - self._rd_dataset = None - if self._buffers: - import pandas as pd - - df = pd.DataFrame([{"bytes": buf.read(), "path": name} for name, buf in self._buffers]) - result = executor.ingest(df) - else: - result = executor.ingest(self._documents) + if single_effective is None: + raise RuntimeError("Internal error: extraction inputs were not resolved.") + result = self._execute_single_graph(single_effective, post_extract_order=post_extract_order) self._raise_for_stage_errors(result) return result + def _execute_single_graph( + self, + effective_extraction: ResolvedExtractionInputs, + *, + post_extract_order: tuple[str, ...], + ) -> Any: + if self._run_mode == "batch": + return self._execute_single_graph_batch(effective_extraction, post_extract_order=post_extract_order) + return self._execute_single_graph_inprocess(effective_extraction, post_extract_order=post_extract_order) + + def _execute_single_graph_batch( + self, + effective_extraction: ResolvedExtractionInputs, + *, + post_extract_order: tuple[str, ...], + ) -> Any: + _ray, cluster_resources = self._ensure_batch_runtime() + graph = build_graph( + extraction_mode=effective_extraction.extraction_mode, + extract_params=effective_extraction.extract_params, + text_params=effective_extraction.text_params, + html_params=effective_extraction.html_params, + audio_chunk_params=effective_extraction.audio_chunk_params, + asr_params=effective_extraction.asr_params, + video_frame_params=effective_extraction.video_frame_params, + video_text_dedup_params=effective_extraction.video_text_dedup_params, + av_fuse_params=effective_extraction.av_fuse_params, + embed_params=self._embed_params, + split_config=self._split_config, + caption_params=self._caption_params, + dedup_params=self._dedup_params, + store_params=self._store_params, + vdb_upload_params=self._vdb_upload_params, + webhook_params=self._webhook_params, + stage_order=post_extract_order, + ) + effective_allow_no_gpu = self._allow_no_gpu or cluster_resources.available_gpu_count() == 0 + derived_overrides = batch_tuning_to_node_overrides( + effective_extraction.extract_params, + self._embed_params, + store_params=self._store_params, + cluster_resources=cluster_resources, + allow_no_gpu=effective_allow_no_gpu, + caption_params=self._caption_params, + video_frame_params=effective_extraction.video_frame_params, + ) + executor = RayDataExecutor( + graph, + ray_address=self._ray_address, + batch_size=self._batch_size, + num_cpus=self._num_cpus, + num_gpus=self._num_gpus, + node_overrides=merge_node_overrides(derived_overrides, self._node_overrides), + ) + result = executor.ingest(self._documents) + self._rd_dataset = result + return result + + def _execute_single_graph_inprocess( + self, + effective_extraction: ResolvedExtractionInputs, + *, + post_extract_order: tuple[str, ...], + ) -> Any: + graph = build_graph( + extraction_mode=effective_extraction.extraction_mode, + extract_params=effective_extraction.extract_params, + text_params=effective_extraction.text_params, + html_params=effective_extraction.html_params, + audio_chunk_params=effective_extraction.audio_chunk_params, + asr_params=effective_extraction.asr_params, + video_frame_params=effective_extraction.video_frame_params, + video_text_dedup_params=effective_extraction.video_text_dedup_params, + av_fuse_params=effective_extraction.av_fuse_params, + embed_params=self._embed_params, + split_config=self._split_config, + caption_params=self._caption_params, + dedup_params=self._dedup_params, + store_params=self._store_params, + vdb_upload_params=self._vdb_upload_params, + webhook_params=self._webhook_params, + stage_order=post_extract_order, + ) + executor = InprocessExecutor(graph, show_progress=self._show_progress) + self._rd_dataset = None + if self._buffers: + import pandas as pd + + df = pd.DataFrame([{"bytes": buf.getvalue(), "path": name} for name, buf in self._buffers]) + return executor.ingest(df) + return executor.ingest(self._documents) + + def _execute_extraction_branches( + self, + branches: tuple[ExtractionBranchPlan, ...], + *, + post_extract_order: tuple[str, ...], + ) -> Any: + result = ExtractionBranchExecutor( + run_mode=self._run_mode, + branches=branches, + documents=self._documents, + buffers=self._buffers, + split_config=self._split_config, + extract_params=self._extract_params, + text_params=self._text_params, + html_params=self._html_params, + audio_chunk_params=self._audio_chunk_params, + asr_params=self._asr_params, + video_frame_params=self._video_frame_params, + video_text_dedup_params=self._video_text_dedup_params, + av_fuse_params=self._av_fuse_params, + embed_params=self._embed_params, + caption_params=self._caption_params, + dedup_params=self._dedup_params, + store_params=self._store_params, + vdb_upload_params=self._vdb_upload_params, + webhook_params=self._webhook_params, + post_extract_order=post_extract_order, + ray_address=self._ray_address, + batch_size=self._batch_size, + num_cpus=self._num_cpus, + num_gpus=self._num_gpus, + node_overrides=self._node_overrides, + show_progress=self._show_progress, + allow_no_gpu=self._allow_no_gpu, + ensure_batch_runtime=self._ensure_batch_runtime, + ).execute() + self._rd_dataset = result if self._run_mode == "batch" else None + return result + + def _ensure_batch_runtime(self) -> tuple[Any, Any]: + import ray + + if self._ray_address or not ray.is_initialized(): + venv = os.path.dirname(os.path.dirname(sys.executable)) + venv_bin = os.path.join(venv, "bin") + pypath = os.pathsep.join(p for p in sys.path if p) + ray_env_vars: dict[str, str] = { + "VIRTUAL_ENV": venv, + "PATH": venv_bin + os.pathsep + os.environ.get("PATH", ""), + "PYTHONPATH": pypath, + } + ray_env_vars.update(collect_hf_runtime_env()) + ray_env_vars.update(collect_remote_auth_runtime_env()) + os.environ["HF_HUB_OFFLINE"] = ray_env_vars["HF_HUB_OFFLINE"] + runtime_env = {"env_vars": ray_env_vars} + ray.init( + address=self._ray_address, + ignore_reinit_error=True, + runtime_env=runtime_env, + log_to_driver=self._ray_log_to_driver, + ) + return ray, gather_cluster_resources(ray) + # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ @@ -672,75 +733,68 @@ def _validate_explicit_extraction_mode_inputs( examples = self._input_type_examples(mismatched) raise ValueError(f"Input file type(s) do not match extraction_mode={extraction_mode!r}: {examples}") - def _resolve_effective_extraction_inputs(self) -> _EffectiveExtractionInputs: - extraction_mode = self._extraction_mode - extract_params = self._extract_params - text_params = self._text_params - html_params = self._html_params - audio_chunk_params = self._audio_chunk_params - asr_params = self._asr_params - video_frame_params = self._video_frame_params - video_text_dedup_params = self._video_text_dedup_params - av_fuse_params = self._av_fuse_params + def _plan_default_extraction_branches(self) -> tuple[ExtractionBranchPlan, ...] | None: + if self._extraction_mode is not None: + return None + manifest = build_input_manifest(self._configured_input_paths()) + branches = plan_extraction_branches(manifest) + if self._debug: + logger.info( + "Retriever ingest manifest planned %d extraction branches: %s", + len(branches), + format_branch_summary(branches), + ) + return branches + + def _resolve_branch_extraction_inputs(self, branch: ExtractionBranchPlan) -> ResolvedExtractionInputs: + return resolve_branch_extraction_inputs( + branch, + extract_params=self._extract_params, + text_params=self._text_params, + html_params=self._html_params, + audio_chunk_params=self._audio_chunk_params, + asr_params=self._asr_params, + video_frame_params=self._video_frame_params, + video_text_dedup_params=self._video_text_dedup_params, + av_fuse_params=self._av_fuse_params, + ) + def _resolve_effective_extraction_inputs(self) -> ResolvedExtractionInputs: + extraction_mode = self._extraction_mode classified = self._classified_input_paths() if extraction_mode is not None: self._validate_explicit_extraction_mode_inputs(extraction_mode, classified) - return _EffectiveExtractionInputs( + return ResolvedExtractionInputs( extraction_mode=extraction_mode, - extract_params=extract_params, - text_params=text_params, - html_params=html_params, - audio_chunk_params=audio_chunk_params, - asr_params=asr_params, - video_frame_params=video_frame_params, - video_text_dedup_params=video_text_dedup_params, - av_fuse_params=av_fuse_params, + extract_params=self._extract_params, + text_params=self._text_params, + html_params=self._html_params, + audio_chunk_params=self._audio_chunk_params, + asr_params=self._asr_params, + video_frame_params=self._video_frame_params, + video_text_dedup_params=self._video_text_dedup_params, + av_fuse_params=self._av_fuse_params, ) - unsupported = [ - path for path, input_type in classified if input_type is None and not _is_explicit_glob_path(path) - ] - if unsupported: - examples = self._input_type_examples(unsupported) - raise ValueError(f"Unsupported input file type(s) for default GraphIngestor.extract(): {examples}") - - observed_input_types = {input_type for _, input_type in classified if input_type is not None} - if not observed_input_types or observed_input_types <= PDF_DOCUMENT_INPUT_TYPES: - extraction_mode = "pdf" - elif observed_input_types == {"image"}: - extraction_mode = "image" - elif observed_input_types == {"txt"}: - extraction_mode = "text" - text_params = text_params or TextChunkParams() - elif observed_input_types == {"html"}: - extraction_mode = "html" - html_params = html_params or HtmlChunkParams() - elif observed_input_types == {"audio"}: - extraction_mode = "audio" - audio_chunk_params = audio_chunk_params or AudioChunkParams() - asr_params = asr_params or ASRParams() - elif observed_input_types == {"video"}: - extraction_mode = "auto" - audio_chunk_params = audio_chunk_params or AudioChunkParams() - asr_params = asr_params or ASRParams() - video_frame_params = video_frame_params or VideoFrameParams() - video_text_dedup_params = video_text_dedup_params or VideoFrameTextDedupParams() - av_fuse_params = av_fuse_params or AudioVisualFuseParams() - extract_params = extract_params or ExtractParams() - else: - extraction_mode = "auto" - - return _EffectiveExtractionInputs( - extraction_mode=extraction_mode, - extract_params=extract_params, - text_params=text_params, - html_params=html_params, - audio_chunk_params=audio_chunk_params, - asr_params=asr_params, - video_frame_params=video_frame_params, - video_text_dedup_params=video_text_dedup_params, - av_fuse_params=av_fuse_params, + branches = self._plan_default_extraction_branches() + if branches is None: + raise RuntimeError("Internal error: default extraction planning did not return branches.") + if len(branches) == 1: + return self._resolve_branch_extraction_inputs(branches[0]) + + # Compatibility fallback for private callers that still ask for a + # scalar effective mode directly. The public ingest path executes the + # branches instead of using this MultiType fallback. + return ResolvedExtractionInputs( + extraction_mode="auto", + extract_params=self._extract_params or ExtractParams(), + text_params=self._text_params or TextChunkParams(), + html_params=self._html_params or HtmlChunkParams(), + audio_chunk_params=self._audio_chunk_params, + asr_params=self._asr_params, + video_frame_params=self._video_frame_params, + video_text_dedup_params=self._video_text_dedup_params, + av_fuse_params=self._av_fuse_params, ) @staticmethod diff --git a/nemo_retriever/src/nemo_retriever/ingest_manifest.py b/nemo_retriever/src/nemo_retriever/ingest_manifest.py new file mode 100644 index 0000000000..abb66328ad --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/ingest_manifest.py @@ -0,0 +1,221 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-26, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Manifest planning for input-aware retriever ingest extraction.""" + +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Iterable + +from nemo_retriever.params import ( + ASRParams, + AudioChunkParams, + AudioVisualFuseParams, + ExtractParams, + HtmlChunkParams, + TextChunkParams, + VideoFrameParams, + VideoFrameTextDedupParams, +) +from nemo_retriever.utils.input_files import _is_explicit_glob_path, input_type_for_path + + +_AUDIO_SPLIT_INTERVAL = 500000 +_VIDEO_FRAME_FPS = 0.5 + + +@dataclass(frozen=True) +class ExtractionBranchSpec: + """Canonical policy for one manifest-planned extraction branch.""" + + family: str + input_types: tuple[str, ...] + extraction_mode: str + + +_BRANCH_SPECS: tuple[ExtractionBranchSpec, ...] = ( + ExtractionBranchSpec(family="pdf", input_types=("pdf", "doc"), extraction_mode="pdf"), + ExtractionBranchSpec(family="image", input_types=("image",), extraction_mode="image"), + ExtractionBranchSpec(family="txt", input_types=("txt",), extraction_mode="text"), + ExtractionBranchSpec(family="html", input_types=("html",), extraction_mode="html"), + ExtractionBranchSpec(family="audio", input_types=("audio",), extraction_mode="audio"), + # Video keeps extraction_mode="auto" because build_graph uses the presence + # of video params to construct the dedicated video extraction chain. + ExtractionBranchSpec(family="video", input_types=("video",), extraction_mode="auto"), +) +_BRANCH_SPECS_BY_FAMILY = {spec.family: spec for spec in _BRANCH_SPECS} +_BRANCH_SPECS_BY_INPUT_TYPE = {input_type: spec for spec in _BRANCH_SPECS for input_type in spec.input_types} + + +@dataclass(frozen=True) +class ManifestEntry: + """One concrete or optional ingest input in a manifest.""" + + path: str + input_type: str | None + is_explicit_glob: bool = False + + +@dataclass(frozen=True) +class InputManifest: + """Classified input files for planner-backed ingest.""" + + entries: tuple[ManifestEntry, ...] + unsupported_files: tuple[str, ...] + + @property + def files_by_family(self) -> dict[str, tuple[str, ...]]: + grouped: defaultdict[str, list[str]] = defaultdict(list) + for entry in self.entries: + if entry.input_type is None: + continue + grouped[_BRANCH_SPECS_BY_INPUT_TYPE[entry.input_type].family].append(entry.path) + return {family: tuple(paths) for family, paths in grouped.items()} + + @property + def optional_globs(self) -> tuple[str, ...]: + return tuple(entry.path for entry in self.entries if entry.is_explicit_glob) + + +@dataclass(frozen=True) +class ExtractionBranchPlan: + """A single typed extraction branch to execute before common stages.""" + + spec: ExtractionBranchSpec + input_paths: tuple[str, ...] + + @property + def family(self) -> str: + return self.spec.family + + @property + def extraction_mode(self) -> str: + return self.spec.extraction_mode + + +@dataclass(frozen=True) +class ResolvedExtractionInputs: + """Concrete graph-builder inputs for one extraction branch or explicit mode.""" + + extraction_mode: str + extract_params: Any | None + text_params: Any | None + html_params: Any | None + audio_chunk_params: Any | None + asr_params: Any | None + video_frame_params: Any | None + video_text_dedup_params: Any | None + av_fuse_params: Any | None + + +def build_input_manifest(input_paths: Iterable[str]) -> InputManifest: + """Classify concrete input paths without loading modality dependencies.""" + + entries: list[ManifestEntry] = [] + unsupported: list[str] = [] + for path in input_paths: + is_glob = _is_explicit_glob_path(path) + input_type = None if is_glob else input_type_for_path(path) + entries.append(ManifestEntry(path=path, input_type=input_type, is_explicit_glob=is_glob)) + if input_type is None and not is_glob: + unsupported.append(path) + return InputManifest(entries=tuple(entries), unsupported_files=tuple(unsupported)) + + +def plan_extraction_branches(manifest: InputManifest) -> tuple[ExtractionBranchPlan, ...]: + """Emit deterministic extraction branches for a validated manifest.""" + + if manifest.unsupported_files: + examples = ", ".join(manifest.unsupported_files[:3]) + raise ValueError(f"Unsupported input file type(s) for default GraphIngestor.extract(): {examples}") + + files_by_family = manifest.files_by_family + if not files_by_family: + # Empty optional globs should preserve the old empty-input behavior + # without inventing modality branches that require extra dependencies. + return ( + ExtractionBranchPlan( + spec=_BRANCH_SPECS_BY_FAMILY["pdf"], + input_paths=manifest.optional_globs, + ), + ) + + branches: list[ExtractionBranchPlan] = [] + for spec in _BRANCH_SPECS: + paths = files_by_family.get(spec.family) + if not paths: + continue + branches.append(ExtractionBranchPlan(spec=spec, input_paths=paths)) + return tuple(branches) + + +def format_branch_summary(branches: tuple[ExtractionBranchPlan, ...]) -> str: + return ", ".join(f"{branch.family}:{len(branch.input_paths)}" for branch in branches) + + +def resolve_branch_extraction_inputs( + branch: ExtractionBranchPlan, + *, + extract_params: Any | None, + text_params: Any | None, + html_params: Any | None, + audio_chunk_params: Any | None, + asr_params: Any | None, + video_frame_params: Any | None, + video_text_dedup_params: Any | None, + av_fuse_params: Any | None, +) -> ResolvedExtractionInputs: + """Apply the canonical branch defaults to graph-builder inputs.""" + + family = branch.family + if family in {"pdf", "image"}: + extract_params = extract_params or ExtractParams() + elif family == "txt": + text_params = text_params or TextChunkParams() + elif family == "html": + html_params = html_params or HtmlChunkParams() + elif family == "audio": + audio_chunk_params = audio_chunk_params or AudioChunkParams( + split_type="size", + split_interval=_AUDIO_SPLIT_INTERVAL, + ) + asr_params = asr_params or _default_asr_params() + elif family == "video": + extract_params = extract_params or ExtractParams() + audio_chunk_params = audio_chunk_params or AudioChunkParams( + enabled=True, + split_type="size", + split_interval=_AUDIO_SPLIT_INTERVAL, + ) + asr_params = asr_params or _default_asr_params() + video_frame_params = video_frame_params or VideoFrameParams( + enabled=True, + fps=_VIDEO_FRAME_FPS, + dedup=True, + ) + video_text_dedup_params = video_text_dedup_params or VideoFrameTextDedupParams( + enabled=True, + max_dropped_frames=2, + ) + av_fuse_params = av_fuse_params or AudioVisualFuseParams(enabled=True) + + return ResolvedExtractionInputs( + extraction_mode=branch.extraction_mode, + extract_params=extract_params, + text_params=text_params, + html_params=html_params, + audio_chunk_params=audio_chunk_params, + asr_params=asr_params, + video_frame_params=video_frame_params, + video_text_dedup_params=video_text_dedup_params, + av_fuse_params=av_fuse_params, + ) + + +def _default_asr_params() -> ASRParams: + from nemo_retriever.audio import asr_params_from_env + + return asr_params_from_env().model_copy(update={"segment_audio": False}) diff --git a/nemo_retriever/src/nemo_retriever/params/models.py b/nemo_retriever/src/nemo_retriever/params/models.py index 8d24509d06..f0b46fe8b7 100644 --- a/nemo_retriever/src/nemo_retriever/params/models.py +++ b/nemo_retriever/src/nemo_retriever/params/models.py @@ -252,15 +252,11 @@ class BatchTuningParams(_ParamsModel): detect_workers: Optional[int] = None page_elements_cpus_per_actor: float = 1 ocr_cpus_per_actor: float = 1 - table_structure_workers: Optional[int] = None - table_structure_batch_size: Optional[int] = None - table_structure_cpus_per_actor: float = 1 embed_workers: Optional[int] = None embed_batch_size: int = 32 embed_cpus_per_actor: float = 1 gpu_page_elements: Optional[float] = None gpu_ocr: Optional[float] = None - gpu_table_structure: Optional[float] = None gpu_embed: Optional[float] = None nemotron_parse_workers: Optional[int] = None gpu_nemotron_parse: Optional[float] = None diff --git a/nemo_retriever/tests/test_ingest_interface.py b/nemo_retriever/tests/test_ingest_interface.py index e3a1a3abce..bdf758f878 100644 --- a/nemo_retriever/tests/test_ingest_interface.py +++ b/nemo_retriever/tests/test_ingest_interface.py @@ -150,16 +150,20 @@ def test_extract_default_direct_images_materialize_page_image(monkeypatch, tmp_p def passthrough_detection(self, batch_df): return batch_df + def fail_pdf_split(self, batch_df): + raise AssertionError("direct image extraction routed through PDFSplitActor") + monkeypatch.setattr( "nemo_retriever.graph.multi_type_extract_operator._MultiTypeExtractBase._run_detection_pipeline", passthrough_detection, ) + monkeypatch.setattr("nemo_retriever.pdf.split.PDFSplitActor.run", fail_pdf_split) result = ( - GraphIngestor(run_mode="inprocess", show_progress=False) + create_ingestor(run_mode="inprocess") .files([str(image_path)]) .extract( - ExtractParams( + params=ExtractParams( extract_text=True, extract_images=True, extract_tables=False, @@ -177,15 +181,19 @@ def passthrough_detection(self, batch_df): assert result.iloc[0]["metadata"]["source_path"] == str(image_path.resolve()) -def test_extract_default_mixed_pdf_and_image_uses_multitype_graph(tmp_path) -> None: +def test_extract_default_mixed_pdf_and_image_plans_ordered_branches(tmp_path) -> None: pdf = tmp_path / "manual.pdf" image = tmp_path / "scan.bmp" pdf.write_bytes(b"%PDF-1.4\n") image.write_bytes(b"bmp") - ingestor = GraphIngestor(run_mode="inprocess").files([str(pdf), str(image)]).extract() + ingestor = GraphIngestor(run_mode="inprocess").files([str(image), str(pdf)]).extract() - assert _effective_graph_node_names(ingestor) == ["MultiTypeExtractOperator"] + branches = ingestor._plan_default_extraction_branches() + assert [(branch.family, branch.extraction_mode, branch.input_paths) for branch in branches] == [ + ("pdf", "pdf", (str(pdf),)), + ("image", "image", (str(image),)), + ] def test_extract_explicit_pdf_rejects_image_input(tmp_path) -> None: diff --git a/nemo_retriever/tests/test_ingest_manifest.py b/nemo_retriever/tests/test_ingest_manifest.py new file mode 100644 index 0000000000..14276bb852 --- /dev/null +++ b/nemo_retriever/tests/test_ingest_manifest.py @@ -0,0 +1,266 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pandas as pd +import pytest + +from nemo_retriever.graph import Graph +from nemo_retriever.graph.abstract_operator import AbstractOperator +from nemo_retriever.branch_extraction import normalize_ray_branch_datasets +from nemo_retriever.graph_ingestor import GraphIngestor +from nemo_retriever.ingest_manifest import ( + build_input_manifest, + plan_extraction_branches, + resolve_branch_extraction_inputs, +) +from nemo_retriever.params import ASRParams + + +class _TagOperator(AbstractOperator): + def __init__(self, *, tag: str) -> None: + super().__init__(tag=tag) + self.tag = tag + + def preprocess(self, data: Any, **kwargs: Any) -> Any: + return data + + def process(self, data: Any, **kwargs: Any) -> Any: + return pd.DataFrame( + { + "path": list(data["path"]), + f"{self.tag}_value": [self.tag] * len(data), + } + ) + + def postprocess(self, data: Any, **kwargs: Any) -> Any: + return data + + +class _PostOperator(AbstractOperator): + def preprocess(self, data: Any, **kwargs: Any) -> Any: + return data + + def process(self, data: Any, **kwargs: Any) -> Any: + return data.assign(post_extract=True) + + def postprocess(self, data: Any, **kwargs: Any) -> Any: + return data + + +def _graph_with(operator: AbstractOperator) -> Graph: + return Graph() >> operator + + +def test_manifest_planner_pdf_doc_share_dedicated_pdf_branch(tmp_path) -> None: + pdf = tmp_path / "manual.pdf" + pptx = tmp_path / "deck.pptx" + pdf.write_bytes(b"pdf") + pptx.write_bytes(b"pptx") + + branches = plan_extraction_branches(build_input_manifest([str(pdf), str(pptx)])) + + assert [(branch.family, branch.extraction_mode, branch.input_paths) for branch in branches] == [ + ("pdf", "pdf", (str(pdf), str(pptx))), + ] + + +def test_manifest_planner_mixed_inputs_use_stable_family_order(tmp_path) -> None: + text = tmp_path / "notes.txt" + image = tmp_path / "scan.png" + pdf = tmp_path / "manual.pdf" + text.write_text("notes", encoding="utf-8") + image.write_bytes(b"png") + pdf.write_bytes(b"pdf") + + branches = plan_extraction_branches(build_input_manifest([str(text), str(image), str(pdf)])) + + assert [branch.family for branch in branches] == ["pdf", "image", "txt"] + + +def test_manifest_branch_specs_resolve_default_params(monkeypatch, tmp_path) -> None: + audio = tmp_path / "clip.wav" + video = tmp_path / "scene.mp4" + audio.write_bytes(b"audio") + video.write_bytes(b"video") + monkeypatch.setattr("nemo_retriever.ingest_manifest._default_asr_params", lambda: ASRParams(segment_audio=False)) + + branches = plan_extraction_branches(build_input_manifest([str(video), str(audio)])) + by_family = {branch.family: branch for branch in branches} + + audio_inputs = resolve_branch_extraction_inputs( + by_family["audio"], + extract_params=None, + text_params=None, + html_params=None, + audio_chunk_params=None, + asr_params=None, + video_frame_params=None, + video_text_dedup_params=None, + av_fuse_params=None, + ) + video_inputs = resolve_branch_extraction_inputs( + by_family["video"], + extract_params=None, + text_params=None, + html_params=None, + audio_chunk_params=None, + asr_params=None, + video_frame_params=None, + video_text_dedup_params=None, + av_fuse_params=None, + ) + + assert audio_inputs.extraction_mode == "audio" + assert audio_inputs.audio_chunk_params.split_interval == 500000 + assert audio_inputs.asr_params.segment_audio is False + assert video_inputs.extraction_mode == "auto" + assert video_inputs.extract_params is not None + assert video_inputs.audio_chunk_params.enabled is True + assert video_inputs.video_frame_params.fps == 0.5 + assert video_inputs.video_frame_params.dedup is True + assert video_inputs.video_text_dedup_params.enabled is True + assert video_inputs.av_fuse_params.enabled is True + + +def test_manifest_planner_rejects_unsupported_concrete_extensions(tmp_path) -> None: + payload = tmp_path / "payload.bin" + payload.write_bytes(b"unknown") + + with pytest.raises(ValueError, match="payload.bin"): + plan_extraction_branches(build_input_manifest([str(payload)])) + + +def test_manifest_planner_empty_glob_does_not_invent_modal_branches(tmp_path) -> None: + branches = plan_extraction_branches(build_input_manifest([str(tmp_path / "*.wav")])) + + assert [(branch.family, branch.input_paths) for branch in branches] == [("pdf", (str(tmp_path / "*.wav"),))] + + +def test_explicit_extraction_mode_bypasses_manifest_planning(tmp_path) -> None: + image = tmp_path / "scan.png" + image.write_bytes(b"png") + ingestor = GraphIngestor(run_mode="inprocess").files([str(image)]).extract(extraction_mode="auto") + + assert ingestor._plan_default_extraction_branches() is None + assert ingestor._resolve_effective_extraction_inputs().extraction_mode == "auto" + + +def test_inprocess_branch_execution_unions_schemas_and_runs_post_once(monkeypatch, tmp_path) -> None: + pdf = tmp_path / "manual.pdf" + image = tmp_path / "scan.png" + text = tmp_path / "notes.txt" + pdf.write_bytes(b"pdf") + image.write_bytes(b"png") + text.write_text("notes", encoding="utf-8") + extraction_calls: list[dict[str, Any]] = [] + post_calls: list[dict[str, Any]] = [] + + def fake_build_graph(**kwargs: Any) -> Graph: + extraction_calls.append(kwargs) + return _graph_with(_TagOperator(tag=kwargs["extraction_mode"])) + + def fake_post_graph(**kwargs: Any) -> Graph: + post_calls.append(kwargs) + return _graph_with(_PostOperator()) + + monkeypatch.setattr("nemo_retriever.branch_extraction.build_graph", fake_build_graph) + monkeypatch.setattr("nemo_retriever.branch_extraction.build_post_extract_graph", fake_post_graph) + + result = ( + GraphIngestor(run_mode="inprocess", show_progress=False) + .files([str(text), str(image), str(pdf)]) + .extract() + .embed() + .ingest() + ) + + assert [call["extraction_mode"] for call in extraction_calls] == ["pdf", "image", "text"] + assert all(call.get("embed_params") is None for call in extraction_calls) + assert len(post_calls) == 1 + assert post_calls[0]["embed_params"] is not None + assert set(result.columns) == {"path", "pdf_value", "image_value", "text_value", "post_extract"} + assert result["post_extract"].tolist() == [True, True, True] + + +class _FakeDataset: + def __init__(self, columns: list[str]) -> None: + self.columns = columns + self.unioned: list[_FakeDataset] = [] + self.normalized_columns: tuple[str, ...] | None = None + + def schema(self) -> Any: + return SimpleNamespace(names=self.columns) + + def map_batches(self, *_args: Any, **kwargs: Any) -> "_FakeDataset": + self.normalized_columns = kwargs["fn_kwargs"]["columns"] + return self + + def union(self, other: "_FakeDataset") -> "_FakeDataset": + self.unioned.append(other) + return self + + +class _LazySchemaDataset: + def __init__(self) -> None: + self.map_batches_called = False + + def schema(self, *, fetch_if_missing: bool = True) -> None: + assert fetch_if_missing is False + return None + + def map_batches(self, *_args: Any, **_kwargs: Any) -> "_LazySchemaDataset": + self.map_batches_called = True + return self + + +def test_ray_schema_normalization_does_not_trigger_lazy_schema_fetch() -> None: + datasets = [_LazySchemaDataset(), _LazySchemaDataset()] + + normalized = normalize_ray_branch_datasets(datasets) + + assert normalized == datasets + assert all(not dataset.map_batches_called for dataset in datasets) + + +def test_batch_branch_execution_uses_dataset_union(monkeypatch, tmp_path) -> None: + pdf = tmp_path / "manual.pdf" + image = tmp_path / "scan.png" + pdf.write_bytes(b"pdf") + image.write_bytes(b"png") + datasets = [_FakeDataset(["path", "pdf_value"]), _FakeDataset(["path", "image_value"])] + executor_calls: list[dict[str, Any]] = [] + + class FakeCluster: + def available_gpu_count(self) -> int: + return 0 + + def total_cpu_count(self) -> int: + return 64 + + class FakeExecutor: + def __init__(self, *args: Any, **kwargs: Any) -> None: + pass + + def build_dataset(self, data: Any, **kwargs: Any) -> Any: + executor_calls.append({"method": "build_dataset", "data": data}) + return datasets.pop(0) + + def ingest(self, data: Any, **kwargs: Any) -> Any: + executor_calls.append({"method": "ingest", "data": data}) + return pd.DataFrame({"done": [True]}) + + monkeypatch.setattr(GraphIngestor, "_ensure_batch_runtime", lambda self: (None, FakeCluster())) + monkeypatch.setattr("nemo_retriever.branch_extraction.RayDataExecutor", FakeExecutor) + monkeypatch.setattr("nemo_retriever.branch_extraction.build_graph", lambda **_kwargs: Graph()) + monkeypatch.setattr("nemo_retriever.branch_extraction.build_post_extract_graph", lambda **_kwargs: Graph()) + + result = GraphIngestor(run_mode="batch").files([str(pdf), str(image)]).extract().ingest() + + assert [call["method"] for call in executor_calls] == ["build_dataset", "build_dataset", "ingest"] + combined = executor_calls[2]["data"] + assert isinstance(combined, _FakeDataset) + assert len(combined.unioned) == 1 + assert combined.normalized_columns == ("path", "pdf_value", "image_value") + assert result["done"].tolist() == [True] diff --git a/nemo_retriever/tests/test_ingest_plans.py b/nemo_retriever/tests/test_ingest_plans.py index b2e66494ac..aaf520fcc0 100644 --- a/nemo_retriever/tests/test_ingest_plans.py +++ b/nemo_retriever/tests/test_ingest_plans.py @@ -347,34 +347,6 @@ def test_batch_tuning_to_node_overrides_auto_cpu_only_when_no_gpus(ocr_version: assert overrides["NemotronParseActor"]["concurrency"] == 2 -def test_batch_tuning_to_node_overrides_honors_table_structure_tuning() -> None: - cluster = ClusterResources( - total_resources=Resources(cpu_count=64, gpu_count=8), - available_resources=Resources(cpu_count=64, gpu_count=8), - ) - extract_params = ExtractParams( - use_table_structure=True, - batch_tuning=BatchTuningParams( - table_structure_workers=6, - table_structure_batch_size=12, - table_structure_cpus_per_actor=0.4, - gpu_table_structure=0.25, - ), - ) - - overrides = batch_tuning_to_node_overrides( - extract_params=extract_params, - embed_params=None, - cluster_resources=cluster, - ) - - assert overrides["TableStructureActor"]["batch_size"] == 12 - assert overrides["TableStructureActor"]["target_num_rows_per_block"] == 12 - assert overrides["TableStructureActor"]["concurrency"] == 6 - assert overrides["TableStructureActor"]["num_cpus"] == 0.4 - assert overrides["TableStructureActor"]["num_gpus"] == 0.25 - - def test_batch_tuning_to_node_overrides_adds_default_store_tuning() -> None: overrides = batch_tuning_to_node_overrides( extract_params=None, diff --git a/nemo_retriever/tests/test_pipeline_graph.py b/nemo_retriever/tests/test_pipeline_graph.py index 929a80db9b..5d1fa26ff7 100644 --- a/nemo_retriever/tests/test_pipeline_graph.py +++ b/nemo_retriever/tests/test_pipeline_graph.py @@ -15,14 +15,42 @@ from nemo_retriever.graph import FileListLoaderOperator, MultiTypeExtractOperator, UDFOperator from nemo_retriever.graph.cpu_operator import CPUOperator from nemo_retriever.graph.executor import AbstractExecutor, InprocessExecutor, RayDataExecutor +from nemo_retriever.graph.ingestor_runtime import build_graph, build_post_extract_graph from nemo_retriever.graph.gpu_operator import GPUOperator from nemo_retriever.graph.pipeline_graph import Graph, Node -from nemo_retriever.params import ASRParams -from nemo_retriever.params import ExtractParams -from nemo_retriever.params import VideoFrameTextDedupParams +from nemo_retriever.params import EmbedParams, ExtractParams, TextChunkParams from nemo_retriever.utils.ray_resource_hueristics import Resources +def _graph_node_names(graph: Graph) -> list[str]: + names: list[str] = [] + + def visit(node: Node) -> None: + names.append(getattr(node.operator, "name", node.name)) + for child in node.children: + visit(child) + + for root in graph.roots: + visit(root) + return names + + +def test_post_extract_graph_uses_explicit_content_reshape_flag() -> None: + graph = build_post_extract_graph(embed_params=EmbedParams()) + + assert "ExplodeContentToRows" in _graph_node_names(graph) + + +def test_text_build_graph_does_not_use_modal_content_reshape() -> None: + graph = build_graph( + extraction_mode="text", + text_params=TextChunkParams(), + embed_params=EmbedParams(), + ) + + assert "ExplodeContentToRows" not in _graph_node_names(graph) + + # --------------------------------------------------------------------------- # Concrete operator stubs for testing # --------------------------------------------------------------------------- @@ -623,44 +651,6 @@ def test_group_files_by_type(self): assert grouped["audio"] == ["/folder/audio.mp3"] assert grouped["video"] == ["/folder/video.mp4"] - def test_default_media_params_match_root_ingest_defaults(self, monkeypatch): - """Mixed auto uses the same audio/video defaults as root CLI typed media ingest.""" - import nemo_retriever.graph.multi_type_extract_operator as multitype - - monkeypatch.setattr( - multitype, - "asr_params_from_env", - lambda: ASRParams(audio_endpoints=("grpc.example:443", None), segment_audio=True), - ) - - op = multitype.MultiTypeExtractCPUActor() - - assert op.audio_chunk_params.split_type == "size" - assert op.audio_chunk_params.split_interval == 500000 - assert op.asr_params.audio_endpoints == ("grpc.example:443", None) - assert op.asr_params.segment_audio is False - assert op.video_frame_params.enabled is True - assert op.video_frame_params.fps == 0.5 - assert op.video_frame_params.dedup is True - assert op.video_text_dedup_params.enabled is True - assert op.video_text_dedup_params.max_dropped_frames == 2 - assert op.av_fuse_params.enabled is True - - def test_build_graph_forwards_video_text_dedup_params_to_multitype(self): - from nemo_retriever.graph.ingestor_runtime import build_graph - - text_dedup_params = VideoFrameTextDedupParams(enabled=False, max_dropped_frames=7) - - graph = build_graph( - extraction_mode="auto", - extract_params=ExtractParams(), - video_text_dedup_params=text_dedup_params, - ) - - op = graph.roots[0].operator - assert isinstance(op, MultiTypeExtractOperator) - assert op.video_text_dedup_params is text_dedup_params - def test_auto_mode_logs_and_skips_unsupported_extension_in_file_list(self, caplog): op = MultiTypeExtractOperator(extraction_mode="auto") @@ -970,6 +960,46 @@ def _fake_read_binary_files(paths, include_paths=True): assert captured["paths"] == [str(pdf_path)] assert captured["include_paths"] is True + def test_build_dataset_returns_lazy_dataset_without_materializing(self, tmp_path, monkeypatch): + import sys + from types import SimpleNamespace + + pdf_path = tmp_path / "sample.pdf" + pdf_path.write_bytes(b"pdf") + + class _FakeDataset: + def to_pandas(self): + raise AssertionError("to_pandas should not be called by build_dataset") + + class _FakeDataContext: + enable_rich_progress_bars = False + use_ray_tqdm = True + + @classmethod + def get_current(cls): + return cls() + + fake_dataset = _FakeDataset() + fake_ray_data = SimpleNamespace( + Dataset=_FakeDataset, + DataContext=_FakeDataContext, + read_binary_files=lambda paths, include_paths=True: fake_dataset, + ) + fake_ray = SimpleNamespace(is_initialized=lambda: True, init=lambda **kwargs: None, data=fake_ray_data) + + monkeypatch.setitem(sys.modules, "ray", fake_ray) + monkeypatch.setitem(sys.modules, "ray.data", fake_ray_data) + monkeypatch.setattr( + "nemo_retriever.graph.executor.gather_cluster_resources", + lambda ray: SimpleNamespace(available_gpu_count=lambda: 0), + ) + monkeypatch.setattr("nemo_retriever.graph.executor.resolve_graph", lambda graph, cluster: graph) + + executor = RayDataExecutor(Graph()) + result = executor.build_dataset([str(pdf_path)]) + + assert result is fake_dataset + def test_ingest_rejects_directory_paths_before_ray_read(self, tmp_path, monkeypatch): import sys from types import SimpleNamespace diff --git a/nemo_retriever/tests/test_root_cli_workflow.py b/nemo_retriever/tests/test_root_cli_workflow.py index 8f3bd784e0..dfc3e8975a 100644 --- a/nemo_retriever/tests/test_root_cli_workflow.py +++ b/nemo_retriever/tests/test_root_cli_workflow.py @@ -18,27 +18,17 @@ import nemo_retriever.adapters.cli.sdk_workflow as sdk_workflow from nemo_retriever.graph_ingestor import GraphIngestor -from nemo_retriever.params import AudioChunkParams, EmbedParams, ExtractParams, TextChunkParams, VideoFrameParams +from nemo_retriever.params import EmbedParams, ExtractParams RUNNER = CliRunner() cli_main = importlib.import_module("nemo_retriever.adapters.cli.main") -class _FakeAsrParams: - def model_copy(self, *, update: dict[str, Any]) -> dict[str, Any]: - return update - - def _make_fake_ingestor() -> Any: fake_ingestor = create_autospec(GraphIngestor, instance=True, spec_set=True) fake_ingestor.files.return_value = fake_ingestor fake_ingestor.extract.return_value = fake_ingestor - fake_ingestor.extract_txt.return_value = fake_ingestor - fake_ingestor.extract_html.return_value = fake_ingestor - fake_ingestor.extract_image_files.return_value = fake_ingestor - fake_ingestor.extract_audio.return_value = fake_ingestor - fake_ingestor.extract_video.return_value = fake_ingestor fake_ingestor.embed.return_value = fake_ingestor fake_ingestor.vdb_upload.return_value = fake_ingestor fake_ingestor.ingest.return_value = [{"status": "ok"}] @@ -70,7 +60,7 @@ def fake_create_ingestor(**kwargs: Any) -> Any: ] assert fake_ingestor.files.call_args.args == ([str(document)],) assert isinstance(fake_ingestor.extract.call_args.args[0], ExtractParams) - assert fake_ingestor.extract.call_args.kwargs == {"extraction_mode": "pdf"} + assert fake_ingestor.extract.call_args.kwargs == {} assert fake_ingestor.embed.call_args.args == () vdb_upload_params = fake_ingestor.vdb_upload.call_args.args[0] assert vdb_upload_params.vdb_op == "lancedb" @@ -112,7 +102,7 @@ def fake_create_ingestor(**kwargs: Any) -> Any: assert create_calls == [{"run_mode": "batch"}] assert fake_ingestor.files.call_args.args == ([str(first_document), str(globbed_document)],) assert isinstance(fake_ingestor.extract.call_args.args[0], ExtractParams) - assert fake_ingestor.extract.call_args.kwargs == {"extraction_mode": "pdf"} + assert fake_ingestor.extract.call_args.kwargs == {} assert fake_ingestor.vdb_upload.call_args.args[0].vdb_kwargs == { "uri": "/tmp/lancedb", "table_name": "docs", @@ -178,8 +168,6 @@ def fake_create_ingestor(**_kwargs: Any) -> Any: assert extract_params.ocr_version == "v1" assert extract_params.graphic_elements_invoke_url == "http://graphic-elements:8000/v1/infer" assert extract_params.table_structure_invoke_url == "http://table-structure:8000/v1/infer" - assert extract_params.use_table_structure is True - assert extract_params.table_output_format == "markdown" embed_params = fake_ingestor.embed.call_args.args[0] assert isinstance(embed_params, EmbedParams) @@ -189,79 +177,6 @@ def fake_create_ingestor(**_kwargs: Any) -> Any: assert embed_params.embed_model_name == "nvidia/llama-nemotron-embed-1b-v2" -def test_root_ingest_table_output_markdown_enables_local_table_structure(monkeypatch, tmp_path) -> None: - fake_ingestor = _make_fake_ingestor() - document = tmp_path / "table-structure.pdf" - document.write_bytes(b"%PDF-1.4\n") - - monkeypatch.setattr(sdk_workflow, "create_ingestor", lambda **_kwargs: fake_ingestor) - - result = RUNNER.invoke(cli_main.app, ["ingest", str(document), "--table-output-format", "markdown"]) - - assert result.exit_code == 0 - extract_params = fake_ingestor.extract.call_args.args[0] - assert isinstance(extract_params, ExtractParams) - assert extract_params.use_table_structure is True - assert extract_params.table_output_format == "markdown" - assert extract_params.table_structure_invoke_url is None - - -def test_root_ingest_table_output_pseudo_markdown_does_not_enable_table_structure(monkeypatch, tmp_path) -> None: - fake_ingestor = _make_fake_ingestor() - document = tmp_path / "plain-table.pdf" - document.write_bytes(b"%PDF-1.4\n") - - monkeypatch.setattr(sdk_workflow, "create_ingestor", lambda **_kwargs: fake_ingestor) - - result = RUNNER.invoke(cli_main.app, ["ingest", str(document), "--table-output-format", "pseudo_markdown"]) - - assert result.exit_code == 0 - extract_params = fake_ingestor.extract.call_args.args[0] - assert isinstance(extract_params, ExtractParams) - assert extract_params.use_table_structure is False - assert extract_params.table_output_format == "pseudo_markdown" - - -def test_root_ingest_table_structure_url_auto_enables_table_structure(monkeypatch, tmp_path) -> None: - fake_ingestor = _make_fake_ingestor() - document = tmp_path / "remote-table-structure.pdf" - document.write_bytes(b"%PDF-1.4\n") - - monkeypatch.setattr(sdk_workflow, "create_ingestor", lambda **_kwargs: fake_ingestor) - - result = RUNNER.invoke( - cli_main.app, - [ - "ingest", - str(document), - "--table-structure-invoke-url", - "http://table-structure:8000/v1/infer", - ], - ) - - assert result.exit_code == 0 - extract_params = fake_ingestor.extract.call_args.args[0] - assert isinstance(extract_params, ExtractParams) - assert extract_params.table_structure_invoke_url == "http://table-structure:8000/v1/infer" - assert extract_params.use_table_structure is True - assert extract_params.table_output_format == "markdown" - - -def test_root_ingest_passes_local_hf_embed_backend(monkeypatch, tmp_path) -> None: - fake_ingestor = _make_fake_ingestor() - document = tmp_path / "local-hf.pdf" - document.write_bytes(b"%PDF-1.4\n") - - monkeypatch.setattr(sdk_workflow, "create_ingestor", lambda **_kwargs: fake_ingestor) - - result = RUNNER.invoke(cli_main.app, ["ingest", str(document), "--local-ingest-embed-backend", "hf"]) - - assert result.exit_code == 0 - embed_params = fake_ingestor.embed.call_args.args[0] - assert isinstance(embed_params, EmbedParams) - assert embed_params.local_ingest_embed_backend == "hf" - - def test_root_ingest_passes_ocr_lang_option(monkeypatch, tmp_path) -> None: fake_ingestor = _make_fake_ingestor() document = tmp_path / "english-ocr.pdf" @@ -331,32 +246,18 @@ def fake_create_ingestor(**kwargs: Any) -> Any: "8", "--page-elements-cpus-per-actor", "0.5", - "--page-elements-gpus-per-actor", - "0.2", "--ocr-workers", "5", "--ocr-batch-size", "6", "--ocr-cpus-per-actor", "0.75", - "--ocr-gpus-per-actor", - "0.3", - "--table-structure-workers", - "6", - "--table-structure-batch-size", - "12", - "--table-structure-cpus-per-actor", - "0.4", - "--table-structure-gpus-per-actor", - "0.25", "--embed-workers", "7", "--embed-batch-size", "16", "--embed-cpus-per-actor", "0.25", - "--embed-gpus-per-actor", - "0.5", ], ) @@ -377,22 +278,15 @@ def fake_create_ingestor(**kwargs: Any) -> Any: assert extract_params.batch_tuning.page_elements_workers == 3 assert extract_params.batch_tuning.page_elements_batch_size == 8 assert extract_params.batch_tuning.page_elements_cpus_per_actor == 0.5 - assert extract_params.batch_tuning.gpu_page_elements == 0.2 assert extract_params.batch_tuning.ocr_workers == 5 assert extract_params.batch_tuning.ocr_inference_batch_size == 6 assert extract_params.batch_tuning.ocr_cpus_per_actor == 0.75 - assert extract_params.batch_tuning.gpu_ocr == 0.3 - assert extract_params.batch_tuning.table_structure_workers == 6 - assert extract_params.batch_tuning.table_structure_batch_size == 12 - assert extract_params.batch_tuning.table_structure_cpus_per_actor == 0.4 - assert extract_params.batch_tuning.gpu_table_structure == 0.25 embed_params = fake_ingestor.embed.call_args.args[0] assert isinstance(embed_params, EmbedParams) assert embed_params.batch_tuning.embed_workers == 7 assert embed_params.batch_tuning.embed_batch_size == 16 assert embed_params.batch_tuning.embed_cpus_per_actor == 0.25 - assert embed_params.batch_tuning.gpu_embed == 0.5 assert "Ingested 1 document(s) into LanceDB lancedb/nv-ingest." in result.output @@ -413,7 +307,7 @@ def test_root_ingest_reports_unknown_default_input_type(tmp_path) -> None: assert "Unsupported input file type(s) for retriever ingest" in result.output -def test_root_ingest_routes_text_inputs_by_default(monkeypatch, tmp_path) -> None: +def test_root_ingest_routes_text_inputs_by_default_to_auto_planner(monkeypatch, tmp_path) -> None: fake_ingestor = _make_fake_ingestor() document = tmp_path / "notes.txt" document.write_text("not a pdf", encoding="utf-8") @@ -424,27 +318,18 @@ def test_root_ingest_routes_text_inputs_by_default(monkeypatch, tmp_path) -> Non assert result.exit_code == 0 assert fake_ingestor.files.call_args.args == ([str(document)],) - text_params = fake_ingestor.extract_txt.call_args.args[0] - assert isinstance(text_params, TextChunkParams) - assert fake_ingestor.extract.call_count == 0 + assert isinstance(fake_ingestor.extract.call_args.args[0], ExtractParams) + assert fake_ingestor.extract.call_args.kwargs == {} -def test_root_ingest_routes_explicit_image_inputs(monkeypatch, tmp_path) -> None: - fake_ingestor = _make_fake_ingestor() - document = tmp_path / "figure.svg" - document.write_text("", encoding="utf-8") - - monkeypatch.setattr(sdk_workflow, "create_ingestor", lambda **_kwargs: fake_ingestor) - - result = RUNNER.invoke(cli_main.app, ["ingest", str(document), "--input-type", "image"]) +def test_root_ingest_help_does_not_expose_input_type() -> None: + result = RUNNER.invoke(cli_main.app, ["ingest", "--help"]) assert result.exit_code == 0 - extract_params = fake_ingestor.extract_image_files.call_args.args[0] - assert isinstance(extract_params, ExtractParams) - assert fake_ingestor.extract.call_count == 0 + assert "--input-type" not in result.output -def test_root_ingest_routes_tiff_inputs_by_default(monkeypatch, tmp_path) -> None: +def test_root_ingest_routes_tiff_inputs_by_default_to_auto_planner(monkeypatch, tmp_path) -> None: fake_ingestor = _make_fake_ingestor() document = tmp_path / "scan.tiff" document.write_bytes(b"tiff") @@ -455,44 +340,8 @@ def test_root_ingest_routes_tiff_inputs_by_default(monkeypatch, tmp_path) -> Non assert result.exit_code == 0 assert fake_ingestor.files.call_args.args == ([str(document)],) - extract_params = fake_ingestor.extract_image_files.call_args.args[0] - assert isinstance(extract_params, ExtractParams) - assert fake_ingestor.extract.call_count == 0 - - -def test_root_ingest_routes_audio_inputs(monkeypatch, tmp_path) -> None: - fake_ingestor = _make_fake_ingestor() - document = tmp_path / "meeting.m4a" - document.write_bytes(b"audio") - - monkeypatch.setattr(sdk_workflow, "create_ingestor", lambda **_kwargs: fake_ingestor) - monkeypatch.setattr(sdk_workflow, "_default_asr_params", _FakeAsrParams) - - result = RUNNER.invoke(cli_main.app, ["ingest", str(document), "--input-type", "audio"]) - - assert result.exit_code == 0 - audio_params = fake_ingestor.extract_audio.call_args.kwargs["params"] - assert isinstance(audio_params, AudioChunkParams) - assert audio_params.split_type == "size" - assert audio_params.split_interval == 500000 - assert fake_ingestor.extract_audio.call_args.kwargs["asr_params"] == {"segment_audio": False} - - -def test_root_ingest_routes_video_inputs(monkeypatch, tmp_path) -> None: - fake_ingestor = _make_fake_ingestor() - document = tmp_path / "demo.mp4" - document.write_bytes(b"video") - - monkeypatch.setattr(sdk_workflow, "create_ingestor", lambda **_kwargs: fake_ingestor) - monkeypatch.setattr(sdk_workflow, "_default_asr_params", _FakeAsrParams) - - result = RUNNER.invoke(cli_main.app, ["ingest", str(document), "--input-type", "video"]) - - assert result.exit_code == 0 - video_frame_params = fake_ingestor.extract_video.call_args.kwargs["video_frame_params"] - assert isinstance(video_frame_params, VideoFrameParams) - assert video_frame_params.fps == 0.5 - assert video_frame_params.enabled is True + assert isinstance(fake_ingestor.extract.call_args.args[0], ExtractParams) + assert fake_ingestor.extract.call_args.kwargs == {} def test_root_ingest_auto_mixed_directory_uses_auto_extraction(monkeypatch, tmp_path) -> None: @@ -508,17 +357,13 @@ def test_root_ingest_auto_mixed_directory_uses_auto_extraction(monkeypatch, tmp_ image.write_bytes(b"png") monkeypatch.setattr(sdk_workflow, "create_ingestor", lambda **_kwargs: fake_ingestor) - monkeypatch.setattr(sdk_workflow, "_default_asr_params", _FakeAsrParams) result = RUNNER.invoke(cli_main.app, ["ingest", str(dataset)]) assert result.exit_code == 0 assert set(fake_ingestor.files.call_args.args[0]) == {str(pdf.resolve()), str(text.resolve()), str(image.resolve())} - assert fake_ingestor.extract.call_args.kwargs["extraction_mode"] == "auto" - assert isinstance(fake_ingestor.extract.call_args.kwargs["text_params"], TextChunkParams) - assert "asr_params" not in fake_ingestor.extract.call_args.kwargs - assert "video_frame_params" not in fake_ingestor.extract.call_args.kwargs assert isinstance(fake_ingestor.extract.call_args.args[0], ExtractParams) + assert fake_ingestor.extract.call_args.kwargs == {} def test_root_ingest_reports_os_errors(monkeypatch) -> None: