diff --git a/nemo_retriever/src/nemo_retriever/rerank/rerank.py b/nemo_retriever/src/nemo_retriever/rerank/rerank.py index a98e1e414..58ecb68bc 100644 --- a/nemo_retriever/src/nemo_retriever/rerank/rerank.py +++ b/nemo_retriever/src/nemo_retriever/rerank/rerank.py @@ -51,6 +51,7 @@ from __future__ import annotations +from collections.abc import Hashable import json import logging import traceback @@ -502,19 +503,50 @@ def _rerank_batch( images_b64 = batch_df[image_column].tolist() if rerank_invoke_url: - # Remote endpoint: score pair-by-pair (each row may have a different query). - scores: List[float] = [] + # Remote endpoint: batch all passages that share a query into one request. + # The long-form DataFrame can contain different queries in the same Ray + # batch, so keep per-row score alignment when expanding grouped responses. + groups: dict[Any, dict[str, Any]] = {} for i, (q, d) in enumerate(pairs): - img = [images_b64[i]] if images_b64 else None + if not isinstance(q, Hashable): + logger.warning( + "Query at row %d is not hashable (%s); it will be sent in its own request " + "and cannot be batched with identical queries.", + i, + type(q).__name__, + ) + key = q if isinstance(q, Hashable) else ("__unhashable_query__", i) + group = groups.setdefault( + key, + { + "query": q, + "indices": [], + "documents": [], + "images_b64": [] if images_b64 is not None else None, + }, + ) + group["indices"].append(i) + group["documents"].append(d) + if images_b64 is not None: + group["images_b64"].append(images_b64[i]) + + scores = [float("-inf")] * len(pairs) + for group in groups.values(): row_scores = _rerank_via_endpoint( - q, - [d], + group["query"], + group["documents"], endpoint=rerank_invoke_url, model_name=model_name, api_key=api_key, - images_b64=img, + images_b64=group["images_b64"], ) - scores.append(row_scores[0]) + if len(row_scores) != len(group["indices"]): + raise RuntimeError( + f"Endpoint returned {len(row_scores)} scores for a batch of " + f"{len(group['indices'])} documents; score alignment is broken." + ) + for row_index, score in zip(group["indices"], row_scores): + scores[row_index] = score elif model is not None: if images_b64 is not None: scores = model.score_pairs(pairs, images_b64=images_b64, max_length=max_length, batch_size=batch_size) diff --git a/nemo_retriever/tests/test_nemotron_rerank_v2.py b/nemo_retriever/tests/test_nemotron_rerank_v2.py index 6901a7c7e..1288681a9 100644 --- a/nemo_retriever/tests/test_nemotron_rerank_v2.py +++ b/nemo_retriever/tests/test_nemotron_rerank_v2.py @@ -572,17 +572,86 @@ def test_actor_call_sorts_descending_by_default(self): mock_resp = MagicMock() mock_resp.raise_for_status = MagicMock() - mock_resp.json.side_effect = [ - {"rankings": [{"index": 0, "logit": 0.1}]}, - {"rankings": [{"index": 0, "logit": 0.9}]}, - ] + mock_resp.json.return_value = { + "rankings": [ + {"index": 0, "logit": 0.1}, + {"index": 1, "logit": 0.9}, + ] + } - with patch("requests.post", return_value=mock_resp): + with patch("requests.post", return_value=mock_resp) as mock_post: out = actor(df) + mock_post.assert_called_once() + payload = mock_post.call_args[1]["json"] + assert payload["query"] == {"text": "q"} + assert payload["passages"] == [{"text": "low relevance"}, {"text": "high relevance"}] scores = out["rerank_score"].tolist() assert scores == sorted(scores, reverse=True) + def test_actor_call_batches_remote_rows_by_query(self): + import pandas as pd + from nemo_retriever.rerank.rerank import NemotronRerankActor + + actor = NemotronRerankActor(rerank_invoke_url="http://localhost:8000", sort_results=False) + df = pd.DataFrame( + { + "query": ["q1", "q1", "q2"], + "text": ["doc A", "doc B", "doc C"], + } + ) + + resp_q1 = MagicMock() + resp_q1.raise_for_status = MagicMock() + resp_q1.json.return_value = { + "rankings": [ + {"index": 1, "logit": 0.7}, + {"index": 0, "logit": 0.2}, + ] + } + resp_q2 = MagicMock() + resp_q2.raise_for_status = MagicMock() + resp_q2.json.return_value = {"rankings": [{"index": 0, "logit": 0.9}]} + + with patch("requests.post", side_effect=[resp_q1, resp_q2]) as mock_post: + out = actor(df) + + assert mock_post.call_count == 2 + assert mock_post.call_args_list[0][1]["json"]["query"] == {"text": "q1"} + assert mock_post.call_args_list[0][1]["json"]["passages"] == [{"text": "doc A"}, {"text": "doc B"}] + assert mock_post.call_args_list[1][1]["json"]["query"] == {"text": "q2"} + assert mock_post.call_args_list[1][1]["json"]["passages"] == [{"text": "doc C"}] + assert out["rerank_score"].tolist() == [0.2, 0.7, 0.9] + + def test_rerank_batch_raises_when_endpoint_score_count_mismatches(self): + import pandas as pd + from nemo_retriever.rerank.rerank import _rerank_batch + + df = pd.DataFrame({"query": ["q", "q"], "text": ["doc A", "doc B"]}) + + with ( + patch("nemo_retriever.rerank.rerank._rerank_via_endpoint", return_value=[0.2]), + pytest.raises(RuntimeError, match="score alignment is broken"), + ): + _rerank_batch(df, rerank_invoke_url="http://localhost:8000", sort_results=False) + + def test_rerank_batch_warns_when_unhashable_queries_cannot_batch(self, caplog): + import logging + + import pandas as pd + from nemo_retriever.rerank.rerank import _rerank_batch + + caplog.set_level(logging.WARNING, logger="nemo_retriever.rerank.rerank") + df = pd.DataFrame({"query": [["q"], ["q"]], "text": ["doc A", "doc B"]}) + + with patch("nemo_retriever.rerank.rerank._rerank_via_endpoint", side_effect=[[0.2], [0.7]]) as mock_rerank: + out = _rerank_batch(df, rerank_invoke_url="http://localhost:8000", sort_results=False) + + assert mock_rerank.call_count == 2 + assert "Query at row 0 is not hashable (list)" in caplog.text + assert "Query at row 1 is not hashable (list)" in caplog.text + assert out["rerank_score"].tolist() == [0.2, 0.7] + def test_actor_call_returns_error_payload_on_exception(self): import pandas as pd from nemo_retriever.rerank.rerank import NemotronRerankActor