diff --git a/nemo_retriever/src/nemo_retriever/retriever.py b/nemo_retriever/src/nemo_retriever/retriever.py index efafc0932..1d90293e6 100644 --- a/nemo_retriever/src/nemo_retriever/retriever.py +++ b/nemo_retriever/src/nemo_retriever/retriever.py @@ -179,6 +179,9 @@ def _execute_queries_graph( text_col = str(embed_params.text_column) df = pd.DataFrame({text_col: query_texts}) + # Hybrid retrieval relies on these ordered query strings staying aligned + # with the embedded rows produced from ``df``. If this query graph grows + # distributed/shuffled stages, carry row-local query text or IDs instead. graph = self._get_graph(embed_extra=embed_extra) if not callable(getattr(graph, "resolve_for_local_execution", None)): raise TypeError("graph must provide resolve_for_local_execution() (e.g. pipeline_graph.Graph)") diff --git a/nemo_retriever/src/nemo_retriever/vdb/lancedb.py b/nemo_retriever/src/nemo_retriever/vdb/lancedb.py index 1586597bf..1cc49d2f1 100644 --- a/nemo_retriever/src/nemo_retriever/vdb/lancedb.py +++ b/nemo_retriever/src/nemo_retriever/vdb/lancedb.py @@ -7,6 +7,7 @@ import os import time +from collections.abc import Iterable, Sequence from datetime import timedelta from typing import Any, Final, FrozenSet @@ -529,7 +530,7 @@ def run(self, records): logger.info("Skipping LanceDB index creation for table %r because build_index=False.", self.table_name) return records - def retrieval(self, vectors, **kwargs): + def retrieval(self, vectors: Iterable[Sequence[float]], **kwargs: Any) -> list[list[dict[str, Any]]]: """Search LanceDB with precomputed query vectors. Keyword arguments @@ -546,10 +547,12 @@ def retrieval(self, vectors, **kwargs): ``table.search`` (e.g. ``query_type``, ``fts_columns``). Do not pass ``vector_column_name`` here; use the top-level ``vector_column_name`` retrieval argument instead. + query_texts: + Raw query strings aligned with ``vectors``. Required for + ``hybrid=True`` and ignored for dense-only retrieval. """ hybrid = kwargs.pop("hybrid", self.hybrid) - if hybrid: - raise NotImplementedError("LanceDB hybrid retrieval with precomputed vectors is not implemented yet.") + query_texts = kwargs.pop("query_texts", None) table_path = kwargs.pop("table_path", self.uri) table_name = kwargs.pop("table_name", self.table_name) @@ -567,6 +570,23 @@ def retrieval(self, vectors, **kwargs): else: search_kwargs = dict(search_kwargs_raw) + if hybrid: + if query_texts is None: + raise ValueError( + "LanceDB hybrid retrieval requires query_texts. Pass query_texts=your_queries " + "alongside vectors when calling retrieval() with hybrid=True." + ) + query_type = search_kwargs.get("query_type") + if query_type is not None: + query_type_value = getattr(query_type, "value", query_type) + if str(query_type_value).lower() != "hybrid": + raise ValueError( + "LanceDB hybrid retrieval requires search_kwargs['query_type']='hybrid'; " + f"got {query_type!r}." + ) + search_kwargs["query_type"] = "hybrid" + search_kwargs.setdefault("fts_columns", "text") + where_clause = kwargs.pop("where", None) _filter_fallback = kwargs.pop("_filter", None) if where_clause is None: @@ -576,9 +596,28 @@ def retrieval(self, vectors, **kwargs): table = lancedb.connect(uri=table_path).open_table(table_name) + if hybrid: + vectors_for_search = list(vectors) + query_texts_list = [query_texts] if isinstance(query_texts, str) else list(query_texts) + if len(query_texts_list) != len(vectors_for_search): + raise ValueError( + "LanceDB hybrid retrieval requires query_texts length to match vectors length; " + f"got query_texts={len(query_texts_list)} vectors={len(vectors_for_search)}." + ) + else: + vectors_for_search = vectors + query_texts_list = [] + search_results = [] - for vector in vectors: - query = table.search([vector], vector_column_name=vector_column_name, **search_kwargs) + for idx, vector in enumerate(vectors_for_search): + if hybrid: + query = ( + table.search(vector_column_name=vector_column_name, **search_kwargs) + .vector(vector) + .text(str(query_texts_list[idx])) + ) + else: + query = table.search([vector], vector_column_name=vector_column_name, **search_kwargs) if where_clause is not None: query = query.where(where_clause) query = query.limit(top_k).refine_factor(refine_factor).nprobes(n_probe) diff --git a/nemo_retriever/src/nemo_retriever/vdb/operators.py b/nemo_retriever/src/nemo_retriever/vdb/operators.py index 4b284d78c..bece8253c 100644 --- a/nemo_retriever/src/nemo_retriever/vdb/operators.py +++ b/nemo_retriever/src/nemo_retriever/vdb/operators.py @@ -147,6 +147,7 @@ def __init__( ) -> None: merged = dict(vdb_kwargs or {}) clean_kwargs, _sidecar = split_sidecar_from_vdb_kwargs(merged) + clean_kwargs.pop("query_texts", None) super().__init__(vdb=vdb, vdb_op=vdb_op, vdb_kwargs=clean_kwargs, explode_for_rerank=explode_for_rerank) self._vdb_kwargs = clean_kwargs self._retrieval_vdb_kwargs = clean_kwargs @@ -162,6 +163,8 @@ def process(self, data: Any, **kwargs: Any) -> list[list[dict[str, Any]]]: from nemo_retriever.retriever_graph_utils import filter_retrieval_kwargs retrieval_kwargs = {**self._retrieval_vdb_kwargs, **filter_retrieval_kwargs(kwargs)} + if retrieval_kwargs.get("hybrid") and "query_texts" in kwargs: + retrieval_kwargs["query_texts"] = kwargs["query_texts"] return normalize_retrieval_results(self._vdb.retrieval(data, **retrieval_kwargs)) def postprocess(self, data: Any, **kwargs: Any) -> Any: diff --git a/nemo_retriever/tests/test_lancedb_retrieval_where.py b/nemo_retriever/tests/test_lancedb_retrieval_where.py index ffdb534f2..a0fcc7d89 100644 --- a/nemo_retriever/tests/test_lancedb_retrieval_where.py +++ b/nemo_retriever/tests/test_lancedb_retrieval_where.py @@ -17,7 +17,7 @@ from nemo_retriever.vdb.lancedb import LanceDB -def _tiny_table(uri: str) -> None: +def _tiny_table(uri: str, *, create_fts_index: bool = False) -> None: schema = pa.schema( [ pa.field("vector", pa.list_(pa.float32(), 2)), @@ -41,7 +41,9 @@ def _tiny_table(uri: str) -> None: }, ] db = lancedb.connect(uri) - db.create_table("t", rows, schema=schema, mode="overwrite") + table = db.create_table("t", rows, schema=schema, mode="overwrite") + if create_fts_index: + table.create_fts_index("text", replace=True) def test_retrieval_where_filters_rows() -> None: @@ -101,3 +103,82 @@ def test_retrieval_search_kwargs_must_be_dict() -> None: op = LanceDB(uri=d, table_name="t", overwrite=False, vector_dim=2, validate_vector_length=False) with pytest.raises(TypeError, match="search_kwargs"): op.retrieval([[1.0, 0.0]], top_k=5, table_path=d, table_name="t", search_kwargs="bad") + + +def test_hybrid_retrieval_uses_query_texts() -> None: + d = tempfile.mkdtemp() + _tiny_table(d, create_fts_index=True) + op = LanceDB(uri=d, table_name="t", overwrite=False, vector_dim=2, validate_vector_length=False) + + results = op.retrieval( + [[1.0, 0.0]], + top_k=2, + table_path=d, + table_name="t", + hybrid=True, + query_texts=["alpha"], + ) + + assert results[0] + assert results[0][0]["text"] == "alpha" + + +def test_hybrid_retrieval_requires_query_texts() -> None: + d = tempfile.mkdtemp() + _tiny_table(d, create_fts_index=True) + op = LanceDB(uri=d, table_name="t", overwrite=False, vector_dim=2, validate_vector_length=False) + + with pytest.raises(ValueError, match="requires query_texts"): + op.retrieval([[1.0, 0.0]], top_k=2, table_path=d, table_name="t", hybrid=True) + + +def test_hybrid_retrieval_requires_query_texts_aligned_with_vectors() -> None: + d = tempfile.mkdtemp() + _tiny_table(d, create_fts_index=True) + op = LanceDB(uri=d, table_name="t", overwrite=False, vector_dim=2, validate_vector_length=False) + + with pytest.raises(ValueError, match="length to match vectors length"): + op.retrieval( + [[1.0, 0.0]], + top_k=2, + table_path=d, + table_name="t", + hybrid=True, + query_texts=["alpha", "beta"], + ) + + +def test_hybrid_retrieval_where_filters_rows() -> None: + d = tempfile.mkdtemp() + _tiny_table(d, create_fts_index=True) + op = LanceDB(uri=d, table_name="t", overwrite=False, vector_dim=2, validate_vector_length=False) + + filtered = op.retrieval( + [[1.0, 0.0]], + top_k=10, + table_path=d, + table_name="t", + hybrid=True, + query_texts=["beta"], + where="text = 'beta'", + ) + + assert len(filtered[0]) == 1 + assert filtered[0][0]["text"] == "beta" + + +def test_hybrid_retrieval_rejects_non_hybrid_query_type() -> None: + d = tempfile.mkdtemp() + _tiny_table(d, create_fts_index=True) + op = LanceDB(uri=d, table_name="t", overwrite=False, vector_dim=2, validate_vector_length=False) + + with pytest.raises(ValueError, match="query_type"): + op.retrieval( + [[1.0, 0.0]], + top_k=2, + table_path=d, + table_name="t", + hybrid=True, + query_texts=["alpha"], + search_kwargs={"query_type": "vector"}, + ) diff --git a/nemo_retriever/tests/test_nv_ingest_vdb_operator.py b/nemo_retriever/tests/test_nv_ingest_vdb_operator.py index eb79bde28..db1b2d625 100644 --- a/nemo_retriever/tests/test_nv_ingest_vdb_operator.py +++ b/nemo_retriever/tests/test_nv_ingest_vdb_operator.py @@ -161,6 +161,38 @@ def test_retrieve_operator_delegates_vectors_to_retrieval() -> None: assert vdb.retrieval_calls == [([[0.1, 0.2]], {"collection_name": "docs", "model_name": "embedder", "top_k": 3})] +def test_retrieve_operator_forwards_runtime_query_texts() -> None: + vdb = FakeVDB() + operator = RetrieveVdbOperator( + vdb=vdb, + vdb_kwargs={"collection_name": "docs", "model_name": "embedder", "hybrid": True, "query_texts": ["stale"]}, + ) + + operator.process([[0.1, 0.2]], top_k=3, query_texts=["current"]) + + assert vdb.retrieval_calls == [ + ( + [[0.1, 0.2]], + { + "collection_name": "docs", + "model_name": "embedder", + "hybrid": True, + "top_k": 3, + "query_texts": ["current"], + }, + ) + ] + + +def test_retrieve_operator_does_not_forward_query_texts_for_dense_retrieval() -> None: + vdb = FakeVDB() + operator = RetrieveVdbOperator(vdb=vdb, vdb_kwargs={"collection_name": "docs", "model_name": "embedder"}) + + operator.process([[0.1, 0.2]], top_k=3, query_texts=["current"]) + + assert vdb.retrieval_calls == [([[0.1, 0.2]], {"collection_name": "docs", "model_name": "embedder", "top_k": 3})] + + def test_constructor_requires_exactly_one_vdb_source() -> None: with pytest.raises(ValueError, match="Either vdb or vdb_op is required"): IngestVdbOperator()