diff --git a/README.md b/README.md index a11b6e5d9c..9db68981bc 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,23 @@ doc = DocumentFile.from_pdf("path/to/your/doc.pdf") result = model(doc) ``` +### Detecting the document layout + +You can additionally run a layout detection model as part of the pipeline by passing `detect_layout=True`. The detected regions (e.g. `Title`, `Text`, `Table`, `Page-header`, `Page-footer`) are attached to every page and rendered by `.show()`: + +```python +from doctr.io import DocumentFile +from doctr.models import ocr_predictor + +model = ocr_predictor(pretrained=True, detect_layout=True) +doc = DocumentFile.from_images("path/to/your/doc.jpg") +result = model(doc) + +# Access the detected layout regions of the first page +for region in result.pages[0].layout: + print(region.type, region.confidence, region.geometry) +``` + ### Dealing with rotated documents Should you use docTR on documents that include rotated pages, or pages with multiple box orientations, diff --git a/api/app/routes/kie.py b/api/app/routes/kie.py index d329b34ab8..324688c073 100644 --- a/api/app/routes/kie.py +++ b/api/app/routes/kie.py @@ -6,7 +6,7 @@ from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status -from app.schemas import KIEElement, KIEIn, KIEOut +from app.schemas import KIEElement, KIEIn, KIEOut, LayoutElementOut from app.utils import get_documents, resolve_geometry from app.vision import init_predictor @@ -30,6 +30,14 @@ async def perform_kie(request: KIEIn = Depends(), files: list[UploadFile] = [Fil orientation=page.orientation, language=page.language, dimensions=page.dimensions, + layout=[ + LayoutElementOut( + type=region.type, + geometry=resolve_geometry(region.geometry), + confidence=round(region.confidence, 2), + ) + for region in page.layout + ], predictions=[ KIEElement( class_name=class_name, diff --git a/api/app/routes/ocr.py b/api/app/routes/ocr.py index 56a5f38733..16d6c7295b 100644 --- a/api/app/routes/ocr.py +++ b/api/app/routes/ocr.py @@ -6,7 +6,7 @@ from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status -from app.schemas import OCRBlock, OCRIn, OCRLine, OCROut, OCRPage, OCRWord +from app.schemas import LayoutElementOut, OCRBlock, OCRIn, OCRLine, OCROut, OCRPage, OCRWord from app.utils import get_documents, resolve_geometry from app.vision import init_predictor @@ -31,6 +31,14 @@ async def perform_ocr(request: OCRIn = Depends(), files: list[UploadFile] = [Fil orientation=page.orientation, language=page.language, dimensions=page.dimensions, + layout=[ + LayoutElementOut( + type=region.type, + geometry=resolve_geometry(region.geometry), + confidence=round(region.confidence, 2), + ) + for region in page.layout + ], items=[ OCRPage( blocks=[ diff --git a/api/app/schemas.py b/api/app/schemas.py index 66743967a3..4e5f779168 100644 --- a/api/app/schemas.py +++ b/api/app/schemas.py @@ -15,6 +15,8 @@ class KIEIn(BaseModel): preserve_aspect_ratio: bool = Field(default=True, examples=[True]) detect_orientation: bool = Field(default=False, examples=[False]) detect_language: bool = Field(default=False, examples=[False]) + detect_layout: bool = Field(default=False, examples=[False]) + layout_arch: str = Field(default="lw_detr_s", examples=["lw_detr_s"]) symmetric_pad: bool = Field(default=True, examples=[True]) straighten_pages: bool = Field(default=False, examples=[False]) det_bs: int = Field(default=2, examples=[2]) @@ -131,11 +133,21 @@ class OCRPage(BaseModel): ) +class LayoutElementOut(BaseModel): + type: str = Field(..., examples=["Title"]) + geometry: list[float] = Field(..., examples=[[0.0, 0.0, 0.0, 0.0]]) + confidence: float = Field(..., examples=[0.99]) + + class OCROut(BaseModel): name: str = Field(..., examples=["example.jpg"]) orientation: dict[str, float | None] = Field(..., examples=[{"value": 0.0, "confidence": 0.99}]) language: dict[str, str | float | None] = Field(..., examples=[{"value": "en", "confidence": 0.99}]) dimensions: tuple[int, int] = Field(..., examples=[(100, 100)]) + layout: list[LayoutElementOut] = Field( + default=[], + examples=[[{"type": "Title", "geometry": [0.0, 0.0, 0.0, 0.0], "confidence": 0.99}]], + ) items: list[OCRPage] = Field( ..., examples=[ @@ -183,4 +195,8 @@ class KIEOut(BaseModel): orientation: dict[str, float | None] = Field(..., examples=[{"value": 0.0, "confidence": 0.99}]) language: dict[str, str | float | None] = Field(..., examples=[{"value": "en", "confidence": 0.99}]) dimensions: tuple[int, int] = Field(..., examples=[(100, 100)]) + layout: list[LayoutElementOut] = Field( + default=[], + examples=[[{"type": "Title", "geometry": [0.0, 0.0, 0.0, 0.0], "confidence": 0.99}]], + ) predictions: list[KIEElement] diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 3aab987d4a..3d9d65a1aa 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -77,6 +77,7 @@ def mock_kie_response(): "orientation": {"value": None, "confidence": None}, "language": {"value": None, "confidence": None}, "dimensions": [2339, 1654], + "layout": [], "predictions": [ { "class_name": "words", @@ -104,6 +105,7 @@ def mock_kie_response(): "orientation": {"value": None, "confidence": None}, "language": {"value": None, "confidence": None}, "dimensions": [2339, 1654], + "layout": [], "predictions": [ { "class_name": "words", @@ -155,6 +157,7 @@ def mock_ocr_response(): "orientation": {"value": None, "confidence": None}, "language": {"value": None, "confidence": None}, "dimensions": [2339, 1654], + "layout": [], "items": [ { "blocks": [ @@ -203,6 +206,7 @@ def mock_ocr_response(): "orientation": {"value": None, "confidence": None}, "language": {"value": None, "confidence": None}, "dimensions": [2339, 1654], + "layout": [], "items": [ { "blocks": [ diff --git a/api/tests/routes/test_kie.py b/api/tests/routes/test_kie.py index 87e1614fe2..3a8e5cfb5f 100644 --- a/api/tests/routes/test_kie.py +++ b/api/tests/routes/test_kie.py @@ -10,6 +10,11 @@ def common_test(json_response, expected_response): and len(first_pred["dimensions"]) == 2 and all(isinstance(dim, int) for dim in first_pred["dimensions"]) ) + assert isinstance(first_pred["layout"], list) + for region in first_pred["layout"]: + assert isinstance(region["type"], str) + assert isinstance(region["confidence"], (int, float)) + assert isinstance(region["geometry"], (tuple, list)) assert isinstance(first_pred["predictions"], list) assert isinstance(expected_response["predictions"], list) @@ -67,6 +72,27 @@ async def test_kie_poly(test_app_asyncio, mock_detection_image, mock_kie_respons common_test(json_response, expected_poly_response) +@pytest.mark.asyncio +async def test_kie_layout(test_app_asyncio, mock_detection_image): + headers = { + "accept": "application/json", + } + params = {"det_arch": "db_resnet50", "reco_arch": "crnn_vgg16_bn", "detect_layout": True} + files = [ + ("files", ("test.jpg", mock_detection_image, "image/jpeg")), + ] + response = await test_app_asyncio.post("/kie", params=params, files=files, headers=headers) + assert response.status_code == 200 + json_response = response.json() + + assert isinstance(json_response, list) and len(json_response) == 1 + assert "layout" in json_response[0] and isinstance(json_response[0]["layout"], list) + for region in json_response[0]["layout"]: + assert isinstance(region["type"], str) + assert isinstance(region["confidence"], (int, float)) + assert len(region["geometry"]) in (4, 8) + + @pytest.mark.asyncio async def test_kie_invalid_file(test_app_asyncio, mock_txt_file): headers = { diff --git a/api/tests/routes/test_ocr.py b/api/tests/routes/test_ocr.py index ac2b96ebf8..bdb54d0174 100644 --- a/api/tests/routes/test_ocr.py +++ b/api/tests/routes/test_ocr.py @@ -11,6 +11,11 @@ def common_test(json_response, expected_response): and len(first_pred["dimensions"]) == 2 and all(isinstance(dim, int) for dim in first_pred["dimensions"]) ) + assert isinstance(first_pred["layout"], list) + for region in first_pred["layout"]: + assert isinstance(region["type"], str) + assert isinstance(region["confidence"], (int, float)) + assert isinstance(region["geometry"], (tuple, list)) for item, expected_item in zip(first_pred["items"], expected_response["items"]): for block, expected_block in zip(item["blocks"], expected_item["blocks"]): np.testing.assert_allclose(block["geometry"], expected_block["geometry"], rtol=1e-2) @@ -67,6 +72,27 @@ async def test_ocr_poly(test_app_asyncio, mock_detection_image, mock_ocr_respons common_test(json_response, expected_poly_response) +@pytest.mark.asyncio +async def test_ocr_layout(test_app_asyncio, mock_detection_image): + headers = { + "accept": "application/json", + } + params = {"det_arch": "db_resnet50", "reco_arch": "crnn_vgg16_bn", "detect_layout": True} + files = [ + ("files", ("test.jpg", mock_detection_image, "image/jpeg")), + ] + response = await test_app_asyncio.post("/ocr", params=params, files=files, headers=headers) + assert response.status_code == 200 + json_response = response.json() + + assert isinstance(json_response, list) and len(json_response) == 1 + assert "layout" in json_response[0] and isinstance(json_response[0]["layout"], list) + for region in json_response[0]["layout"]: + assert isinstance(region["type"], str) + assert isinstance(region["confidence"], (int, float)) + assert len(region["geometry"]) in (4, 8) + + @pytest.mark.asyncio async def test_ocr_invalid_file(test_app_asyncio, mock_txt_file): headers = { diff --git a/demo/app.py b/demo/app.py index 2a404d14c8..85446708bb 100644 --- a/demo/app.py +++ b/demo/app.py @@ -8,7 +8,7 @@ import numpy as np import streamlit as st import torch -from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor +from backend.pytorch import DET_ARCHS, LAYOUT_ARCHS, RECO_ARCHS, forward_image, load_predictor from doctr.io import DocumentFile from doctr.utils.visualization import visualize_page @@ -16,7 +16,7 @@ forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -def main(det_archs, reco_archs): +def main(det_archs, reco_archs, layout_archs): """Build a streamlit layout""" # Wide mode st.set_page_config(layout="wide") @@ -67,6 +67,9 @@ def main(det_archs, reco_archs): straighten_pages = st.sidebar.checkbox("Straighten pages", value=False) # Export as straight boxes export_straight_boxes = st.sidebar.checkbox("Export as straight boxes", value=False) + # Layout detection + detect_layout = st.sidebar.checkbox("Detect layout", value=False) + layout_arch = st.sidebar.selectbox("Layout detection model", layout_archs, disabled=not detect_layout) st.sidebar.write("\n") # Binarization threshold bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1) @@ -92,6 +95,8 @@ def main(det_archs, reco_archs): bin_thresh=bin_thresh, box_thresh=box_thresh, device=forward_device, + detect_layout=detect_layout, + layout_arch=layout_arch, ) with st.spinner("Analyzing..."): @@ -123,4 +128,4 @@ def main(det_archs, reco_archs): if __name__ == "__main__": - main(DET_ARCHS, RECO_ARCHS) + main(DET_ARCHS, RECO_ARCHS, LAYOUT_ARCHS) diff --git a/demo/backend/pytorch.py b/demo/backend/pytorch.py index e8cf1e7b4a..1c4c9d941d 100644 --- a/demo/backend/pytorch.py +++ b/demo/backend/pytorch.py @@ -31,6 +31,10 @@ "parseq", "viptr_tiny", ] +LAYOUT_ARCHS = [ + "lw_detr_s", + "lw_detr_m", +] def load_predictor( @@ -44,6 +48,8 @@ def load_predictor( bin_thresh: float, box_thresh: float, device: torch.device, + detect_layout: bool, + layout_arch: str, ) -> OCRPredictor: """Load a predictor from doctr.models @@ -58,6 +64,8 @@ def load_predictor( bin_thresh: binarization threshold for the segmentation map box_thresh: minimal objectness score to consider a box device: torch.device, the device to load the predictor on + detect_layout: whether to run a layout detection model and attach the regions to each page + layout_arch: layout architecture to use when detect_layout is True Returns: instance of OCRPredictor @@ -72,6 +80,8 @@ def load_predictor( detect_orientation=not assume_straight_pages, disable_page_orientation=disable_page_orientation, disable_crop_orientation=disable_crop_orientation, + detect_layout=detect_layout, + layout_arch=layout_arch, ).to(device) predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh predictor.det_predictor.model.postprocessor.box_thresh = box_thresh diff --git a/docs/source/modules/io.rst b/docs/source/modules/io.rst index 7ac74025b0..b042e28e2a 100644 --- a/docs/source/modules/io.rst +++ b/docs/source/modules/io.rst @@ -33,6 +33,13 @@ An Artefact is a non-textual element (e.g. QR code, picture, chart, signature, l .. autoclass:: Artefact +LayoutElement +^^^^^^^^^^^^^ + +A LayoutElement is a region predicted by a layout detection model (e.g. Title, Text, Table, Page-header, Page-footer). Layout regions are attached to a :class:`Page` when the ``ocr_predictor`` / ``kie_predictor`` is run with ``detect_layout=True``. + +.. autoclass:: LayoutElement + Block ^^^^^ A Block is a collection of Lines (e.g. an address written on several lines) and Artefacts (e.g. a graph with its title underneath). diff --git a/docs/source/modules/models.rst b/docs/source/modules/models.rst index 55ce88a365..dc8fe51ef7 100644 --- a/docs/source/modules/models.rst +++ b/docs/source/modules/models.rst @@ -83,6 +83,8 @@ doctr.models.layout .. autofunction:: doctr.models.layout.lw_detr_m +.. autofunction:: doctr.models.layout.layout_predictor + doctr.models.recognition ------------------------ diff --git a/docs/source/using_doctr/custom_models_training.rst b/docs/source/using_doctr/custom_models_training.rst index 9b28df0fbb..802218296e 100644 --- a/docs/source/using_doctr/custom_models_training.rst +++ b/docs/source/using_doctr/custom_models_training.rst @@ -67,6 +67,26 @@ Load a custom layout analysis model trained on another set of classes as the def predictor = layout_predictor(layout_arch=layout_model, pretrained=True) + +Plug a custom layout analysis model (trained on another set of classes) directly into the OCR pipeline so the detected regions are attached to every page: + +.. code:: python3 + + import torch + from doctr.models import ocr_predictor, lw_detr_s + + # Custom layout model with your own class names + layout_model = lw_detr_s(pretrained=False, class_names=["heading", "paragraph", "figure", "table"]) + layout_model.from_pretrained('') + + # Pass it through `layout_arch`, exactly as for the detection / recognition models + predictor = ocr_predictor(pretrained=True, detect_layout=True, layout_arch=layout_model) + + result = predictor(doc) + # The regions (with your custom class names) are available on each page + print([(region.type, region.confidence) for region in result.pages[0].layout]) + + Load a custom trained KIE detection model: .. code:: python3 diff --git a/docs/source/using_doctr/sharing_models.rst b/docs/source/using_doctr/sharing_models.rst index b2dcbfbc6f..a99af99b88 100644 --- a/docs/source/using_doctr/sharing_models.rst +++ b/docs/source/using_doctr/sharing_models.rst @@ -68,6 +68,8 @@ We suggest using the following naming conventions for your models: **Recognition:** ``doctr--`` +**Layout:** ``doctr-`` + Classification -------------- @@ -101,3 +103,13 @@ Recognition +---------------------------------+---------------------------------------------------+---------------------+------------------------+ | parseq | rania-sr/doctr-model-v1-arabic | arabic | PyTorch | +---------------------------------+---------------------------------------------------+---------------------+------------------------+ + + +Layout +------ + ++---------------------------------+---------------------------------------------------+------------------------+ +| **Architecture** | **Repo_ID** | **Framework** | ++=================================+===================================================+========================+ +| lw_detr_s (dummy) | Felix92/doctr-dummy-torch-lw-detr-s | PyTorch | ++---------------------------------+---------------------------------------------------+------------------------+ diff --git a/docs/source/using_doctr/using_models.rst b/docs/source/using_doctr/using_models.rst index 4f67f73ead..f2fe02177e 100644 --- a/docs/source/using_doctr/using_models.rst +++ b/docs/source/using_doctr/using_models.rst @@ -308,6 +308,10 @@ Additional arguments which can be passed to the `ocr_predictor` are: * `export_as_straight_boxes`: If you work with rotated and skewed documents but you still want to export straight bounding boxes and not polygons, set it to True. * `straighten_pages`: If you want to straighten the pages before sending them to the detection model, set it to True. +* `detect_orientation`: If you want to estimate the general page orientation and add it to each page, set it to True. +* `detect_language`: If you want to predict the language of the text on each page, set it to True. +* `detect_layout`: If you want to run a layout detection model on each page and attach the detected regions to each page, set it to True (default: False). +* `layout_arch`: The layout architecture name (e.g. ``'lw_detr_s'``, ``'lw_detr_m'``) or your own (fine-tuned) layout model instance to use when ``detect_layout=True``. For instance, this snippet instantiates an end-to-end ocr_predictor working with rotated documents, which preserves the aspect ratio of the documents, and returns polygons: @@ -319,7 +323,7 @@ For instance, this snippet instantiates an end-to-end ocr_predictor working with Additionally, you can change the batch size of the underlying detection and recognition predictors to optimize the performance depending on your hardware: -* `det_bs`: batch size for the detection model (default: 2) +* `det_bs`: batch size for the detection model (default: 2) - will also be used for the layout model if ``detect_layout=True`` * `reco_bs`: batch size for the recognition model (default: 128) .. code:: python3 @@ -341,6 +345,34 @@ For example to disable the automatic grouping of lines into blocks: model = ocr_predictor(pretrained=True, resolve_blocks=False) +Detecting the document layout +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In addition to running the :py:meth:`layout_predictor ` standalone, you can plug a layout detection model directly into the end-to-end pipeline by setting ``detect_layout=True``. The detected regions (e.g. Title, Text, Table, Page-header, Page-footer) are attached to every :class:`Page ` and can be accessed through ``page.layout``, exported alongside the rest of the page, and rendered with :py:meth:`show `. + +.. code:: python3 + + from doctr.io import DocumentFile + from doctr.models import ocr_predictor + + model = ocr_predictor(pretrained=True, detect_layout=True) + doc = DocumentFile.from_images("path/to/your/doc.jpg") + result = model(doc) + + # Access the detected layout regions of the first page + for region in result.pages[0].layout: + print(region.type, region.confidence, region.geometry) + + # The layout is part of the exported representation + export = result.pages[0].export() + print(export["layout"]) + + # Overlay both text and layout regions (use display_layout=False to hide the regions) + result.pages[0].show() + +The same ``detect_layout`` / ``layout_arch`` arguments are available for the :py:meth:`kie_predictor `. + + Running the predictors on GPU ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -369,6 +401,7 @@ The same approach applies to all standalone predictors: * `detection_predictor` * `crop_orientation_predictor` * `page_orientation_predictor` +* `layout_predictor` Just create the predictor instance and move it to the appropriate device. To enable **half-precision inference**, you can append `.half()` after moving the predictor to the device. @@ -378,6 +411,7 @@ What should I do with the output? ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The ocr_predictor returns a `Document` object with a nested structure (with `Page`, `Block`, `Line`, `Word`, `Artefact`). +When ``detect_layout=True`` was passed, each `Page` additionally carries a list of `LayoutElement` regions under ``page.layout``. To get a better understanding of our document model, check our :ref:`document_structure` section Here is a typical `Document` layout:: diff --git a/doctr/models/builder.py b/doctr/models/builder.py index dca7ced5f8..685de2d215 100644 --- a/doctr/models/builder.py +++ b/doctr/models/builder.py @@ -9,7 +9,7 @@ import numpy as np from scipy.cluster.hierarchy import fclusterdata -from doctr.io.elements import Block, Document, KIEDocument, KIEPage, Line, Page, Prediction, Word +from doctr.io.elements import Block, Document, KIEDocument, KIEPage, LayoutElement, Line, Page, Prediction, Word from doctr.utils.geometry import estimate_page_angle, resolve_enclosing_bbox, resolve_enclosing_rbbox, rotate_boxes from doctr.utils.repr import NestedObject @@ -211,6 +211,33 @@ def _resolve_blocks(boxes: np.ndarray, lines: list[list[int]]) -> list[list[list return blocks + @staticmethod + def _build_layout_elements(regions: dict[str, Any] | None) -> list[LayoutElement]: + """Convert a raw layout prediction into exportable ``LayoutElement`` objects. + + Args: + regions: a layout prediction ``{"boxes": (R, 4) | (R, 4, 2), "class_names": [...], "scores": [...]}`` + as returned by a ``LayoutPredictor``, or None. + + Returns: + list of ``LayoutElement`` (empty if no layout was provided). + """ + if regions is None or len(regions.get("boxes", [])) == 0: + return [] + boxes = np.asarray(regions["boxes"]) + class_names = regions.get("class_names") or ["" for _ in range(len(boxes))] + scores = regions.get("scores") + scores = scores if scores is not None else [1.0 for _ in range(len(boxes))] + + elements: list[LayoutElement] = [] + for box, cname, score in zip(boxes, class_names, scores): + if box.ndim == 2: # rotated polygon (4, 2) + geometry: Any = tuple(tuple(float(c) for c in pt) for pt in box.tolist()) + else: # straight (x1, y1, x2, y2) + geometry = ((float(box[0]), float(box[1])), (float(box[2]), float(box[3]))) + elements.append(LayoutElement(layout_type=str(cname), confidence=float(score), geometry=geometry)) + return elements + def _build_blocks( self, boxes: np.ndarray, @@ -292,6 +319,7 @@ def __call__( crop_orientations: list[dict[str, Any]], orientations: list[dict[str, Any]] | None = None, languages: list[dict[str, Any]] | None = None, + regions: list[dict[str, Any] | None] | None = None, ) -> Document: """Re-arrange detected words into structured blocks @@ -308,6 +336,8 @@ def __call__( where each element is a dictionary containing the orientation (orientation + confidence) languages: optional, list of N elements, where each element is a dictionary containing the language (language + confidence) + regions: optional, list of N elements, where each element is a layout prediction + ``{"boxes": (R, 4|4x2), "class_names": [...], "scores": [...]}`` attached to each page Returns: document object @@ -319,6 +349,7 @@ def __call__( _orientations = orientations if isinstance(orientations, list) else [None] * len(boxes) _languages = languages if isinstance(languages, list) else [None] * len(boxes) + _regions = regions if isinstance(regions, list) else [None] * len(boxes) if self.export_as_straight_boxes and len(boxes) > 0: # If boxes are already straight OK, else fit a bounding rect if boxes[0].ndim == 3: @@ -338,8 +369,9 @@ def __call__( shape, orientation, language, + self._build_layout_elements(page_regions), ) - for page, _idx, shape, page_boxes, loc_scores, word_preds, word_crop_orientations, orientation, language in zip( # noqa: E501 + for page, _idx, shape, page_boxes, loc_scores, word_preds, word_crop_orientations, orientation, language, page_regions in zip( # noqa: E501 pages, range(len(boxes)), page_shapes, @@ -349,6 +381,7 @@ def __call__( crop_orientations, _orientations, _languages, + _regions, ) ] @@ -376,6 +409,7 @@ def __call__( # type: ignore[override] crop_orientations: list[dict[str, list[dict[str, Any]]]], orientations: list[dict[str, Any]] | None = None, languages: list[dict[str, Any]] | None = None, + regions: list[dict[str, Any] | None] | None = None, ) -> KIEDocument: """Re-arrange detected words into structured predictions @@ -392,6 +426,8 @@ def __call__( # type: ignore[override] where each element is a dictionary containing the orientation (orientation + confidence) languages: optional, list of N elements, where each element is a dictionary containing the language (language + confidence) + regions: optional, list of N elements, where each element is a layout prediction + ``{"boxes": (R, 4|4x2), "class_names": [...], "scores": [...]}`` attached to each page Returns: document object @@ -402,6 +438,7 @@ def __call__( # type: ignore[override] raise ValueError("All arguments are expected to be lists of the same size") _orientations = orientations if isinstance(orientations, list) else [None] * len(boxes) _languages = languages if isinstance(languages, list) else [None] * len(boxes) + _regions = regions if isinstance(regions, list) else [None] * len(boxes) if self.export_as_straight_boxes and len(boxes) > 0: # If boxes are already straight OK, else fit a bounding rect if next(iter(boxes[0].values())).ndim == 3: @@ -431,8 +468,9 @@ def __call__( # type: ignore[override] shape, orientation, language, + self._build_layout_elements(page_regions), ) - for page, _idx, shape, page_boxes, loc_scores, word_preds, word_crop_orientations, orientation, language in zip( # noqa: E501 + for page, _idx, shape, page_boxes, loc_scores, word_preds, word_crop_orientations, orientation, language, page_regions in zip( # noqa: E501 pages, range(len(boxes)), page_shapes, @@ -442,6 +480,7 @@ def __call__( # type: ignore[override] crop_orientations, _orientations, _languages, + _regions, ) ] diff --git a/doctr/models/kie_predictor/pytorch.py b/doctr/models/kie_predictor/pytorch.py index 153d97f2d4..ad33cdd37d 100644 --- a/doctr/models/kie_predictor/pytorch.py +++ b/doctr/models/kie_predictor/pytorch.py @@ -12,6 +12,7 @@ from doctr.io.elements import Document from doctr.models._utils import get_language, invert_data_structure from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.layout.predictor import LayoutPredictor from doctr.models.recognition.predictor import RecognitionPredictor from doctr.utils.geometry import detach_scores @@ -35,6 +36,7 @@ class KIEPredictor(nn.Module, _KIEPredictor): page. Doing so will slightly deteriorate the overall latency. detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. + layout_predictor: optional layout detection module **kwargs: keyword args of `DocumentBuilder` """ @@ -48,6 +50,7 @@ def __init__( symmetric_pad: bool = True, detect_orientation: bool = False, detect_language: bool = False, + layout_predictor: LayoutPredictor | None = None, **kwargs: Any, ) -> None: nn.Module.__init__(self) @@ -64,6 +67,7 @@ def __init__( ) self.detect_orientation = detect_orientation self.detect_language = detect_language + self.layout_predictor = layout_predictor.eval() if layout_predictor is not None else None @torch.inference_mode() def forward( @@ -164,6 +168,9 @@ def forward( else: languages_dict = None + # Detect layout regions on the (possibly straightened) pages + regions = self.layout_predictor(pages, **kwargs) if self.layout_predictor is not None else None + out = self.doc_builder( pages, boxes_per_page, @@ -173,6 +180,7 @@ def forward( crop_orientations_per_page, orientations, languages_dict, + regions, ) return out diff --git a/doctr/models/predictor/pytorch.py b/doctr/models/predictor/pytorch.py index 61a55fb4e3..ceab3e8414 100644 --- a/doctr/models/predictor/pytorch.py +++ b/doctr/models/predictor/pytorch.py @@ -12,6 +12,7 @@ from doctr.io.elements import Document from doctr.models._utils import get_language from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.layout.predictor import LayoutPredictor from doctr.models.recognition.predictor import RecognitionPredictor from doctr.utils.geometry import detach_scores @@ -35,6 +36,7 @@ class OCRPredictor(nn.Module, _OCRPredictor): page. Doing so will slightly deteriorate the overall latency. detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. + layout_predictor: optional layout detection module **kwargs: keyword args of `DocumentBuilder` """ @@ -48,6 +50,7 @@ def __init__( symmetric_pad: bool = True, detect_orientation: bool = False, detect_language: bool = False, + layout_predictor: LayoutPredictor | None = None, **kwargs: Any, ) -> None: nn.Module.__init__(self) @@ -64,6 +67,7 @@ def __init__( ) self.detect_orientation = detect_orientation self.detect_language = detect_language + self.layout_predictor = layout_predictor.eval() if layout_predictor is not None else None @torch.inference_mode() def forward( @@ -142,6 +146,9 @@ def forward( else: languages_dict = None + # Detect layout regions on the (possibly straightened) pages + regions = self.layout_predictor(pages, **kwargs) if self.layout_predictor is not None else None + out = self.doc_builder( pages, boxes, @@ -151,5 +158,6 @@ def forward( crop_orientations, orientations, languages_dict, + regions, ) return out diff --git a/doctr/models/zoo.py b/doctr/models/zoo.py index bfa8026943..cf164bc6e3 100644 --- a/doctr/models/zoo.py +++ b/doctr/models/zoo.py @@ -7,6 +7,7 @@ from .detection.zoo import detection_predictor from .kie_predictor import KIEPredictor +from .layout.zoo import layout_predictor from .predictor import OCRPredictor from .recognition.zoo import recognition_predictor @@ -26,6 +27,8 @@ def _predictor( detect_orientation: bool = False, straighten_pages: bool = False, detect_language: bool = False, + detect_layout: bool = False, + layout_arch: Any = "lw_detr_s", **kwargs, ) -> OCRPredictor: # Detection @@ -47,6 +50,20 @@ def _predictor( batch_size=reco_bs, ) + # Layout - optional + layout_pred = ( + layout_predictor( + layout_arch, + pretrained=pretrained, + assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, + batch_size=det_bs, + ) + if detect_layout + else None + ) + return OCRPredictor( det_predictor, reco_predictor, @@ -56,6 +73,7 @@ def _predictor( detect_orientation=detect_orientation, straighten_pages=straighten_pages, detect_language=detect_language, + layout_predictor=layout_pred, **kwargs, ) @@ -72,6 +90,8 @@ def ocr_predictor( detect_orientation: bool = False, straighten_pages: bool = False, detect_language: bool = False, + detect_layout: bool = False, + layout_arch: Any = "lw_detr_s", **kwargs: Any, ) -> OCRPredictor: """End-to-end OCR architecture using one model for localization, and another for text recognition. @@ -104,6 +124,10 @@ def ocr_predictor( Doing so will improve performances for documents with page-uniform rotations. detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. + detect_layout: if True, a layout detection model is run on each page and the detected regions are attached + to each page. + Doing so will slightly deteriorate the overall latency. + layout_arch: name of the layout architecture or the model itself to use. kwargs: keyword args of `OCRPredictor` Returns: @@ -121,6 +145,8 @@ def ocr_predictor( detect_orientation=detect_orientation, straighten_pages=straighten_pages, detect_language=detect_language, + detect_layout=detect_layout, + layout_arch=layout_arch, **kwargs, ) @@ -138,6 +164,8 @@ def _kie_predictor( detect_orientation: bool = False, straighten_pages: bool = False, detect_language: bool = False, + detect_layout: bool = False, + layout_arch: Any = "lw_detr_s", **kwargs, ) -> KIEPredictor: # Detection @@ -159,6 +187,20 @@ def _kie_predictor( batch_size=reco_bs, ) + # Layout - optional + layout_pred = ( + layout_predictor( + layout_arch, + pretrained=pretrained, + assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, + batch_size=det_bs, + ) + if detect_layout + else None + ) + return KIEPredictor( det_predictor, reco_predictor, @@ -168,6 +210,7 @@ def _kie_predictor( detect_orientation=detect_orientation, straighten_pages=straighten_pages, detect_language=detect_language, + layout_predictor=layout_pred, **kwargs, ) @@ -184,6 +227,8 @@ def kie_predictor( detect_orientation: bool = False, straighten_pages: bool = False, detect_language: bool = False, + detect_layout: bool = False, + layout_arch: Any = "lw_detr_s", **kwargs: Any, ) -> KIEPredictor: """End-to-end KIE architecture using one model for localization, and another for text recognition. @@ -216,6 +261,10 @@ def kie_predictor( Doing so will improve performances for documents with page-uniform rotations. detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. + detect_layout: if True, a layout detection model is run on each page and the detected regions are attached + to each page. + Doing so will slightly deteriorate the overall latency. + layout_arch: name of the layout architecture or the model itself to use. kwargs: keyword args of `OCRPredictor` Returns: @@ -233,5 +282,7 @@ def kie_predictor( detect_orientation=detect_orientation, straighten_pages=straighten_pages, detect_language=detect_language, + detect_layout=detect_layout, + layout_arch=layout_arch, **kwargs, ) diff --git a/tests/common/test_models_builder.py b/tests/common/test_models_builder.py index dab41fc440..ad695eb368 100644 --- a/tests/common/test_models_builder.py +++ b/tests/common/test_models_builder.py @@ -3,7 +3,7 @@ from doctr.file_utils import CLASS_NAME from doctr.io import Document -from doctr.io.elements import KIEDocument +from doctr.io.elements import KIEDocument, LayoutElement from doctr.models import builder words_per_page = 10 @@ -184,6 +184,110 @@ def test_kiedocumentbuilder(): ) +def test_documentbuilder_layout(): + + doc_builder = builder.DocumentBuilder() + boxes = np.array([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]) + objectness_scores = np.array([0.9, 0.9]) + regions = [ + { + "boxes": np.array([[0.05, 0.02, 0.95, 0.08], [0.05, 0.2, 0.95, 0.5]], dtype=np.float32), + "class_names": ["Title", "Text"], + "scores": [0.95, 0.88], + } + ] + out = doc_builder( + [np.zeros((100, 100, 3))], + [boxes], + [objectness_scores], + [[("hello", 0.99), ("world", 0.99)]], + [(100, 100)], + [[{"value": 0, "confidence": None}] * 2], + regions=regions, + ) + page = out.pages[0] + # Layout regions are attached as LayoutElement and exported + assert len(page.layout) == 2 + assert all(isinstance(region, LayoutElement) for region in page.layout) + assert [region.type for region in page.layout] == ["Title", "Text"] + assert page.layout[0].confidence == pytest.approx(0.95) + assert np.allclose(np.array(page.layout[0].geometry), [[0.05, 0.02], [0.95, 0.08]], atol=1e-6) + assert page.export()["layout"] == [region.export() for region in page.layout] + + # no regions -> empty layout + out_no_layout = doc_builder( + [np.zeros((100, 100, 3))], + [boxes], + [objectness_scores], + [[("hello", 0.99), ("world", 0.99)]], + [(100, 100)], + [[{"value": 0, "confidence": None}] * 2], + ) + assert out_no_layout.pages[0].layout == [] + assert out_no_layout.pages[0].export()["layout"] == [] + + # Rotated layout polygons (4, 2) are converted to a 4-point geometry + rotated_regions = [ + { + "boxes": np.array([[[0.1, 0.1], [0.4, 0.12], [0.39, 0.3], [0.09, 0.28]]], dtype=np.float32), + "class_names": ["Table"], + "scores": [0.7], + } + ] + out_rot = doc_builder( + [np.zeros((100, 100, 3))], + [boxes], + [objectness_scores], + [[("hello", 0.99), ("world", 0.99)]], + [(100, 100)], + [[{"value": 0, "confidence": None}] * 2], + regions=rotated_regions, + ) + region = out_rot.pages[0].layout[0] + assert region.type == "Table" + assert isinstance(region.geometry, tuple) and len(region.geometry) == 4 + + +def test_kiedocumentbuilder_layout(): + from doctr.io.elements import LayoutElement + + doc_builder = builder.KIEDocumentBuilder() + predictions = {CLASS_NAME: np.array([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]])} + objectness_scores = {CLASS_NAME: np.array([0.9, 0.9])} + regions = [ + { + "boxes": np.array([[0.05, 0.02, 0.95, 0.08], [0.05, 0.2, 0.95, 0.5]], dtype=np.float32), + "class_names": ["Title", "Text"], + "scores": [0.95, 0.88], + } + ] + out = doc_builder( + [np.zeros((100, 100, 3))], + [predictions], + [objectness_scores], + [{CLASS_NAME: [("hello", 0.99), ("world", 0.99)]}], + [(100, 100)], + [{CLASS_NAME: [{"value": 0, "confidence": None}] * 2}], + regions=regions, + ) + page = out.pages[0] + assert len(page.layout) == 2 + assert all(isinstance(region, LayoutElement) for region in page.layout) + assert [region.type for region in page.layout] == ["Title", "Text"] + assert page.export()["layout"] == [region.export() for region in page.layout] + + # no regions -> empty layout + out_no_layout = doc_builder( + [np.zeros((100, 100, 3))], + [predictions], + [objectness_scores], + [{CLASS_NAME: [("hello", 0.99), ("world", 0.99)]}], + [(100, 100)], + [{CLASS_NAME: [{"value": 0, "confidence": None}] * 2}], + ) + assert out_no_layout.pages[0].layout == [] + + @pytest.mark.parametrize( "input_boxes, sorted_idxs", [ diff --git a/tests/pytorch/test_models_zoo_pt.py b/tests/pytorch/test_models_zoo_pt.py index 868842ce82..9b91f2196a 100644 --- a/tests/pytorch/test_models_zoo_pt.py +++ b/tests/pytorch/test_models_zoo_pt.py @@ -6,13 +6,15 @@ from doctr import models from doctr.file_utils import CLASS_NAME from doctr.io import Document, DocumentFile -from doctr.io.elements import KIEDocument -from doctr.models import detection, recognition +from doctr.io.elements import KIEDocument, LayoutElement +from doctr.models import detection, layout, recognition from doctr.models.classification import mobilenet_v3_small_crop_orientation, mobilenet_v3_small_page_orientation from doctr.models.classification.zoo import crop_orientation_predictor, page_orientation_predictor from doctr.models.detection.predictor import DetectionPredictor from doctr.models.detection.zoo import detection_predictor from doctr.models.kie_predictor import KIEPredictor +from doctr.models.layout.predictor import LayoutPredictor +from doctr.models.layout.zoo import layout_predictor from doctr.models.predictor import OCRPredictor from doctr.models.preprocessor import PreProcessor from doctr.models.recognition.predictor import RecognitionPredictor @@ -120,7 +122,141 @@ def test_ocrpredictor( assert out.pages[0].orientation["value"] == orientation -def test_trained_ocr_predictor(mock_payslip): +def test_ocrpredictor_layout(mock_pdf, mock_vocab, mock_payslip): + det_predictor = DetectionPredictor( + PreProcessor(output_size=(512, 512), batch_size=2), + detection.db_mobilenet_v3_large(pretrained=False, pretrained_backbone=False, assume_straight_pages=True), + ) + reco_predictor = RecognitionPredictor( + PreProcessor(output_size=(32, 128), batch_size=32, preserve_aspect_ratio=True), + recognition.crnn_vgg16_bn(pretrained=False, pretrained_backbone=False, vocab=mock_vocab), + ) + layout_pred = layout_predictor("lw_detr_s", pretrained=False) + + doc = DocumentFile.from_pdf(mock_pdf) + + # Without a layout predictor -> pages carry an empty layout + predictor = OCRPredictor(det_predictor, reco_predictor) + assert predictor.layout_predictor is None + out = predictor(doc) + assert all(page.layout == [] for page in out.pages) + assert all(page.export()["layout"] == [] for page in out.pages) + + # With a layout predictor -> detected regions are attached to every page + predictor = OCRPredictor(det_predictor, reco_predictor, layout_predictor=layout_pred) + assert isinstance(predictor.layout_predictor, LayoutPredictor) + out = predictor(doc) + assert isinstance(out, Document) + for page in out.pages: + assert isinstance(page.layout, list) + assert all(isinstance(region, LayoutElement) for region in page.layout) + # the layout is exported alongside the page + exported = page.export() + assert "layout" in exported + assert exported["layout"] == [region.export() for region in page.layout] + + doc = DocumentFile.from_images(mock_payslip) + + det_predictor = detection_predictor( + "fast_base", + pretrained=True, + batch_size=2, + assume_straight_pages=True, + symmetric_pad=True, + preserve_aspect_ratio=False, + ) + reco_predictor = recognition_predictor("crnn_vgg16_bn", pretrained=True, batch_size=128) + + predictor = OCRPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + preserve_aspect_ratio=False, + resolve_blocks=True, + resolve_lines=True, + ) + + out = predictor(doc) + + assert out.pages[0].blocks[0].lines[0].words[0].value == "Mr." + geometry_mr = np.array([[0.1083984375, 0.0634765625], [0.1494140625, 0.0859375]]) + assert np.allclose(np.array(out.pages[0].blocks[0].lines[0].words[0].geometry), geometry_mr, rtol=0.05) + + assert out.pages[0].blocks[1].lines[0].words[-1].value == "revised" + geometry_revised = np.array([[0.7548828125, 0.126953125], [0.8388671875, 0.1484375]]) + assert np.allclose(np.array(out.pages[0].blocks[1].lines[0].words[-1].geometry), geometry_revised, rtol=0.05) + + det_predictor = detection_predictor( + "fast_base", + pretrained=True, + batch_size=2, + assume_straight_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + + predictor = OCRPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + resolve_blocks=True, + resolve_lines=True, + ) + # test hooks + predictor.add_hook(_DummyCallback()) + + out = predictor(doc) + + assert out.pages[0].blocks[0].lines[0].words[0].value == "Mr." + + +def test_trained_ocr_predictor(mock_pdf, mock_vocab, mock_payslip): + det_predictor = DetectionPredictor( + PreProcessor(output_size=(512, 512), batch_size=2), + detection.db_mobilenet_v3_large(pretrained=False, pretrained_backbone=False, assume_straight_pages=True), + ) + reco_predictor = RecognitionPredictor( + PreProcessor(output_size=(32, 128), batch_size=32, preserve_aspect_ratio=True), + recognition.crnn_vgg16_bn(pretrained=False, pretrained_backbone=False, vocab=mock_vocab), + ) + layout_pred = layout_predictor("lw_detr_s", pretrained=True) + + doc = DocumentFile.from_pdf(mock_pdf) + + # Without a layout predictor -> pages carry an empty layout + predictor = OCRPredictor(det_predictor, reco_predictor) + assert predictor.layout_predictor is None + out = predictor(doc) + assert all(page.layout == [] for page in out.pages) + assert all(page.export()["layout"] == [] for page in out.pages) + + # With a layout predictor -> detected regions are attached to every page + predictor = OCRPredictor(det_predictor, reco_predictor, layout_predictor=layout_pred) + assert isinstance(predictor.layout_predictor, LayoutPredictor) + out = predictor(doc) + assert isinstance(out, Document) + for page in out.pages: + assert isinstance(page.layout, list) + assert all(isinstance(region, LayoutElement) for region in page.layout) + # the layout is exported alongside the page + exported = page.export() + assert "layout" in exported + assert exported["layout"] == [region.export() for region in page.layout] + + # Test KIE + predictor = KIEPredictor(det_predictor, reco_predictor, layout_predictor=layout_pred) + assert isinstance(predictor.layout_predictor, LayoutPredictor) + out = predictor(doc) + assert isinstance(out, KIEDocument) + for page in out.pages: + assert isinstance(page.layout, list) + assert all(isinstance(region, LayoutElement) for region in page.layout) + assert page.export()["layout"] == [region.export() for region in page.layout] + doc = DocumentFile.from_images(mock_payslip) det_predictor = detection_predictor( @@ -414,6 +550,26 @@ def test_zoo_models(det_arch, reco_arch): with pytest.raises(ValueError): models.kie_predictor(reco_arch=det_model, pretrained=True) + # Layout-aware OCR predictor via the factory (detect_layout flag) + predictor = models.ocr_predictor(det_arch, reco_arch, pretrained=True, detect_layout=True) + assert isinstance(predictor.layout_predictor, LayoutPredictor) + _test_predictor(predictor) + + # passing a (fine-tuned) layout model instance, like det/reco + layout_model = layout.lw_detr_s(pretrained=False) + predictor = models.ocr_predictor(det_arch, reco_arch, pretrained=True, detect_layout=True, layout_arch=layout_model) + assert isinstance(predictor.layout_predictor, LayoutPredictor) + assert predictor.layout_predictor.model is layout_model + + # disabled by default + predictor = models.ocr_predictor(det_arch, reco_arch, pretrained=True) + assert predictor.layout_predictor is None + + # Layout-aware KIE predictor via the factory + predictor = models.kie_predictor(det_arch, reco_arch, pretrained=True, detect_layout=True) + assert isinstance(predictor.layout_predictor, LayoutPredictor) + _test_kiepredictor(predictor) + @pytest.mark.parametrize( "det_arch, reco_arch",