Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,23 @@ doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
result = model(doc)
```

### Detecting the document layout

You can additionally run a layout detection model as part of the pipeline by passing `detect_layout=True`. The detected regions (e.g. `Title`, `Text`, `Table`, `Page-header`, `Page-footer`) are attached to every page and rendered by `.show()`:

```python
from doctr.io import DocumentFile
from doctr.models import ocr_predictor

model = ocr_predictor(pretrained=True, detect_layout=True)
doc = DocumentFile.from_images("path/to/your/doc.jpg")
result = model(doc)

# Access the detected layout regions of the first page
for region in result.pages[0].layout:
print(region.type, region.confidence, region.geometry)
```

### Dealing with rotated documents

Should you use docTR on documents that include rotated pages, or pages with multiple box orientations,
Expand Down
10 changes: 9 additions & 1 deletion api/app/routes/kie.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status

from app.schemas import KIEElement, KIEIn, KIEOut
from app.schemas import KIEElement, KIEIn, KIEOut, LayoutElementOut
from app.utils import get_documents, resolve_geometry
from app.vision import init_predictor

Expand All @@ -30,6 +30,14 @@ async def perform_kie(request: KIEIn = Depends(), files: list[UploadFile] = [Fil
orientation=page.orientation,
language=page.language,
dimensions=page.dimensions,
layout=[
LayoutElementOut(
type=region.type,
geometry=resolve_geometry(region.geometry),
confidence=round(region.confidence, 2),
)
for region in page.layout
],
predictions=[
KIEElement(
class_name=class_name,
Expand Down
10 changes: 9 additions & 1 deletion api/app/routes/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status

from app.schemas import OCRBlock, OCRIn, OCRLine, OCROut, OCRPage, OCRWord
from app.schemas import LayoutElementOut, OCRBlock, OCRIn, OCRLine, OCROut, OCRPage, OCRWord
from app.utils import get_documents, resolve_geometry
from app.vision import init_predictor

Expand All @@ -31,6 +31,14 @@ async def perform_ocr(request: OCRIn = Depends(), files: list[UploadFile] = [Fil
orientation=page.orientation,
language=page.language,
dimensions=page.dimensions,
layout=[
LayoutElementOut(
type=region.type,
geometry=resolve_geometry(region.geometry),
confidence=round(region.confidence, 2),
)
for region in page.layout
],
items=[
OCRPage(
blocks=[
Expand Down
16 changes: 16 additions & 0 deletions api/app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class KIEIn(BaseModel):
preserve_aspect_ratio: bool = Field(default=True, examples=[True])
detect_orientation: bool = Field(default=False, examples=[False])
detect_language: bool = Field(default=False, examples=[False])
detect_layout: bool = Field(default=False, examples=[False])
layout_arch: str = Field(default="lw_detr_s", examples=["lw_detr_s"])
symmetric_pad: bool = Field(default=True, examples=[True])
straighten_pages: bool = Field(default=False, examples=[False])
det_bs: int = Field(default=2, examples=[2])
Expand Down Expand Up @@ -131,11 +133,21 @@ class OCRPage(BaseModel):
)


class LayoutElementOut(BaseModel):
type: str = Field(..., examples=["Title"])
geometry: list[float] = Field(..., examples=[[0.0, 0.0, 0.0, 0.0]])
confidence: float = Field(..., examples=[0.99])


class OCROut(BaseModel):
name: str = Field(..., examples=["example.jpg"])
orientation: dict[str, float | None] = Field(..., examples=[{"value": 0.0, "confidence": 0.99}])
language: dict[str, str | float | None] = Field(..., examples=[{"value": "en", "confidence": 0.99}])
dimensions: tuple[int, int] = Field(..., examples=[(100, 100)])
layout: list[LayoutElementOut] = Field(
default=[],
examples=[[{"type": "Title", "geometry": [0.0, 0.0, 0.0, 0.0], "confidence": 0.99}]],
)
items: list[OCRPage] = Field(
...,
examples=[
Expand Down Expand Up @@ -183,4 +195,8 @@ class KIEOut(BaseModel):
orientation: dict[str, float | None] = Field(..., examples=[{"value": 0.0, "confidence": 0.99}])
language: dict[str, str | float | None] = Field(..., examples=[{"value": "en", "confidence": 0.99}])
dimensions: tuple[int, int] = Field(..., examples=[(100, 100)])
layout: list[LayoutElementOut] = Field(
default=[],
examples=[[{"type": "Title", "geometry": [0.0, 0.0, 0.0, 0.0], "confidence": 0.99}]],
)
predictions: list[KIEElement]
4 changes: 4 additions & 0 deletions api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def mock_kie_response():
"orientation": {"value": None, "confidence": None},
"language": {"value": None, "confidence": None},
"dimensions": [2339, 1654],
"layout": [],
"predictions": [
{
"class_name": "words",
Expand Down Expand Up @@ -104,6 +105,7 @@ def mock_kie_response():
"orientation": {"value": None, "confidence": None},
"language": {"value": None, "confidence": None},
"dimensions": [2339, 1654],
"layout": [],
"predictions": [
{
"class_name": "words",
Expand Down Expand Up @@ -155,6 +157,7 @@ def mock_ocr_response():
"orientation": {"value": None, "confidence": None},
"language": {"value": None, "confidence": None},
"dimensions": [2339, 1654],
"layout": [],
"items": [
{
"blocks": [
Expand Down Expand Up @@ -203,6 +206,7 @@ def mock_ocr_response():
"orientation": {"value": None, "confidence": None},
"language": {"value": None, "confidence": None},
"dimensions": [2339, 1654],
"layout": [],
"items": [
{
"blocks": [
Expand Down
26 changes: 26 additions & 0 deletions api/tests/routes/test_kie.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
and len(first_pred["dimensions"]) == 2
and all(isinstance(dim, int) for dim in first_pred["dimensions"])
)
assert isinstance(first_pred["layout"], list)

Check warning on line 13 in api/tests/routes/test_kie.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_kie.py#L13

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
for region in first_pred["layout"]:
assert isinstance(region["type"], str)

Check warning on line 15 in api/tests/routes/test_kie.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_kie.py#L15

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert isinstance(region["confidence"], (int, float))

Check warning on line 16 in api/tests/routes/test_kie.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_kie.py#L16

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert isinstance(region["geometry"], (tuple, list))

Check warning on line 17 in api/tests/routes/test_kie.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_kie.py#L17

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert isinstance(first_pred["predictions"], list)
assert isinstance(expected_response["predictions"], list)

Expand Down Expand Up @@ -67,6 +72,27 @@
common_test(json_response, expected_poly_response)


@pytest.mark.asyncio
async def test_kie_layout(test_app_asyncio, mock_detection_image):
headers = {
"accept": "application/json",
}
params = {"det_arch": "db_resnet50", "reco_arch": "crnn_vgg16_bn", "detect_layout": True}
files = [
("files", ("test.jpg", mock_detection_image, "image/jpeg")),
]
response = await test_app_asyncio.post("/kie", params=params, files=files, headers=headers)
assert response.status_code == 200
json_response = response.json()

assert isinstance(json_response, list) and len(json_response) == 1

Check warning on line 88 in api/tests/routes/test_kie.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_kie.py#L88

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert "layout" in json_response[0] and isinstance(json_response[0]["layout"], list)

Check warning on line 89 in api/tests/routes/test_kie.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_kie.py#L89

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
for region in json_response[0]["layout"]:
assert isinstance(region["type"], str)
assert isinstance(region["confidence"], (int, float))
assert len(region["geometry"]) in (4, 8)

Check warning on line 93 in api/tests/routes/test_kie.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_kie.py#L93

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.


@pytest.mark.asyncio
async def test_kie_invalid_file(test_app_asyncio, mock_txt_file):
headers = {
Expand Down
26 changes: 26 additions & 0 deletions api/tests/routes/test_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
and len(first_pred["dimensions"]) == 2
and all(isinstance(dim, int) for dim in first_pred["dimensions"])
)
assert isinstance(first_pred["layout"], list)

Check warning on line 14 in api/tests/routes/test_ocr.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_ocr.py#L14

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
for region in first_pred["layout"]:
assert isinstance(region["type"], str)

Check warning on line 16 in api/tests/routes/test_ocr.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_ocr.py#L16

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert isinstance(region["confidence"], (int, float))

Check warning on line 17 in api/tests/routes/test_ocr.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_ocr.py#L17

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert isinstance(region["geometry"], (tuple, list))

Check warning on line 18 in api/tests/routes/test_ocr.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_ocr.py#L18

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
for item, expected_item in zip(first_pred["items"], expected_response["items"]):
for block, expected_block in zip(item["blocks"], expected_item["blocks"]):
np.testing.assert_allclose(block["geometry"], expected_block["geometry"], rtol=1e-2)
Expand Down Expand Up @@ -67,6 +72,27 @@
common_test(json_response, expected_poly_response)


@pytest.mark.asyncio
async def test_ocr_layout(test_app_asyncio, mock_detection_image):
headers = {
"accept": "application/json",
}
params = {"det_arch": "db_resnet50", "reco_arch": "crnn_vgg16_bn", "detect_layout": True}
files = [
("files", ("test.jpg", mock_detection_image, "image/jpeg")),
]
response = await test_app_asyncio.post("/ocr", params=params, files=files, headers=headers)
assert response.status_code == 200
json_response = response.json()

assert isinstance(json_response, list) and len(json_response) == 1

Check warning on line 88 in api/tests/routes/test_ocr.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_ocr.py#L88

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert "layout" in json_response[0] and isinstance(json_response[0]["layout"], list)

Check warning on line 89 in api/tests/routes/test_ocr.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_ocr.py#L89

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
for region in json_response[0]["layout"]:
assert isinstance(region["type"], str)
assert isinstance(region["confidence"], (int, float))
assert len(region["geometry"]) in (4, 8)

Check warning on line 93 in api/tests/routes/test_ocr.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_ocr.py#L93

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.


@pytest.mark.asyncio
async def test_ocr_invalid_file(test_app_asyncio, mock_txt_file):
headers = {
Expand Down
11 changes: 8 additions & 3 deletions demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
import numpy as np
import streamlit as st
import torch
from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
from backend.pytorch import DET_ARCHS, LAYOUT_ARCHS, RECO_ARCHS, forward_image, load_predictor

from doctr.io import DocumentFile
from doctr.utils.visualization import visualize_page

forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def main(det_archs, reco_archs):
def main(det_archs, reco_archs, layout_archs):
"""Build a streamlit layout"""
# Wide mode
st.set_page_config(layout="wide")
Expand Down Expand Up @@ -67,6 +67,9 @@ def main(det_archs, reco_archs):
straighten_pages = st.sidebar.checkbox("Straighten pages", value=False)
# Export as straight boxes
export_straight_boxes = st.sidebar.checkbox("Export as straight boxes", value=False)
# Layout detection
detect_layout = st.sidebar.checkbox("Detect layout", value=False)
layout_arch = st.sidebar.selectbox("Layout detection model", layout_archs, disabled=not detect_layout)
st.sidebar.write("\n")
# Binarization threshold
bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
Expand All @@ -92,6 +95,8 @@ def main(det_archs, reco_archs):
bin_thresh=bin_thresh,
box_thresh=box_thresh,
device=forward_device,
detect_layout=detect_layout,
layout_arch=layout_arch,
)

with st.spinner("Analyzing..."):
Expand Down Expand Up @@ -123,4 +128,4 @@ def main(det_archs, reco_archs):


if __name__ == "__main__":
main(DET_ARCHS, RECO_ARCHS)
main(DET_ARCHS, RECO_ARCHS, LAYOUT_ARCHS)
10 changes: 10 additions & 0 deletions demo/backend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
"parseq",
"viptr_tiny",
]
LAYOUT_ARCHS = [
"lw_detr_s",
"lw_detr_m",
]


def load_predictor(
Expand All @@ -44,6 +48,8 @@ def load_predictor(
bin_thresh: float,
box_thresh: float,
device: torch.device,
detect_layout: bool,
layout_arch: str,
) -> OCRPredictor:
"""Load a predictor from doctr.models
Expand All @@ -58,6 +64,8 @@ def load_predictor(
bin_thresh: binarization threshold for the segmentation map
box_thresh: minimal objectness score to consider a box
device: torch.device, the device to load the predictor on
detect_layout: whether to run a layout detection model and attach the regions to each page
layout_arch: layout architecture to use when detect_layout is True
Returns:
instance of OCRPredictor
Expand All @@ -72,6 +80,8 @@ def load_predictor(
detect_orientation=not assume_straight_pages,
disable_page_orientation=disable_page_orientation,
disable_crop_orientation=disable_crop_orientation,
detect_layout=detect_layout,
layout_arch=layout_arch,
).to(device)
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
Expand Down
7 changes: 7 additions & 0 deletions docs/source/modules/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ An Artefact is a non-textual element (e.g. QR code, picture, chart, signature, l

.. autoclass:: Artefact

LayoutElement
^^^^^^^^^^^^^

A LayoutElement is a region predicted by a layout detection model (e.g. Title, Text, Table, Page-header, Page-footer). Layout regions are attached to a :class:`Page` when the ``ocr_predictor`` / ``kie_predictor`` is run with ``detect_layout=True``.

.. autoclass:: LayoutElement

Block
^^^^^
A Block is a collection of Lines (e.g. an address written on several lines) and Artefacts (e.g. a graph with its title underneath).
Expand Down
2 changes: 2 additions & 0 deletions docs/source/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ doctr.models.layout

.. autofunction:: doctr.models.layout.lw_detr_m

.. autofunction:: doctr.models.layout.layout_predictor


doctr.models.recognition
------------------------
Expand Down
20 changes: 20 additions & 0 deletions docs/source/using_doctr/custom_models_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,26 @@ Load a custom layout analysis model trained on another set of classes as the def

predictor = layout_predictor(layout_arch=layout_model, pretrained=True)


Plug a custom layout analysis model (trained on another set of classes) directly into the OCR pipeline so the detected regions are attached to every page:

.. code:: python3

import torch
from doctr.models import ocr_predictor, lw_detr_s

# Custom layout model with your own class names
layout_model = lw_detr_s(pretrained=False, class_names=["heading", "paragraph", "figure", "table"])
layout_model.from_pretrained('<path_to_pt>')

# Pass it through `layout_arch`, exactly as for the detection / recognition models
predictor = ocr_predictor(pretrained=True, detect_layout=True, layout_arch=layout_model)

result = predictor(doc)
# The regions (with your custom class names) are available on each page
print([(region.type, region.confidence) for region in result.pages[0].layout])


Load a custom trained KIE detection model:

.. code:: python3
Expand Down
12 changes: 12 additions & 0 deletions docs/source/using_doctr/sharing_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ We suggest using the following naming conventions for your models:

**Recognition:** ``doctr-<architecture>-<vocab>``

**Layout:** ``doctr-<architecture>``


Classification
--------------
Expand Down Expand Up @@ -101,3 +103,13 @@ Recognition
+---------------------------------+---------------------------------------------------+---------------------+------------------------+
| parseq | rania-sr/doctr-model-v1-arabic | arabic | PyTorch |
+---------------------------------+---------------------------------------------------+---------------------+------------------------+


Layout
------

+---------------------------------+---------------------------------------------------+------------------------+
| **Architecture** | **Repo_ID** | **Framework** |
+=================================+===================================================+========================+
| lw_detr_s (dummy) | Felix92/doctr-dummy-torch-lw-detr-s | PyTorch |
+---------------------------------+---------------------------------------------------+------------------------+
Loading
Loading