From 6410d755299a8e8882f2738a1971589ada2168bd Mon Sep 17 00:00:00 2001 From: jioffe502 Date: Thu, 14 May 2026 18:30:41 +0000 Subject: [PATCH 1/3] Implement LanceDB hybrid retrieval --- .../src/nemo_retriever/vdb/lancedb.py | 45 +++++++++- .../src/nemo_retriever/vdb/operators.py | 3 + .../tests/test_lancedb_retrieval_where.py | 85 ++++++++++++++++++- .../tests/test_nv_ingest_vdb_operator.py | 32 +++++++ 4 files changed, 159 insertions(+), 6 deletions(-) diff --git a/nemo_retriever/src/nemo_retriever/vdb/lancedb.py b/nemo_retriever/src/nemo_retriever/vdb/lancedb.py index 74f6b509d4..5cc3c839d0 100644 --- a/nemo_retriever/src/nemo_retriever/vdb/lancedb.py +++ b/nemo_retriever/src/nemo_retriever/vdb/lancedb.py @@ -546,10 +546,12 @@ def retrieval(self, vectors, **kwargs): ``table.search`` (e.g. ``query_type``, ``fts_columns``). Do not pass ``vector_column_name`` here; use the top-level ``vector_column_name`` retrieval argument instead. + query_texts: + Raw query strings aligned with ``vectors``. Required for + ``hybrid=True`` and ignored for dense-only retrieval. """ hybrid = kwargs.pop("hybrid", self.hybrid) - if hybrid: - raise NotImplementedError("LanceDB hybrid retrieval with precomputed vectors is not implemented yet.") + query_texts = kwargs.pop("query_texts", None) table_path = kwargs.pop("table_path", self.uri) table_name = kwargs.pop("table_name", self.table_name) @@ -567,6 +569,23 @@ def retrieval(self, vectors, **kwargs): else: search_kwargs = dict(search_kwargs_raw) + if hybrid: + if query_texts is None: + raise ValueError( + "LanceDB hybrid retrieval requires query_texts because it needs raw query text " + "in addition to precomputed vectors." + ) + query_type = search_kwargs.get("query_type") + if query_type is not None: + query_type_value = getattr(query_type, "value", query_type) + if str(query_type_value).lower() != "hybrid": + raise ValueError( + "LanceDB hybrid retrieval requires search_kwargs['query_type']='hybrid'; " + f"got {query_type!r}." + ) + search_kwargs["query_type"] = "hybrid" + search_kwargs.setdefault("fts_columns", "text") + where_clause = kwargs.pop("where", None) _filter_fallback = kwargs.pop("_filter", None) if where_clause is None: @@ -576,9 +595,27 @@ def retrieval(self, vectors, **kwargs): table = lancedb.connect(uri=table_path).open_table(table_name) + vectors_list = list(vectors) + if hybrid: + query_texts_list = [query_texts] if isinstance(query_texts, str) else list(query_texts) + if len(query_texts_list) != len(vectors_list): + raise ValueError( + "LanceDB hybrid retrieval requires query_texts length to match vectors length; " + f"got query_texts={len(query_texts_list)} vectors={len(vectors_list)}." + ) + else: + query_texts_list = [] + search_results = [] - for vector in vectors: - query = table.search([vector], vector_column_name=vector_column_name, **search_kwargs) + for idx, vector in enumerate(vectors_list): + if hybrid: + query = ( + table.search(vector_column_name=vector_column_name, **search_kwargs) + .vector(vector) + .text(str(query_texts_list[idx])) + ) + else: + query = table.search([vector], vector_column_name=vector_column_name, **search_kwargs) if where_clause is not None: query = query.where(where_clause) query = query.limit(top_k).refine_factor(refine_factor).nprobes(n_probe) diff --git a/nemo_retriever/src/nemo_retriever/vdb/operators.py b/nemo_retriever/src/nemo_retriever/vdb/operators.py index caf5c87710..07fe88baed 100644 --- a/nemo_retriever/src/nemo_retriever/vdb/operators.py +++ b/nemo_retriever/src/nemo_retriever/vdb/operators.py @@ -147,6 +147,7 @@ def __init__( ) -> None: merged = dict(vdb_kwargs or {}) clean_kwargs, _sidecar = split_sidecar_from_vdb_kwargs(merged) + clean_kwargs.pop("query_texts", None) super().__init__(vdb=vdb, vdb_op=vdb_op, vdb_kwargs=clean_kwargs, explode_for_rerank=explode_for_rerank) self._vdb_kwargs = clean_kwargs self._retrieval_vdb_kwargs = clean_kwargs @@ -162,6 +163,8 @@ def process(self, data: Any, **kwargs: Any) -> list[list[dict[str, Any]]]: from nemo_retriever.retriever_graph_utils import filter_retrieval_kwargs retrieval_kwargs = {**self._retrieval_vdb_kwargs, **filter_retrieval_kwargs(kwargs)} + if retrieval_kwargs.get("hybrid") and "query_texts" in kwargs: + retrieval_kwargs["query_texts"] = kwargs["query_texts"] return normalize_retrieval_results(self._vdb.retrieval(data, **retrieval_kwargs)) def postprocess(self, data: Any, **kwargs: Any) -> Any: diff --git a/nemo_retriever/tests/test_lancedb_retrieval_where.py b/nemo_retriever/tests/test_lancedb_retrieval_where.py index ffdb534f2f..a0fcc7d89e 100644 --- a/nemo_retriever/tests/test_lancedb_retrieval_where.py +++ b/nemo_retriever/tests/test_lancedb_retrieval_where.py @@ -17,7 +17,7 @@ from nemo_retriever.vdb.lancedb import LanceDB -def _tiny_table(uri: str) -> None: +def _tiny_table(uri: str, *, create_fts_index: bool = False) -> None: schema = pa.schema( [ pa.field("vector", pa.list_(pa.float32(), 2)), @@ -41,7 +41,9 @@ def _tiny_table(uri: str) -> None: }, ] db = lancedb.connect(uri) - db.create_table("t", rows, schema=schema, mode="overwrite") + table = db.create_table("t", rows, schema=schema, mode="overwrite") + if create_fts_index: + table.create_fts_index("text", replace=True) def test_retrieval_where_filters_rows() -> None: @@ -101,3 +103,82 @@ def test_retrieval_search_kwargs_must_be_dict() -> None: op = LanceDB(uri=d, table_name="t", overwrite=False, vector_dim=2, validate_vector_length=False) with pytest.raises(TypeError, match="search_kwargs"): op.retrieval([[1.0, 0.0]], top_k=5, table_path=d, table_name="t", search_kwargs="bad") + + +def test_hybrid_retrieval_uses_query_texts() -> None: + d = tempfile.mkdtemp() + _tiny_table(d, create_fts_index=True) + op = LanceDB(uri=d, table_name="t", overwrite=False, vector_dim=2, validate_vector_length=False) + + results = op.retrieval( + [[1.0, 0.0]], + top_k=2, + table_path=d, + table_name="t", + hybrid=True, + query_texts=["alpha"], + ) + + assert results[0] + assert results[0][0]["text"] == "alpha" + + +def test_hybrid_retrieval_requires_query_texts() -> None: + d = tempfile.mkdtemp() + _tiny_table(d, create_fts_index=True) + op = LanceDB(uri=d, table_name="t", overwrite=False, vector_dim=2, validate_vector_length=False) + + with pytest.raises(ValueError, match="requires query_texts"): + op.retrieval([[1.0, 0.0]], top_k=2, table_path=d, table_name="t", hybrid=True) + + +def test_hybrid_retrieval_requires_query_texts_aligned_with_vectors() -> None: + d = tempfile.mkdtemp() + _tiny_table(d, create_fts_index=True) + op = LanceDB(uri=d, table_name="t", overwrite=False, vector_dim=2, validate_vector_length=False) + + with pytest.raises(ValueError, match="length to match vectors length"): + op.retrieval( + [[1.0, 0.0]], + top_k=2, + table_path=d, + table_name="t", + hybrid=True, + query_texts=["alpha", "beta"], + ) + + +def test_hybrid_retrieval_where_filters_rows() -> None: + d = tempfile.mkdtemp() + _tiny_table(d, create_fts_index=True) + op = LanceDB(uri=d, table_name="t", overwrite=False, vector_dim=2, validate_vector_length=False) + + filtered = op.retrieval( + [[1.0, 0.0]], + top_k=10, + table_path=d, + table_name="t", + hybrid=True, + query_texts=["beta"], + where="text = 'beta'", + ) + + assert len(filtered[0]) == 1 + assert filtered[0][0]["text"] == "beta" + + +def test_hybrid_retrieval_rejects_non_hybrid_query_type() -> None: + d = tempfile.mkdtemp() + _tiny_table(d, create_fts_index=True) + op = LanceDB(uri=d, table_name="t", overwrite=False, vector_dim=2, validate_vector_length=False) + + with pytest.raises(ValueError, match="query_type"): + op.retrieval( + [[1.0, 0.0]], + top_k=2, + table_path=d, + table_name="t", + hybrid=True, + query_texts=["alpha"], + search_kwargs={"query_type": "vector"}, + ) diff --git a/nemo_retriever/tests/test_nv_ingest_vdb_operator.py b/nemo_retriever/tests/test_nv_ingest_vdb_operator.py index 8213405eea..ee882d9fa0 100644 --- a/nemo_retriever/tests/test_nv_ingest_vdb_operator.py +++ b/nemo_retriever/tests/test_nv_ingest_vdb_operator.py @@ -161,6 +161,38 @@ def test_retrieve_operator_delegates_vectors_to_retrieval() -> None: assert vdb.retrieval_calls == [([[0.1, 0.2]], {"collection_name": "docs", "model_name": "embedder", "top_k": 3})] +def test_retrieve_operator_forwards_runtime_query_texts() -> None: + vdb = FakeVDB() + operator = RetrieveVdbOperator( + vdb=vdb, + vdb_kwargs={"collection_name": "docs", "model_name": "embedder", "hybrid": True, "query_texts": ["stale"]}, + ) + + operator.process([[0.1, 0.2]], top_k=3, query_texts=["current"]) + + assert vdb.retrieval_calls == [ + ( + [[0.1, 0.2]], + { + "collection_name": "docs", + "model_name": "embedder", + "hybrid": True, + "top_k": 3, + "query_texts": ["current"], + }, + ) + ] + + +def test_retrieve_operator_does_not_forward_query_texts_for_dense_retrieval() -> None: + vdb = FakeVDB() + operator = RetrieveVdbOperator(vdb=vdb, vdb_kwargs={"collection_name": "docs", "model_name": "embedder"}) + + operator.process([[0.1, 0.2]], top_k=3, query_texts=["current"]) + + assert vdb.retrieval_calls == [([[0.1, 0.2]], {"collection_name": "docs", "model_name": "embedder", "top_k": 3})] + + def test_constructor_requires_exactly_one_vdb_source() -> None: with pytest.raises(ValueError, match="Either vdb or vdb_op is required"): IngestVdbOperator() From 1bf91683211a94d0f995e6f6f54f6e17cc465f9c Mon Sep 17 00:00:00 2001 From: jioffe502 Date: Thu, 14 May 2026 21:03:41 +0000 Subject: [PATCH 2/3] Document hybrid query text ordering assumption --- nemo_retriever/src/nemo_retriever/retriever.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nemo_retriever/src/nemo_retriever/retriever.py b/nemo_retriever/src/nemo_retriever/retriever.py index 8f6d36da58..47d991fe14 100644 --- a/nemo_retriever/src/nemo_retriever/retriever.py +++ b/nemo_retriever/src/nemo_retriever/retriever.py @@ -178,6 +178,9 @@ def _execute_queries_graph( text_col = str(embed_params.text_column) df = pd.DataFrame({text_col: query_texts}) + # Hybrid retrieval relies on these ordered query strings staying aligned + # with the embedded rows produced from ``df``. If this query graph grows + # distributed/shuffled stages, carry row-local query text or IDs instead. graph = self._get_graph(embed_extra=embed_extra) if not callable(getattr(graph, "resolve_for_local_execution", None)): raise TypeError("graph must provide resolve_for_local_execution() (e.g. pipeline_graph.Graph)") From d96dcf470b907afd9efd06486b28d74c3b850dc3 Mon Sep 17 00:00:00 2001 From: jioffe502 Date: Mon, 18 May 2026 15:07:59 +0000 Subject: [PATCH 3/3] Address LanceDB hybrid review comments --- nemo_retriever/src/nemo_retriever/vdb/lancedb.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/nemo_retriever/src/nemo_retriever/vdb/lancedb.py b/nemo_retriever/src/nemo_retriever/vdb/lancedb.py index 5cc3c839d0..9087006d3a 100644 --- a/nemo_retriever/src/nemo_retriever/vdb/lancedb.py +++ b/nemo_retriever/src/nemo_retriever/vdb/lancedb.py @@ -7,6 +7,7 @@ import os import time +from collections.abc import Iterable, Sequence from datetime import timedelta from typing import Any, Final, FrozenSet @@ -529,7 +530,7 @@ def run(self, records): logger.info("Skipping LanceDB index creation for table %r because build_index=False.", self.table_name) return records - def retrieval(self, vectors, **kwargs): + def retrieval(self, vectors: Iterable[Sequence[float]], **kwargs: Any) -> list[list[dict[str, Any]]]: """Search LanceDB with precomputed query vectors. Keyword arguments @@ -572,8 +573,8 @@ def retrieval(self, vectors, **kwargs): if hybrid: if query_texts is None: raise ValueError( - "LanceDB hybrid retrieval requires query_texts because it needs raw query text " - "in addition to precomputed vectors." + "LanceDB hybrid retrieval requires query_texts. Pass query_texts=your_queries " + "alongside vectors when calling retrieval() with hybrid=True." ) query_type = search_kwargs.get("query_type") if query_type is not None: @@ -595,19 +596,20 @@ def retrieval(self, vectors, **kwargs): table = lancedb.connect(uri=table_path).open_table(table_name) - vectors_list = list(vectors) if hybrid: + vectors_for_search = list(vectors) query_texts_list = [query_texts] if isinstance(query_texts, str) else list(query_texts) - if len(query_texts_list) != len(vectors_list): + if len(query_texts_list) != len(vectors_for_search): raise ValueError( "LanceDB hybrid retrieval requires query_texts length to match vectors length; " - f"got query_texts={len(query_texts_list)} vectors={len(vectors_list)}." + f"got query_texts={len(query_texts_list)} vectors={len(vectors_for_search)}." ) else: + vectors_for_search = vectors query_texts_list = [] search_results = [] - for idx, vector in enumerate(vectors_list): + for idx, vector in enumerate(vectors_for_search): if hybrid: query = ( table.search(vector_column_name=vector_column_name, **search_kwargs)