Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nemo_retriever/src/nemo_retriever/chart/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ def graphic_elements_ocr_page_elements(
_ocr_kw = dict(
invoke_url=ocr_url,
image_b64_list=flat_crop_b64s,
merge_levels=["word"] * len(flat_crop_b64s),
api_key=api_key or None,
timeout_s=float(request_timeout_s),
max_batch_size=inference_batch_size,
Expand Down
25 changes: 23 additions & 2 deletions nemo_retriever/src/nemo_retriever/ocr/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,28 @@
# content like tables/charts/infographics). Used by the OCR stage to
# decide which detections contribute to the page's ``text`` column.
_TEXT_LABELS: frozenset[str] = frozenset({"text", "title", "header_footer"})
_MERGE_LEVEL_BY_LABEL: dict[str, str] = {
"table": "word",
"chart": "paragraph",
"infographic": "paragraph",
"text": "paragraph",
"title": "paragraph",
"header_footer": "paragraph",
}

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _merge_level_for_ocr_label(label_name: str) -> str:
"""Return the OCR merge level used by the local path for this element label."""
try:
return _MERGE_LEVEL_BY_LABEL[label_name]
except KeyError:
raise ValueError(f"Unsupported OCR label for merge level: {label_name!r}") from None


def _error_payload(*, stage: str, exc: BaseException) -> Dict[str, Any]:
return {
"timing": None,
Expand Down Expand Up @@ -410,6 +426,7 @@ def ocr_b64_to_text(
invoke_kw = dict(
invoke_url=invoke_url,
image_b64_list=valid_b64,
merge_levels=[merge_level] * len(valid_b64),
api_key=api_key,
timeout_s=float(timeout_s),
max_batch_size=int(batch_size),
Expand Down Expand Up @@ -807,11 +824,13 @@ def ocr_page_elements(
crops = _crop_all_from_page(page_image_b64, dets, row_wanted, as_b64=True)
crop_b64s: List[str] = [b64 for _label, _bbox, b64 in crops]
crop_meta: List[Tuple[str, List[float]]] = [(label, bbox) for label, bbox, _b64 in crops]
merge_levels = [_merge_level_for_ocr_label(label) for label, _bbox in crop_meta]

if crop_b64s:
_invoke_kw = dict(
invoke_url=invoke_url,
image_b64_list=crop_b64s,
merge_levels=merge_levels,
api_key=api_key,
timeout_s=float(request_timeout_s),
max_batch_size=int(kwargs.get("inference_batch_size", 8)),
Expand Down Expand Up @@ -891,9 +910,11 @@ def ocr_page_elements(

# Tables require word-level merging; charts/infographics use paragraph-level.
# Group by merge level so each batched invoke uses one consistent setting.
local_jobs: Dict[str, List[Tuple[str, List[float], np.ndarray]]] = {"word": [], "paragraph": []}
local_jobs: Dict[str, List[Tuple[str, List[float], np.ndarray]]] = {
ml: [] for ml in _MERGE_LEVEL_BY_LABEL.values()
}
for label_name, bbox, crop_array in crops:
ml = "word" if label_name == "table" else "paragraph"
ml = _merge_level_for_ocr_label(label_name)
local_jobs[ml].append((label_name, bbox, crop_array))

def _append_local_result(
Expand Down
33 changes: 32 additions & 1 deletion nemo_retriever/tests/test_chart_graphic_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import base64
import importlib
import io
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pandas as pd
import pytest
Expand Down Expand Up @@ -192,6 +192,37 @@ def test_with_mocked_models_produces_text(self) -> None:
assert "ChartTitle" in chart_entries[0]["text"]
assert chart_entries[0]["bbox_xyxy_norm"] == [0.0, 0.0, 1.0, 1.0]

def test_remote_ocr_uses_word_merge_level(self) -> None:
from nemo_retriever.chart.chart_detection import graphic_elements_ocr_page_elements

import torch

df = _make_chart_page_df(width=200, height=100)

mock_ge_model = MagicMock()
mock_ge_model._model = MagicMock()
mock_ge_model._model.labels = ["chart_title"]
mock_ge_model.preprocess.return_value = torch.zeros(1, 3, 100, 200)
mock_ge_model.invoke.return_value = {
"boxes": torch.tensor([[0.0, 0.0, 1.0, 0.3]]),
"labels": torch.tensor([0]),
"scores": torch.tensor([0.9]),
}

ocr_response = [[{"left": 0.1, "right": 0.9, "upper": 0.05, "lower": 0.25, "text": "ChartTitle"}]]
with patch(
"nemo_retriever.chart.shared.invoke_image_inference_batches",
return_value=ocr_response,
) as remote_ocr:
result = graphic_elements_ocr_page_elements(
df,
graphic_elements_model=mock_ge_model,
ocr_invoke_url="http://fake-ocr",
)

assert "ChartTitle" in result.iloc[0]["chart"][0]["text"]
assert remote_ocr.call_args.kwargs["merge_levels"] == ["word"]

def test_fallback_when_no_ge_detections(self) -> None:
"""When GE model returns no detections, should fall back to OCR-only text."""
from nemo_retriever.chart.chart_detection import graphic_elements_ocr_page_elements
Expand Down
84 changes: 83 additions & 1 deletion nemo_retriever/tests/test_table_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,40 @@ def _make_page_df_with_ts_regions(
class TestOCRJoinsTableStructure:
"""When use_table_structure=True, OCR stage should join structure + OCR."""

def test_merge_level_mapping_rejects_unknown_labels(self) -> None:
from nemo_retriever.ocr.shared import _merge_level_for_ocr_label

assert _merge_level_for_ocr_label("table") == "word"
assert _merge_level_for_ocr_label("chart") == "paragraph"
assert _merge_level_for_ocr_label("infographic") == "paragraph"
assert _merge_level_for_ocr_label("text") == "paragraph"
assert _merge_level_for_ocr_label("title") == "paragraph"
assert _merge_level_for_ocr_label("header_footer") == "paragraph"
with pytest.raises(ValueError, match="Unsupported OCR label"):
_merge_level_for_ocr_label("typo")

def test_local_jobs_follow_merge_level_mapping_values(self, monkeypatch) -> None:
from nemo_retriever.ocr import shared as ocr_shared

monkeypatch.setitem(ocr_shared._MERGE_LEVEL_BY_LABEL, "chart", "line")
df = pd.DataFrame(
[
{
"page_image": {"image_b64": _make_b64_png(200, 100)},
"page_elements_v3": {
"detections": [{"label_name": "chart", "bbox_xyxy_norm": [0.0, 0.0, 1.0, 1.0], "score": 0.95}]
},
}
]
)
ocr_model = MagicMock()
ocr_model.invoke.return_value = [[{"left": 0.1, "right": 0.9, "upper": 0.1, "lower": 0.2, "text": "chart"}]]

result = ocr_shared.ocr_page_elements(df, model=ocr_model, extract_charts=True)

assert result.iloc[0]["chart"][0]["text"] == "chart"
assert ocr_model.invoke.call_args.kwargs["merge_level"] == "line"

def _structure_2x2(self) -> list[dict]:
return [
{"bbox_xyxy_norm": [0.0, 0.0, 1.0, 0.5], "label_name": "row", "score": 0.9},
Expand Down Expand Up @@ -482,7 +516,7 @@ def test_remote_path_joins_structure_and_ocr(self) -> None:
with patch(
"nemo_retriever.ocr.shared.invoke_image_inference_batches",
return_value=[self._ocr_preds_abcd()],
):
) as remote_invoke:
result = ocr_page_elements(
df,
invoke_url="http://fake-ocr",
Expand All @@ -497,6 +531,54 @@ def test_remote_path_joins_structure_and_ocr(self) -> None:
assert cell in text, f"missing cell '{cell}' in joined markdown: {text!r}"
# Structure-aware markdown includes a header separator; pseudo-markdown does not.
assert "---" in text, f"expected structure-aware markdown, got: {text!r}"
assert remote_invoke.call_args.kwargs["merge_levels"] == ["word"]

def test_remote_path_sends_merge_levels_by_modality(self) -> None:
"""Remote OCR should request word-level tables and paragraph-level non-table crops."""
from nemo_retriever.ocr.shared import ocr_page_elements

image_b64 = _make_b64_png(240, 160)
df = pd.DataFrame(
[
{
"page_image": {"image_b64": image_b64},
"page_elements_v3": {
"detections": [
{"label_name": "table", "bbox_xyxy_norm": [0.0, 0.0, 0.5, 0.5], "score": 0.95},
{"label_name": "chart", "bbox_xyxy_norm": [0.5, 0.0, 1.0, 0.5], "score": 0.95},
{"label_name": "infographic", "bbox_xyxy_norm": [0.0, 0.5, 0.5, 1.0], "score": 0.95},
{"label_name": "text", "bbox_xyxy_norm": [0.5, 0.5, 1.0, 1.0], "score": 0.95},
]
},
"metadata": {"needs_ocr_for_text": True},
}
]
)
remote_items = [
[{"left": 0.1, "right": 0.2, "upper": 0.1, "lower": 0.2, "text": "table"}],
[{"left": 0.1, "right": 0.2, "upper": 0.1, "lower": 0.2, "text": "chart"}],
[{"left": 0.1, "right": 0.2, "upper": 0.1, "lower": 0.2, "text": "info"}],
[{"left": 0.1, "right": 0.2, "upper": 0.1, "lower": 0.2, "text": "body"}],
]

with patch(
"nemo_retriever.ocr.shared.invoke_image_inference_batches",
return_value=remote_items,
) as remote_invoke:
result = ocr_page_elements(
df,
invoke_url="http://fake-ocr",
extract_text=True,
extract_tables=True,
extract_charts=True,
extract_infographics=True,
)

assert result.iloc[0]["table"]
assert result.iloc[0]["chart"]
assert result.iloc[0]["infographic"]
assert result.iloc[0]["text"] == "body"
assert remote_invoke.call_args.kwargs["merge_levels"] == ["word", "paragraph", "paragraph", "paragraph"]


# ---------------------------------------------------------------------------
Expand Down
17 changes: 17 additions & 0 deletions nemo_retriever/tests/test_video_frame_ocr_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,23 @@ def test_cpu_actor_calls_remote_batched_with_b64_list() -> None:
call_kwargs = nim_client.invoke_image_inference_batches.call_args.kwargs
assert call_kwargs["image_b64_list"] == ["AAA", "BBB"]
assert call_kwargs["invoke_url"] == "https://example/ocr"
assert call_kwargs["merge_levels"] == ["paragraph", "paragraph"]


def test_cpu_actor_forwards_configured_merge_level_to_remote() -> None:
df = _make_frame_df(["AAA"])

nim_client = MagicMock()
nim_client.invoke_image_inference_batches = MagicMock(return_value=[[{"text_prediction": {"text": "word"}}]])

actor = VideoFrameOCRCPUActor(ocr_invoke_url="https://example/ocr", merge_level="word")
actor._nim_client = nim_client

out = actor.run(df)

assert out["text"].tolist() == ["word"]
call_kwargs = nim_client.invoke_image_inference_batches.call_args.kwargs
assert call_kwargs["merge_levels"] == ["word"]


def test_cpu_actor_defaults_to_hosted_ocr_v1() -> None:
Expand Down
Loading