Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions nemo_retriever/src/nemo_retriever/rerank/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

from __future__ import annotations

from collections.abc import Hashable
import json
import logging
import traceback
Expand Down Expand Up @@ -502,19 +503,50 @@ def _rerank_batch(
images_b64 = batch_df[image_column].tolist()

if rerank_invoke_url:
# Remote endpoint: score pair-by-pair (each row may have a different query).
scores: List[float] = []
# Remote endpoint: batch all passages that share a query into one request.
# The long-form DataFrame can contain different queries in the same Ray
# batch, so keep per-row score alignment when expanding grouped responses.
groups: dict[Any, dict[str, Any]] = {}
for i, (q, d) in enumerate(pairs):
img = [images_b64[i]] if images_b64 else None
if not isinstance(q, Hashable):
logger.warning(
"Query at row %d is not hashable (%s); it will be sent in its own request "
"and cannot be batched with identical queries.",
i,
type(q).__name__,
)
key = q if isinstance(q, Hashable) else ("__unhashable_query__", i)
Comment thread
ChrisJar marked this conversation as resolved.
group = groups.setdefault(
key,
{
"query": q,
"indices": [],
"documents": [],
"images_b64": [] if images_b64 is not None else None,
},
)
group["indices"].append(i)
group["documents"].append(d)
if images_b64 is not None:
group["images_b64"].append(images_b64[i])

scores = [float("-inf")] * len(pairs)
for group in groups.values():
row_scores = _rerank_via_endpoint(
q,
[d],
group["query"],
group["documents"],
endpoint=rerank_invoke_url,
model_name=model_name,
api_key=api_key,
images_b64=img,
images_b64=group["images_b64"],
)
scores.append(row_scores[0])
if len(row_scores) != len(group["indices"]):
raise RuntimeError(
f"Endpoint returned {len(row_scores)} scores for a batch of "
f"{len(group['indices'])} documents; score alignment is broken."
)
for row_index, score in zip(group["indices"], row_scores):
scores[row_index] = score
Comment thread
ChrisJar marked this conversation as resolved.
elif model is not None:
if images_b64 is not None:
scores = model.score_pairs(pairs, images_b64=images_b64, max_length=max_length, batch_size=batch_size)
Expand Down
79 changes: 74 additions & 5 deletions nemo_retriever/tests/test_nemotron_rerank_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,17 +572,86 @@ def test_actor_call_sorts_descending_by_default(self):

mock_resp = MagicMock()
mock_resp.raise_for_status = MagicMock()
mock_resp.json.side_effect = [
{"rankings": [{"index": 0, "logit": 0.1}]},
{"rankings": [{"index": 0, "logit": 0.9}]},
]
mock_resp.json.return_value = {
"rankings": [
{"index": 0, "logit": 0.1},
{"index": 1, "logit": 0.9},
]
}

with patch("requests.post", return_value=mock_resp):
with patch("requests.post", return_value=mock_resp) as mock_post:
out = actor(df)

mock_post.assert_called_once()
payload = mock_post.call_args[1]["json"]
assert payload["query"] == {"text": "q"}
assert payload["passages"] == [{"text": "low relevance"}, {"text": "high relevance"}]
scores = out["rerank_score"].tolist()
assert scores == sorted(scores, reverse=True)

def test_actor_call_batches_remote_rows_by_query(self):
import pandas as pd
from nemo_retriever.rerank.rerank import NemotronRerankActor

actor = NemotronRerankActor(rerank_invoke_url="http://localhost:8000", sort_results=False)
df = pd.DataFrame(
{
"query": ["q1", "q1", "q2"],
"text": ["doc A", "doc B", "doc C"],
}
)

resp_q1 = MagicMock()
resp_q1.raise_for_status = MagicMock()
resp_q1.json.return_value = {
"rankings": [
{"index": 1, "logit": 0.7},
{"index": 0, "logit": 0.2},
]
}
resp_q2 = MagicMock()
resp_q2.raise_for_status = MagicMock()
resp_q2.json.return_value = {"rankings": [{"index": 0, "logit": 0.9}]}

with patch("requests.post", side_effect=[resp_q1, resp_q2]) as mock_post:
out = actor(df)

assert mock_post.call_count == 2
assert mock_post.call_args_list[0][1]["json"]["query"] == {"text": "q1"}
assert mock_post.call_args_list[0][1]["json"]["passages"] == [{"text": "doc A"}, {"text": "doc B"}]
assert mock_post.call_args_list[1][1]["json"]["query"] == {"text": "q2"}
assert mock_post.call_args_list[1][1]["json"]["passages"] == [{"text": "doc C"}]
assert out["rerank_score"].tolist() == [0.2, 0.7, 0.9]

def test_rerank_batch_raises_when_endpoint_score_count_mismatches(self):
import pandas as pd
from nemo_retriever.rerank.rerank import _rerank_batch

df = pd.DataFrame({"query": ["q", "q"], "text": ["doc A", "doc B"]})

with (
patch("nemo_retriever.rerank.rerank._rerank_via_endpoint", return_value=[0.2]),
pytest.raises(RuntimeError, match="score alignment is broken"),
):
_rerank_batch(df, rerank_invoke_url="http://localhost:8000", sort_results=False)

def test_rerank_batch_warns_when_unhashable_queries_cannot_batch(self, caplog):
import logging

import pandas as pd
from nemo_retriever.rerank.rerank import _rerank_batch

caplog.set_level(logging.WARNING, logger="nemo_retriever.rerank.rerank")
df = pd.DataFrame({"query": [["q"], ["q"]], "text": ["doc A", "doc B"]})

with patch("nemo_retriever.rerank.rerank._rerank_via_endpoint", side_effect=[[0.2], [0.7]]) as mock_rerank:
out = _rerank_batch(df, rerank_invoke_url="http://localhost:8000", sort_results=False)

assert mock_rerank.call_count == 2
assert "Query at row 0 is not hashable (list)" in caplog.text
assert "Query at row 1 is not hashable (list)" in caplog.text
assert out["rerank_score"].tolist() == [0.2, 0.7]

def test_actor_call_returns_error_payload_on_exception(self):
import pandas as pd
from nemo_retriever.rerank.rerank import NemotronRerankActor
Expand Down
Loading