diff --git a/nemo_retriever/src/nemo_retriever/graph/tabular_fetch_embeddings_operator.py b/nemo_retriever/src/nemo_retriever/graph/tabular_fetch_embeddings_operator.py index ecfceb0fbc..6db5240f30 100644 --- a/nemo_retriever/src/nemo_retriever/graph/tabular_fetch_embeddings_operator.py +++ b/nemo_retriever/src/nemo_retriever/graph/tabular_fetch_embeddings_operator.py @@ -2,28 +2,45 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""Graph operator: fetch tabular entity descriptions from Neo4j into an embedding-ready DataFrame.""" +"""Graph operator: turn ``(tables_df, columns_df)`` into embedding-ready rows.""" from __future__ import annotations -from typing import Any +import logging +from typing import Any, Iterable import pandas as pd from nemo_retriever.graph.abstract_operator import AbstractOperator from nemo_retriever.graph.cpu_operator import CPUOperator +from nemo_retriever.tabular_data.ingestion.embeddings import fetch_tabular_embedding_dataframe +from nemo_retriever.tabular_data.ingestion.model.reserved_words import Labels + +logger = logging.getLogger(__name__) class TabularFetchEmbeddingsOp(AbstractOperator, CPUOperator): - """Fetch all tabular entity descriptions from Neo4j into an embedding-ready DataFrame. + """Build an embedding-ready DataFrame from ``(tables_df, columns_df)``. + + Expected input: a 2-tuple ``(tables_df, columns_df)``. Both DataFrames + carry the post-ingest UUIDs of the Table/Column nodes written to Neo4j; + ``tables_df`` is keyed by ``id`` (table UUID) with at least + ``table_name``, ``table_schema`` and ``description`` columns, and + ``columns_df`` carries one row per column with ``id``, ``table_name``, + ``column_name``, ``data_type``, ``description`` and ``sample_values``. + Multiple schemas can be concatenated into the same pair — the + ``table_schema`` column on each table row keeps them distinguishable. + + Output columns: ``text, _embed_modality, path, page_number, metadata``. + Two row types are produced: - This operator ignores its input — it always queries Neo4j directly and - returns a fresh DataFrame with columns: - ``text``, ``_embed_modality``, ``path``, ``page_number``, ``metadata``. + * one ``Table`` row per table, whose ``text`` joins the table description + with a compact list of its columns; and + * one ``Column`` row per column. - The output schema matches the format produced by the unstructured pipeline, - so the standard :class:`~nemo_retriever.text_embed.operators._BatchEmbedActor` - can be chained directly after this operator. + The text templates match the previous Neo4j-derived format, so the rest + of the pipeline (``_BatchEmbedActor`` → ``IngestVdbOperator``) keeps + working untouched. """ def __init__( @@ -39,9 +56,185 @@ def preprocess(self, data: Any, **kwargs: Any) -> Any: return data def process(self, data: Any, **kwargs: Any) -> pd.DataFrame: - from nemo_retriever.tabular_data.ingestion.embeddings import fetch_tabular_embedding_dataframe + if not (isinstance(data, tuple) and len(data) == 2): + logger.warning( + "TabularFetchEmbeddingsOp received no (tables_df, columns_df) input " + "for database %r (got %s); falling back to fetching embeddings from the database.", + self._database_name, + type(data).__name__, + ) + return fetch_tabular_embedding_dataframe(database_name=self._database_name) - return fetch_tabular_embedding_dataframe(database_name=self._database_name) + tables_df, columns_df = data + rows = list(self._build_rows(tables_df, columns_df)) + if rows: + return pd.DataFrame(rows) + return pd.DataFrame(columns=["text", "_embed_modality", "path", "page_number", "metadata"]) def postprocess(self, data: Any, **kwargs: Any) -> Any: return data + + def _build_rows(self, tables_df: pd.DataFrame, columns_df: pd.DataFrame) -> Iterable[dict[str, Any]]: + # Index columns by (schema, table_name) so duplicate table names across + # different schemas (e.g. two "users" tables in two schemas) don't get + # their columns merged into one bucket — which would both bloat each + # table's text and double-emit every column row per duplicate. + columns_by_table: dict[tuple[str, str], list[Any]] = {} + for _, col in columns_df.iterrows(): + key = ( + str(col.get("table_schema", "")).lower(), + str(col.get("table_name", "")).lower(), + ) + columns_by_table.setdefault(key, []).append(col) + + rows: list[dict[str, Any]] = [] + for _, table in tables_df.iterrows(): + table_id = str(table.get("id", "")) + table_name = str(table.get("table_name", "")) + table_description = "" if pd.isna(v := table.get("description")) else str(v).strip() + schema_name = str(table.get("table_schema", "")) + columns = columns_by_table.get((schema_name.lower(), table_name.lower()), []) + + table_text = _create_table_text( + table_name=table_name, + table_description=table_description, + columns=columns, + schema_name=schema_name, + database_name=self._database_name, + ) + rows.append( + _create_row( + text=table_text, + node_id=table_id, + label=Labels.TABLE, + name=table_name, + schema_name=schema_name, + database_name=self._database_name, + ) + ) + + for column in columns: + column_id = str(column.get("id", "")) + column_name = str(column.get("column_name", "")) + data_type = "" if pd.isna(v := column.get("data_type")) else str(v).strip() + column_description = "" if pd.isna(v := column.get("description")) else str(v).strip() + sample_values = (column.get("sample_values") or [])[:5] + column_text = _create_column_text( + column_name=column_name, + column_description=column_description, + data_type=data_type, + sample_values=sample_values, + schema_name=schema_name, + table_name=table_name, + database_name=self._database_name, + ) + rows.append( + _create_row( + text=column_text, + node_id=column_id, + label=Labels.COLUMN, + name=column_name, + schema_name=schema_name, + database_name=self._database_name, + ) + ) + return rows + + +# ── Helpers ────────────────────────────────────────────────────────────────── + + +def _create_table_text( + *, + table_name: str, + table_description: str, + columns: list[Any], + schema_name: str, + database_name: str, +) -> str: + """Build the embedding text for a Table node. + + Returns just the text string; the caller is responsible for wrapping it + in an embed-row dict via :func:`_create_row`. + """ + text = f"db_name: {database_name}" f", schema_name: {schema_name}" f", table_name: {table_name}" + if table_description: + text += f", table_description: {table_description}" + + column_pieces: list[str] = [] + for column in columns: + column_name = column.get("column_name", "") + data_type = "" if pd.isna(v := column.get("data_type")) else str(v).strip() + piece = f"{{name: {column_name}, data_type: {data_type}" + + column_description = "" if pd.isna(v := column.get("description")) else str(v).strip() + if column_description: + piece += f", description: {column_description}" + piece += "}" + column_pieces.append(piece) + + text += f", columns: {','.join(column_pieces)}" + return text + + +def _create_column_text( + *, + column_name: str, + column_description: str, + data_type: str, + sample_values: list[Any], + table_name: str, + schema_name: str, + database_name: str, +) -> str: + """Build the embedding text for a Column node. + + Returns just the text string; the caller is responsible for wrapping it + in an embed-row dict via :func:`_create_row`. + """ + text = ( + f"db_name: {database_name}" + f", schema_name: {schema_name}" + f", table_name: {table_name}" + f", column_name: {column_name}" + f", data_type: {data_type}" + ) + if column_description: + text += f", column_description: {column_description}" + if len(sample_values) > 0: + text += f", sample_values: {', '.join(str(x) for x in sample_values)}" + return text + + +def _create_row( + *, + text: str, + node_id: str | None, + label: str, + name: str, + schema_name: str, + database_name: str, +) -> dict[str, Any]: + path = f"neo4j:{node_id}" if node_id else "neo4j:unknown" + # Nest tabular identifiers under content_metadata so they survive the + # IngestVdbOperator → LanceDB write path (which only persists + # content_metadata + source_metadata into the table's metadata column). + # Top-level copies are kept for any in-memory consumer of this DataFrame. + tabular_fields = { + "id": node_id, + "label": label, + "name": name, + "source_path": path, + "schema_name": schema_name, + "database_name": database_name, + } + return { + "text": text.strip(), + "_embed_modality": "text", + "path": path, + "page_number": -1, + "metadata": { + **tabular_fields, + "content_metadata": dict(tabular_fields), + }, + } diff --git a/nemo_retriever/src/nemo_retriever/graph/tabular_schema_extract_operator.py b/nemo_retriever/src/nemo_retriever/graph/tabular_schema_extract_operator.py index 439cef9ea2..eca68d435d 100644 --- a/nemo_retriever/src/nemo_retriever/graph/tabular_schema_extract_operator.py +++ b/nemo_retriever/src/nemo_retriever/graph/tabular_schema_extract_operator.py @@ -24,9 +24,15 @@ class TabularSchemaExtractOp(AbstractOperator, CPUOperator): connector stored in *tabular_params*. 2. Write the extracted entities as graph nodes and relationships into Neo4j. - The operator produces an empty DataFrame as output so it can be chained - with downstream operators (e.g. :class:`TabularFetchEmbeddingsOp`) via - ``>>``. All meaningful state lives in Neo4j after this step. + The operator returns ``(tables_df, columns_df)`` — concatenated across + every ingested :class:`Schema` — carrying the post-ingest UUIDs written + to Neo4j. The per-row ``table_schema`` column keeps schemas + distinguishable. Downstream operators (notably + :class:`TabularFetchEmbeddingsOp`) can build embedding text directly + from this pair without a Neo4j round-trip. + + Returns ``(empty_df, empty_df)`` when there is nothing to ingest, so + the chain still flows. """ def __init__( @@ -43,15 +49,26 @@ def preprocess(self, data: Any, **kwargs: Any) -> TabularExtractParams | None: return data return self._tabular_params - def process(self, data: TabularExtractParams | None, **kwargs: Any) -> pd.DataFrame: + def process(self, data: TabularExtractParams | None, **kwargs: Any) -> tuple[pd.DataFrame, pd.DataFrame]: from nemo_retriever.tabular_data.ingestion.extract_data import ( extract_tabular_db_data, store_relational_db_in_neo4j, ) + empty = (pd.DataFrame(), pd.DataFrame()) + if data is None or data.connector is None: + return empty + schema_data = extract_tabular_db_data(params=data) - store_relational_db_in_neo4j(data=schema_data, dialect=data.connector.dialect) - return pd.DataFrame() + schemas = store_relational_db_in_neo4j(data=schema_data, dialect=data.connector.dialect) or {} + if not schemas: + return empty + + tables = [s.tables_df for s in schemas.values() if s.tables_df is not None] + columns = [s.columns_df for s in schemas.values() if s.columns_df is not None] + tables_df = pd.concat(tables, ignore_index=True) if tables else pd.DataFrame() + columns_df = pd.concat(columns, ignore_index=True) if columns else pd.DataFrame() + return tables_df, columns_df def postprocess(self, data: Any, **kwargs: Any) -> Any: return data diff --git a/nemo_retriever/src/nemo_retriever/tabular_data/ingestion/extract_data.py b/nemo_retriever/src/nemo_retriever/tabular_data/ingestion/extract_data.py index ea0e39129a..7c3a3e49d5 100644 --- a/nemo_retriever/src/nemo_retriever/tabular_data/ingestion/extract_data.py +++ b/nemo_retriever/src/nemo_retriever/tabular_data/ingestion/extract_data.py @@ -60,18 +60,22 @@ def store_relational_db_in_neo4j(data, dialect: str, num_workers: int = 4): Args: data: Data dict returned by extract_tabular_db_data(). dialect: SQL dialect used by the connector (e.g. "sqlite", "duckdb", "snowflake"). - neo4j_conn: Active Neo4jConnectionManager instance (unused directly here; - populate_tabular_data uses its own DAL connection, but - accepted for API consistency with the other ingest steps). + num_workers: Worker count forwarded to populate_tabular_data. + + Returns: + ``{schema_name_lower: Schema}`` dict produced during ingestion, so + callers can recover the post-ingest ``tables_df`` / ``columns_df`` + (with the UUIDs assigned to each Table/Column node) without a + round-trip back to Neo4j. Returns ``{}`` when *data* is empty. """ if not data: - return + return {} from nemo_retriever.tabular_data.ingestion.write_to_graph import ( populate_tabular_data, ) - populate_tabular_data( + return populate_tabular_data( data, num_workers=num_workers, dialect=dialect, diff --git a/nemo_retriever/src/nemo_retriever/tabular_data/ingestion/write_to_graph.py b/nemo_retriever/src/nemo_retriever/tabular_data/ingestion/write_to_graph.py index b0907e35c1..af2b917ba0 100644 --- a/nemo_retriever/src/nemo_retriever/tabular_data/ingestion/write_to_graph.py +++ b/nemo_retriever/src/nemo_retriever/tabular_data/ingestion/write_to_graph.py @@ -38,23 +38,23 @@ def populate_tabular_data(data, num_workers, dialect): if tables_df is None or tables_df.empty: logger.warning("No tables found in source database; skipping graph population.") - return + return {} database = data["database_name"] logger.info(f"Started parsing db {database}.") - all_schemas = {} all_schemas = populate_db(tables_df, columns_df, database, num_workers) if "fks" in data: populate_fks(fks=data["fks"], database_name=database) + if "pks" in data: populate_pks(pks=data["pks"], database_name=database) if "queries" in data: populate_queries(all_schemas, data["queries"], num_workers, dialect) - return [] + return all_schemas def populate_db(tables_df, columns_df, database, num_workers): diff --git a/nemo_retriever/src/nemo_retriever/vdb/__init__.py b/nemo_retriever/src/nemo_retriever/vdb/__init__.py index 2c03d7a669..cd4c4f3548 100644 --- a/nemo_retriever/src/nemo_retriever/vdb/__init__.py +++ b/nemo_retriever/src/nemo_retriever/vdb/__init__.py @@ -6,7 +6,7 @@ from nemo_retriever.vdb.adt_vdb import VDB from nemo_retriever.vdb.factory import get_vdb_op_cls -from nemo_retriever.vdb.operators import IngestVdbOperator, RetrieveVdbOperator +from nemo_retriever.vdb.operators import IngestVdbOperator, PutVdbOperator, RetrieveVdbOperator from nemo_retriever.vdb.records import RetrievalHit, normalize_retrieval_results, to_client_vdb_records from nemo_retriever.vdb.sidecar_metadata import ( apply_sidecar_metadata_to_client_batches, @@ -22,6 +22,7 @@ "VDB", "get_vdb_op_cls", "IngestVdbOperator", + "PutVdbOperator", "RetrievalHit", "RetrieveVdbOperator", "normalize_retrieval_results", diff --git a/nemo_retriever/src/nemo_retriever/vdb/adt_vdb.py b/nemo_retriever/src/nemo_retriever/vdb/adt_vdb.py index 553ffd027d..d99378638d 100644 --- a/nemo_retriever/src/nemo_retriever/vdb/adt_vdb.py +++ b/nemo_retriever/src/nemo_retriever/vdb/adt_vdb.py @@ -22,6 +22,7 @@ """ from abc import ABC, abstractmethod +from typing import Any class VDB(ABC): @@ -133,6 +134,89 @@ def retrieval(self, queries: list, **kwargs): """ pass + def put(self, records: list, **kwargs: Any) -> dict[str, Any]: + """Replace a batch of existing rows in the target table/index. + + Note: this method is intentionally **not** decorated with + :func:`abc.abstractmethod`. Marking it abstract would cause + Python's ABC machinery to refuse instantiation of any concrete + :class:`VDB` subclass that does not override ``put`` — which + would in turn make the early-detection guard in + :class:`~nemo_retriever.vdb.operators.PutVdbOperator` (which + compares ``type(self._vdb).put is VDB.put``) permanently + unreachable, since instantiation would already have failed. + The default body below raises :class:`NotImplementedError` so + backends that have not implemented stable-key puts fail fast + and visibly at the first ``put`` call (and are caught by the + operator-level guard at construction time). + + ``put`` exists as a separate entry point from + :meth:`write_to_index` because it has fundamentally different + semantics. Where ``write_to_index`` is an *append* (or full + ingest) operation, ``put`` is a **strict in-place replace**: + + * Rows whose key value already exists in the target table are + **updated in place** — all stored columns (including the dense + vector) are replaced with the values from ``records``. + * Rows whose key value does not exist in the target table + MUST raise :class:`KeyError`. ``put`` MUST NOT insert new + rows; ingestion of new rows belongs in + :meth:`write_to_index` / :meth:`run`. + * Rows whose key value is empty or ``None`` MUST raise + :class:`KeyError` — a put has no stable identity to target + without a key. + * Rows that already exist in the target but are *not* referenced + by ``records`` are **left untouched**. ``put`` MUST NOT + delete rows. + + This contract makes ``put`` suitable for in-place metadata + patches where the caller knows the exact set of existing rows + it wants to change and would rather fail loudly than silently + no-op or duplicate data. + + Implementations are expected to: + + * Validate / transform records the same way :meth:`write_to_index` + does (e.g. enforce the embedding dimension, apply the + ``on_bad_vectors`` policy), so that a put row is + indistinguishable from one written via the full-ingest path. + * Raise :class:`FileNotFoundError` (or an equivalent) when the + target table does not yet exist. ``put`` MUST NOT create + tables on the fly. + * Avoid building heavy secondary structures (e.g. IVF/HNSW + vector indexes, FTS indexes) on the put path: incremental + batches are typically too small to train such indexes + meaningfully. Defer index builds to the next full + :meth:`write_to_index` / :meth:`create_index` call. + + Parameters: + - records (list): NV-Ingest-shaped batches (typically a list of + lists of record dicts) to put into the target. The shape + mirrors what :meth:`write_to_index` accepts. + - table_name (str, optional): override the operator's configured + target table/index name for this call. When ``None``, the + implementation should use its default target. + - key (str, optional): name of the column used as the stable + put key. Defaults to ``"id"``. Rows missing this column + (or with an empty value) MUST raise :class:`KeyError`. + + Returns: + - implementation-specific result describing what happened + (typical fields include the number of rows put). + Concrete implementations should document the exact return + shape. + + Backends that genuinely cannot support stable-key puts should + override this method and raise :class:`NotImplementedError` + explicitly so that :class:`PutVdbOperator` (and any other + caller) fails fast with a clear message instead of silently + no-oping or duplicating rows. + """ + raise NotImplementedError( + f"{type(self).__name__} does not implement put(); " + "in-place stable-key puts are not supported by this VDB backend." + ) + @abstractmethod def run(self, records): """Pipeline entry point: ensure the index exists, then ingest. diff --git a/nemo_retriever/src/nemo_retriever/vdb/lancedb.py b/nemo_retriever/src/nemo_retriever/vdb/lancedb.py index 1586597bf6..f43a3a1d55 100644 --- a/nemo_retriever/src/nemo_retriever/vdb/lancedb.py +++ b/nemo_retriever/src/nemo_retriever/vdb/lancedb.py @@ -12,6 +12,7 @@ import lancedb import pyarrow as pa +import pyarrow.compute as pc from nemo_retriever.vdb.adt_vdb import VDB @@ -119,6 +120,7 @@ def _lancedb_arrow_schema(vector_dim: int) -> pa.Schema: pa.field("text", pa.string()), pa.field("metadata", pa.string()), pa.field("source", pa.string()), + pa.field("id", pa.string()), ] ) @@ -275,12 +277,18 @@ def _create_lancedb_results( logger.debug(f"No text found for entity: {source_name} page: {pg_num} type: {doc_type}") continue + row_id = content_meta.get("id") if isinstance(content_meta, dict) else None + if row_id is None and isinstance(metadata, dict): + row_id = metadata.get("id") + row_id_str = str(row_id) if row_id is not None else "" + lancedb_rows.append( { "vector": embedding, "text": text, "metadata": _json_str(content_meta), "source": _json_str(metadata.get("source_metadata", {})), + "id": row_id_str, } ) accepted += 1 @@ -529,6 +537,95 @@ def run(self, records): logger.info("Skipping LanceDB index creation for table %r because build_index=False.", self.table_name) return records + def put( + self, + records, + table_name: str | None = None, + key: str = "id", + ) -> dict[str, int]: + """Replace existing rows of a LanceDB table in place, keyed by ``key``. + + Strict update-only semantics: + + * Rows matching an existing row by ``key`` are **updated in place** + (all columns, including ``vector``, are replaced). + * Rows whose ``key`` value is missing/empty raise :class:`KeyError` + — a put operation has no stable identity to target without a key. + * Rows whose ``key`` value does not match any row currently in the + table raise :class:`KeyError` — ``put`` never inserts new rows. + * Rows already in the table that are *not* referenced are **left + untouched** — ``put`` never deletes. + + If the target table does not exist, :class:`FileNotFoundError` is + raised; ``put`` will not create tables on the fly. + + Vector / FTS indexes are intentionally **not** rebuilt here: + incremental puts typically carry only a handful of rows. Indexes + will be (re)built by the next full :meth:`run` / + :meth:`write_to_index` call. + + Returns the row counts dict from :func:`_create_lancedb_results` + plus: ``put``. + """ + target_name = table_name or self.table_name + connect_start = time.perf_counter() + db = lancedb.connect(uri=self.uri) + _record_timing("lancedb.connect", time.perf_counter() - connect_start) + + if self.validate_vector_length and self.on_bad_vectors != "error": + expected_dim: int | None = self.vector_dim + else: + expected_dim = None + + rows, counts = _create_lancedb_results(records or [], expected_dim=expected_dim) + counts["put"] = 0 + + if not rows: + logger.info("LanceDB.put: nothing to put into table %r.", target_name) + return counts + + rows_missing_key = [r for r in rows if not r.get(key)] + if rows_missing_key: + raise KeyError( + f"LanceDB.put: {len(rows_missing_key)} row(s) have an empty {key!r} value; " + "put() requires a stable id for every row." + ) + + try: + table = db.open_table(target_name) + except (ValueError, FileNotFoundError) as exc: + if isinstance(exc, ValueError) and not _is_missing_lancedb_table_error(exc): + raise + raise FileNotFoundError( + f"LanceDB.put: table {target_name!r} not found at uri={self.uri!r}; " + "put() only updates existing rows and will not create tables." + ) from exc + + input_ids = [r[key] for r in rows] + unique_input_ids = list(dict.fromkeys(input_ids)) + + filter_expr = pc.field(key).isin(pa.array(unique_input_ids, type=pa.string())) + existing_arrow = table.to_lance().to_table(columns=[key], filter=filter_expr) + existing_ids = set(existing_arrow.column(key).to_pylist()) + + missing_ids = [i for i in unique_input_ids if i not in existing_ids] + if missing_ids: + raise KeyError( + f"LanceDB.put: row(s) with {key}={missing_ids!r} not found in table " + f"{target_name!r}; put() only updates existing rows." + ) + + put_start = time.perf_counter() + table.merge_insert(key).when_matched_update_all().execute(rows) + _record_timing( + "lancedb.put", + time.perf_counter() - put_start, + {"rows": len(rows), "table": target_name}, + ) + + counts["put"] = len(rows) + return counts + def retrieval(self, vectors, **kwargs): """Search LanceDB with precomputed query vectors. diff --git a/nemo_retriever/src/nemo_retriever/vdb/operators.py b/nemo_retriever/src/nemo_retriever/vdb/operators.py index 4b284d78cf..807ae7736b 100644 --- a/nemo_retriever/src/nemo_retriever/vdb/operators.py +++ b/nemo_retriever/src/nemo_retriever/vdb/operators.py @@ -134,6 +134,60 @@ def postprocess(self, data: Any, **kwargs: Any) -> Any: return data +class PutVdbOperator(IngestVdbOperator): + """Replace existing rows of a VDB table in place on a stable row key. + + Unlike :class:`IngestVdbOperator` (which orchestrates create_index + + write_to_index, optionally overwriting the whole table), this operator + calls ``vdb.put(records, ...)`` so that only rows whose ``key`` is in + ``records`` are touched. Existing rows that match by ``key`` are + replaced; rows in ``records`` whose ``key`` is not already present in + the table raise :class:`KeyError` (``put`` never inserts new rows), + and rows in the table that are not referenced are left untouched. + + The underlying VDB implementation must override + :meth:`~nemo_retriever.vdb.adt_vdb.VDB.put` with a real + stable-key in-place replace; currently this is implemented by + :class:`~nemo_retriever.vdb.lancedb.LanceDB`. ``VDB.put`` itself + raises :class:`NotImplementedError`, so backends that have not + overridden it are detected at construction time and fail fast rather + than silently no-oping at runtime. + """ + + def __init__( + self, + *, + vdb: VDB | None = None, + vdb_op: str | None = None, + vdb_kwargs: dict[str, Any] | None = None, + key: str = "id", + table_name: str | None = None, + ) -> None: + super().__init__(vdb=vdb, vdb_op=vdb_op, vdb_kwargs=vdb_kwargs) + # ``put`` is part of the abstract VDB contract, but the base + # class provides a NotImplementedError stub for backends that + # cannot support stable-key puts. Treat a not-overridden stub + # as "unsupported" so misuse surfaces here instead of at the + # first write. + if getattr(type(self._vdb), "put", None) is VDB.put: + raise NotImplementedError(f"VDB backend {type(self._vdb).__name__!r} does not implement put(); ") + self._key = key + self._table_name = table_name + + def process(self, data: Any, **kwargs: Any) -> Any: + records = to_client_vdb_records(data) + if self._sidecar_spec is not None and self._sidecar_lookup is not None: + records = apply_sidecar_metadata_to_client_batches( + records, + lookup=self._sidecar_lookup, + meta_fields=self._sidecar_spec["meta_fields"], + join_key=self._sidecar_spec["meta_join_key"], + ) + if records and any(batch for batch in records): + self._vdb.put(records, table_name=self._table_name, key=self._key) + return data + + class RetrieveVdbOperator(AbstractOperator): """Retrieve hits from an nv-ingest-client VDB using precomputed query vectors.""" diff --git a/nemo_retriever/tests/test_nv_ingest_vdb_operator.py b/nemo_retriever/tests/test_nv_ingest_vdb_operator.py index eb79bde282..70c1f5bf43 100644 --- a/nemo_retriever/tests/test_nv_ingest_vdb_operator.py +++ b/nemo_retriever/tests/test_nv_ingest_vdb_operator.py @@ -6,11 +6,13 @@ from typing import Any +import pandas as pd import pytest from nemo_retriever.vdb.adt_vdb import VDB from nemo_retriever.vdb import IngestVdbOperator, RetrieveVdbOperator from nemo_retriever.vdb import operators as vdb_operator_module +from nemo_retriever.vdb.operators import PutVdbOperator class FakeVDB(VDB): @@ -18,6 +20,7 @@ def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self.run_calls: list[Any] = [] self.retrieval_calls: list[tuple[Any, dict[str, Any]]] = [] + self.put_calls: list[tuple[Any, dict[str, Any]]] = [] def create_index(self, **kwargs: Any) -> None: return None @@ -47,6 +50,10 @@ def run(self, records: Any) -> dict[str, Any]: self.run_calls.append(records) return {"records": records} + def put(self, records: list, **kwargs: Any) -> dict[str, Any]: + self.put_calls.append((records, dict(kwargs))) + return {"put": sum(len(b) for b in records)} + def _graph_rows() -> list[dict[str, Any]]: return [ @@ -167,3 +174,131 @@ def test_constructor_requires_exactly_one_vdb_source() -> None: with pytest.raises(ValueError, match="Pass either vdb or vdb_op"): IngestVdbOperator(vdb=FakeVDB(), vdb_op="lancedb") + + +# ────────────────────────────────────────────────────────────────────────────── +# PutVdbOperator +# ────────────────────────────────────────────────────────────────────────────── + + +class _StubPutVDB(VDB): + """VDB subclass that intentionally does NOT override ``put``. + + Used to exercise the construction-time guard in + :class:`PutVdbOperator.__init__`, which compares + ``type(self._vdb).put is VDB.put`` to detect backends that + inherit the base-class ``NotImplementedError`` stub. + + Note: this class being instantiable at all is itself a regression + check — :meth:`VDB.put` must NOT be decorated with + ``@abstractmethod``; otherwise ABC machinery would reject this class + before the operator-level guard could run, making the guard dead code. + """ + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + def create_index(self, **kwargs: Any) -> None: + return None + + def write_to_index(self, records: list, **kwargs: Any) -> None: + return None + + def retrieval(self, queries: list, **kwargs: Any) -> list[list[dict[str, Any]]]: + return [] + + def run(self, records: Any) -> None: + return None + + +def test_put_operator_rejects_vdb_without_put_override() -> None: + """Backends inheriting the ``VDB.put`` stub fail fast at construction.""" + stub = _StubPutVDB() + # Sanity-check the precondition the guard relies on: the subclass really + # is using the inherited stub, not its own implementation. If this ever + # fails, the guard's identity comparison would silently never fire. + assert type(stub).put is VDB.put + + with pytest.raises(NotImplementedError, match=r"does not implement put"): + PutVdbOperator(vdb=stub) + + +def test_put_operator_delegates_records_with_configured_key_and_table_name() -> None: + """Happy path: nv-ingest-converted records reach ``vdb.put`` with the configured key/table.""" + vdb = FakeVDB() + operator = PutVdbOperator(vdb=vdb, key="entity_id", table_name="entities") + + data = [ + { + "text": "graph chunk", + "text_embeddings_1b_v2": {"embedding": [0.1] * 2048}, + "source_id": "/tmp/doc-a.pdf", + "page_number": 7, + } + ] + + assert operator(data) is data + + assert vdb.run_calls == [] + assert len(vdb.put_calls) == 1 + call_records, call_kwargs = vdb.put_calls[0] + assert call_kwargs == {"table_name": "entities", "key": "entity_id"} + # The records that reach the backend must already be in nv-ingest-client + # shape (same conversion IngestVdbOperator performs), not the flat graph rows. + assert call_records == [ + [ + { + "document_type": "text", + "metadata": { + "embedding": [0.1] * 2048, + "content": "graph chunk", + "content_metadata": {"page_number": 7}, + "source_metadata": { + "source_id": "/tmp/doc-a.pdf", + "source_name": "doc-a.pdf", + }, + }, + } + ] + ] + + +def test_put_operator_merges_sidecar_metadata_into_records_before_put() -> None: + """Sidecar kwargs are split out from ``vdb_kwargs`` and applied before delegation.""" + vdb = FakeVDB() + meta_df = pd.DataFrame( + { + "source_id": ["/tmp/doc-a.pdf"], + "category": ["legal"], + } + ) + operator = PutVdbOperator( + vdb=vdb, + vdb_kwargs={ + "meta_dataframe": meta_df, + "meta_source_field": "source_id", + "meta_fields": ["category"], + "meta_join_key": "source_id", + }, + key="id", + table_name="my_table", + ) + + data = [ + { + "text": "graph chunk", + "text_embeddings_1b_v2": {"embedding": [0.1] * 2048}, + "source_id": "/tmp/doc-a.pdf", + "page_number": 7, + } + ] + + assert operator.process(data) is data + + assert len(vdb.put_calls) == 1 + call_records, call_kwargs = vdb.put_calls[0] + assert call_kwargs == {"table_name": "my_table", "key": "id"} + merged_content_meta = call_records[0][0]["metadata"]["content_metadata"] + # Sidecar column merged in alongside the per-row ``page_number``. + assert merged_content_meta["category"] == "legal" + assert merged_content_meta["page_number"] == 7