From 6d80fd990ebf15003ecea513c12ea24f00cffb8b Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Mon, 18 May 2026 14:09:13 +0000 Subject: [PATCH 1/2] Pass merge levels to remote OCR --- .../src/nemo_retriever/chart/shared.py | 1 + .../src/nemo_retriever/ocr/shared.py | 25 +++++- .../tests/test_chart_graphic_elements.py | 33 ++++++- nemo_retriever/tests/test_table_structure.py | 86 ++++++++++++++++++- .../tests/test_video_frame_ocr_actor.py | 17 ++++ 5 files changed, 158 insertions(+), 4 deletions(-) diff --git a/nemo_retriever/src/nemo_retriever/chart/shared.py b/nemo_retriever/src/nemo_retriever/chart/shared.py index 0ac8625489..0ed780a58a 100644 --- a/nemo_retriever/src/nemo_retriever/chart/shared.py +++ b/nemo_retriever/src/nemo_retriever/chart/shared.py @@ -521,6 +521,7 @@ def graphic_elements_ocr_page_elements( _ocr_kw = dict( invoke_url=ocr_url, image_b64_list=flat_crop_b64s, + merge_levels=["word"] * len(flat_crop_b64s), api_key=api_key or None, timeout_s=float(request_timeout_s), max_batch_size=inference_batch_size, diff --git a/nemo_retriever/src/nemo_retriever/ocr/shared.py b/nemo_retriever/src/nemo_retriever/ocr/shared.py index 9b4b6fcff1..f334cf9b28 100644 --- a/nemo_retriever/src/nemo_retriever/ocr/shared.py +++ b/nemo_retriever/src/nemo_retriever/ocr/shared.py @@ -45,12 +45,28 @@ # content like tables/charts/infographics). Used by the OCR stage to # decide which detections contribute to the page's ``text`` column. _TEXT_LABELS: frozenset[str] = frozenset({"text", "title", "header_footer"}) +_MERGE_LEVEL_BY_LABEL: dict[str, str] = { + "table": "word", + "chart": "paragraph", + "infographic": "paragraph", + "text": "paragraph", + "title": "paragraph", + "header_footer": "paragraph", +} # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- +def _merge_level_for_ocr_label(label_name: str) -> str: + """Return the OCR merge level used by the local path for this element label.""" + try: + return _MERGE_LEVEL_BY_LABEL[label_name] + except KeyError: + raise ValueError(f"Unsupported OCR label for merge level: {label_name!r}") from None + + def _error_payload(*, stage: str, exc: BaseException) -> Dict[str, Any]: return { "timing": None, @@ -410,6 +426,7 @@ def ocr_b64_to_text( invoke_kw = dict( invoke_url=invoke_url, image_b64_list=valid_b64, + merge_levels=[merge_level] * len(valid_b64), api_key=api_key, timeout_s=float(timeout_s), max_batch_size=int(batch_size), @@ -807,11 +824,13 @@ def ocr_page_elements( crops = _crop_all_from_page(page_image_b64, dets, row_wanted, as_b64=True) crop_b64s: List[str] = [b64 for _label, _bbox, b64 in crops] crop_meta: List[Tuple[str, List[float]]] = [(label, bbox) for label, bbox, _b64 in crops] + merge_levels = [_merge_level_for_ocr_label(label) for label, _bbox in crop_meta] if crop_b64s: _invoke_kw = dict( invoke_url=invoke_url, image_b64_list=crop_b64s, + merge_levels=merge_levels, api_key=api_key, timeout_s=float(request_timeout_s), max_batch_size=int(kwargs.get("inference_batch_size", 8)), @@ -891,9 +910,11 @@ def ocr_page_elements( # Tables require word-level merging; charts/infographics use paragraph-level. # Group by merge level so each batched invoke uses one consistent setting. - local_jobs: Dict[str, List[Tuple[str, List[float], np.ndarray]]] = {"word": [], "paragraph": []} + local_jobs: Dict[str, List[Tuple[str, List[float], np.ndarray]]] = { + ml: [] for ml in _MERGE_LEVEL_BY_LABEL.values() + } for label_name, bbox, crop_array in crops: - ml = "word" if label_name == "table" else "paragraph" + ml = _merge_level_for_ocr_label(label_name) local_jobs[ml].append((label_name, bbox, crop_array)) def _append_local_result( diff --git a/nemo_retriever/tests/test_chart_graphic_elements.py b/nemo_retriever/tests/test_chart_graphic_elements.py index 88e4433a4d..cea724ff0d 100644 --- a/nemo_retriever/tests/test_chart_graphic_elements.py +++ b/nemo_retriever/tests/test_chart_graphic_elements.py @@ -9,7 +9,7 @@ import base64 import importlib import io -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pandas as pd import pytest @@ -192,6 +192,37 @@ def test_with_mocked_models_produces_text(self) -> None: assert "ChartTitle" in chart_entries[0]["text"] assert chart_entries[0]["bbox_xyxy_norm"] == [0.0, 0.0, 1.0, 1.0] + def test_remote_ocr_uses_word_merge_level(self) -> None: + from nemo_retriever.chart.chart_detection import graphic_elements_ocr_page_elements + + import torch + + df = _make_chart_page_df(width=200, height=100) + + mock_ge_model = MagicMock() + mock_ge_model._model = MagicMock() + mock_ge_model._model.labels = ["chart_title"] + mock_ge_model.preprocess.return_value = torch.zeros(1, 3, 100, 200) + mock_ge_model.invoke.return_value = { + "boxes": torch.tensor([[0.0, 0.0, 1.0, 0.3]]), + "labels": torch.tensor([0]), + "scores": torch.tensor([0.9]), + } + + ocr_response = [[{"left": 0.1, "right": 0.9, "upper": 0.05, "lower": 0.25, "text": "ChartTitle"}]] + with patch( + "nemo_retriever.chart.shared.invoke_image_inference_batches", + return_value=ocr_response, + ) as remote_ocr: + result = graphic_elements_ocr_page_elements( + df, + graphic_elements_model=mock_ge_model, + ocr_invoke_url="http://fake-ocr", + ) + + assert "ChartTitle" in result.iloc[0]["chart"][0]["text"] + assert remote_ocr.call_args.kwargs["merge_levels"] == ["word"] + def test_fallback_when_no_ge_detections(self) -> None: """When GE model returns no detections, should fall back to OCR-only text.""" from nemo_retriever.chart.chart_detection import graphic_elements_ocr_page_elements diff --git a/nemo_retriever/tests/test_table_structure.py b/nemo_retriever/tests/test_table_structure.py index 1a55a51ea5..efbbb7e6f7 100644 --- a/nemo_retriever/tests/test_table_structure.py +++ b/nemo_retriever/tests/test_table_structure.py @@ -378,6 +378,42 @@ def _make_page_df_with_ts_regions( class TestOCRJoinsTableStructure: """When use_table_structure=True, OCR stage should join structure + OCR.""" + def test_merge_level_mapping_rejects_unknown_labels(self) -> None: + from nemo_retriever.ocr.shared import _merge_level_for_ocr_label + + assert _merge_level_for_ocr_label("table") == "word" + assert _merge_level_for_ocr_label("chart") == "paragraph" + assert _merge_level_for_ocr_label("infographic") == "paragraph" + assert _merge_level_for_ocr_label("text") == "paragraph" + assert _merge_level_for_ocr_label("title") == "paragraph" + assert _merge_level_for_ocr_label("header_footer") == "paragraph" + with pytest.raises(ValueError, match="Unsupported OCR label"): + _merge_level_for_ocr_label("typo") + + def test_local_jobs_follow_merge_level_mapping_values(self, monkeypatch) -> None: + from nemo_retriever.ocr import shared as ocr_shared + + monkeypatch.setitem(ocr_shared._MERGE_LEVEL_BY_LABEL, "chart", "line") + df = pd.DataFrame( + [ + { + "page_image": {"image_b64": _make_b64_png(200, 100)}, + "page_elements_v3": { + "detections": [ + {"label_name": "chart", "bbox_xyxy_norm": [0.0, 0.0, 1.0, 1.0], "score": 0.95} + ] + }, + } + ] + ) + ocr_model = MagicMock() + ocr_model.invoke.return_value = [[{"left": 0.1, "right": 0.9, "upper": 0.1, "lower": 0.2, "text": "chart"}]] + + result = ocr_shared.ocr_page_elements(df, model=ocr_model, extract_charts=True) + + assert result.iloc[0]["chart"][0]["text"] == "chart" + assert ocr_model.invoke.call_args.kwargs["merge_level"] == "line" + def _structure_2x2(self) -> list[dict]: return [ {"bbox_xyxy_norm": [0.0, 0.0, 1.0, 0.5], "label_name": "row", "score": 0.9}, @@ -482,7 +518,7 @@ def test_remote_path_joins_structure_and_ocr(self) -> None: with patch( "nemo_retriever.ocr.shared.invoke_image_inference_batches", return_value=[self._ocr_preds_abcd()], - ): + ) as remote_invoke: result = ocr_page_elements( df, invoke_url="http://fake-ocr", @@ -497,6 +533,54 @@ def test_remote_path_joins_structure_and_ocr(self) -> None: assert cell in text, f"missing cell '{cell}' in joined markdown: {text!r}" # Structure-aware markdown includes a header separator; pseudo-markdown does not. assert "---" in text, f"expected structure-aware markdown, got: {text!r}" + assert remote_invoke.call_args.kwargs["merge_levels"] == ["word"] + + def test_remote_path_sends_merge_levels_by_modality(self) -> None: + """Remote OCR should request word-level tables and paragraph-level non-table crops.""" + from nemo_retriever.ocr.shared import ocr_page_elements + + image_b64 = _make_b64_png(240, 160) + df = pd.DataFrame( + [ + { + "page_image": {"image_b64": image_b64}, + "page_elements_v3": { + "detections": [ + {"label_name": "table", "bbox_xyxy_norm": [0.0, 0.0, 0.5, 0.5], "score": 0.95}, + {"label_name": "chart", "bbox_xyxy_norm": [0.5, 0.0, 1.0, 0.5], "score": 0.95}, + {"label_name": "infographic", "bbox_xyxy_norm": [0.0, 0.5, 0.5, 1.0], "score": 0.95}, + {"label_name": "text", "bbox_xyxy_norm": [0.5, 0.5, 1.0, 1.0], "score": 0.95}, + ] + }, + "metadata": {"needs_ocr_for_text": True}, + } + ] + ) + remote_items = [ + [{"left": 0.1, "right": 0.2, "upper": 0.1, "lower": 0.2, "text": "table"}], + [{"left": 0.1, "right": 0.2, "upper": 0.1, "lower": 0.2, "text": "chart"}], + [{"left": 0.1, "right": 0.2, "upper": 0.1, "lower": 0.2, "text": "info"}], + [{"left": 0.1, "right": 0.2, "upper": 0.1, "lower": 0.2, "text": "body"}], + ] + + with patch( + "nemo_retriever.ocr.shared.invoke_image_inference_batches", + return_value=remote_items, + ) as remote_invoke: + result = ocr_page_elements( + df, + invoke_url="http://fake-ocr", + extract_text=True, + extract_tables=True, + extract_charts=True, + extract_infographics=True, + ) + + assert result.iloc[0]["table"] + assert result.iloc[0]["chart"] + assert result.iloc[0]["infographic"] + assert result.iloc[0]["text"] == "body" + assert remote_invoke.call_args.kwargs["merge_levels"] == ["word", "paragraph", "paragraph", "paragraph"] # --------------------------------------------------------------------------- diff --git a/nemo_retriever/tests/test_video_frame_ocr_actor.py b/nemo_retriever/tests/test_video_frame_ocr_actor.py index 6f3d4528b9..9bc5d6f31c 100644 --- a/nemo_retriever/tests/test_video_frame_ocr_actor.py +++ b/nemo_retriever/tests/test_video_frame_ocr_actor.py @@ -70,6 +70,23 @@ def test_cpu_actor_calls_remote_batched_with_b64_list() -> None: call_kwargs = nim_client.invoke_image_inference_batches.call_args.kwargs assert call_kwargs["image_b64_list"] == ["AAA", "BBB"] assert call_kwargs["invoke_url"] == "https://example/ocr" + assert call_kwargs["merge_levels"] == ["paragraph", "paragraph"] + + +def test_cpu_actor_forwards_configured_merge_level_to_remote() -> None: + df = _make_frame_df(["AAA"]) + + nim_client = MagicMock() + nim_client.invoke_image_inference_batches = MagicMock(return_value=[[{"text_prediction": {"text": "word"}}]]) + + actor = VideoFrameOCRCPUActor(ocr_invoke_url="https://example/ocr", merge_level="word") + actor._nim_client = nim_client + + out = actor.run(df) + + assert out["text"].tolist() == ["word"] + call_kwargs = nim_client.invoke_image_inference_batches.call_args.kwargs + assert call_kwargs["merge_levels"] == ["word"] def test_cpu_actor_defaults_to_hosted_ocr_v1() -> None: From caeb210754de04d9f902c081e289b12f4538dc81 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Mon, 18 May 2026 14:45:31 +0000 Subject: [PATCH 2/2] Linting --- nemo_retriever/tests/test_table_structure.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/nemo_retriever/tests/test_table_structure.py b/nemo_retriever/tests/test_table_structure.py index efbbb7e6f7..cb35da7c91 100644 --- a/nemo_retriever/tests/test_table_structure.py +++ b/nemo_retriever/tests/test_table_structure.py @@ -399,9 +399,7 @@ def test_local_jobs_follow_merge_level_mapping_values(self, monkeypatch) -> None { "page_image": {"image_b64": _make_b64_png(200, 100)}, "page_elements_v3": { - "detections": [ - {"label_name": "chart", "bbox_xyxy_norm": [0.0, 0.0, 1.0, 1.0], "score": 0.95} - ] + "detections": [{"label_name": "chart", "bbox_xyxy_norm": [0.0, 0.0, 1.0, 1.0], "score": 0.95}] }, } ]