diff --git a/doctr/datasets/ic03.py b/doctr/datasets/ic03.py index e61c2049c6..0216720322 100644 --- a/doctr/datasets/ic03.py +++ b/doctr/datasets/ic03.py @@ -4,9 +4,9 @@ # See LICENSE or go to for full license details. import os +import xml.etree.ElementTree as ET # nosec B405 from typing import Any -import defusedxml.ElementTree as ET import numpy as np from tqdm import tqdm diff --git a/doctr/datasets/svt.py b/doctr/datasets/svt.py index 1a7f5d2a6a..bcbac7769a 100644 --- a/doctr/datasets/svt.py +++ b/doctr/datasets/svt.py @@ -4,9 +4,9 @@ # See LICENSE or go to for full license details. import os +import xml.etree.ElementTree as ET # nosec B405 from typing import Any -import defusedxml.ElementTree as ET import numpy as np from tqdm import tqdm diff --git a/doctr/io/elements.py b/doctr/io/elements.py index 3402efbde2..c9cad3a12a 100644 --- a/doctr/io/elements.py +++ b/doctr/io/elements.py @@ -4,10 +4,6 @@ # See LICENSE or go to for full license details. from typing import Any - -from defusedxml import defuse_stdlib - -defuse_stdlib() from xml.etree import ElementTree as ET from xml.etree.ElementTree import Element as ETElement from xml.etree.ElementTree import SubElement @@ -26,7 +22,7 @@ except ModuleNotFoundError: pass -__all__ = ["Element", "Word", "Artefact", "Line", "Prediction", "Block", "Page", "KIEPage", "Document"] +__all__ = ["Element", "Word", "Artefact", "Line", "Prediction", "Block", "Page", "KIEPage", "Document", "LayoutElement"] class Element(NestedObject): @@ -70,7 +66,7 @@ class Word(Element): value: the text string of the word confidence: the confidence associated with the text prediction geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to - the page's size + the page's size objectness_score: the objectness score of the detection crop_orientation: the general orientation of the crop in degrees and its confidence """ @@ -126,8 +122,8 @@ def __init__(self, artefact_type: str, confidence: float, geometry: BoundingBox) self.confidence = confidence def render(self) -> str: - """Renders the full text of the element""" - return f"[{self.type.upper()}]" + """Renders the region as a tag""" + return f"<[{self.type.upper()}]>" def extra_repr(self) -> str: return f"type='{self.type}', confidence={self.confidence:.2}" @@ -138,6 +134,38 @@ def from_dict(cls, save_dict: dict[str, Any], **kwargs): return cls(**kwargs) +class LayoutElement(Element): + """Implements a layout region predicted by a layout detection model + + Args: + layout_type: the predicted region class (e.g. 'Title', 'Text', 'Table', 'Page-header') + confidence: the confidence of the region prediction + geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to + the page's size + """ + + _exported_keys: list[str] = ["geometry", "type", "confidence"] + _children_names: list[str] = [] + + def __init__(self, layout_type: str, confidence: float, geometry: BoundingBox | np.ndarray) -> None: + super().__init__() + self.geometry = geometry + self.type = layout_type + self.confidence = confidence + + def render(self) -> str: + """Renders the region as a tag""" + return f"<[{self.type.upper()}]>" + + def extra_repr(self) -> str: + return f"type='{self.type}', confidence={self.confidence:.2}" + + @classmethod + def from_dict(cls, save_dict: dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + return cls(layout_type=kwargs["type"], confidence=kwargs["confidence"], geometry=kwargs["geometry"]) + + class Line(Element): """Implements a line element as a collection of words @@ -258,11 +286,13 @@ class Page(Element): dimensions: the page size in pixels in format (height, width) orientation: a dictionary with the value of the rotation angle in degress and confidence of the prediction language: a dictionary with the language value and confidence of the prediction + layout: optional list of layout regions detected on the page """ _exported_keys: list[str] = ["page_idx", "dimensions", "orientation", "language"] - _children_names: list[str] = ["blocks"] + _children_names: list[str] = ["blocks", "layout"] blocks: list[Block] = [] + layout: list[LayoutElement] = [] def __init__( self, @@ -272,8 +302,9 @@ def __init__( dimensions: tuple[int, int], orientation: dict[str, Any] | None = None, language: dict[str, Any] | None = None, + layout: list[LayoutElement] | None = None, ) -> None: - super().__init__(blocks=blocks) + super().__init__(blocks=blocks, layout=layout if layout is not None else []) self.page = page self.page_idx = page_idx self.dimensions = dimensions @@ -294,12 +325,20 @@ def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, ** interactive: whether the display should be interactive preserve_aspect_ratio: pass True if you passed True to the predictor **kwargs: additional keyword arguments passed to the matplotlib.pyplot.show method + (e.g. ``display_layout=False`` to hide detected layout regions) """ requires_package("matplotlib", "`.show()` requires matplotlib & mplcursors installed") requires_package("mplcursors", "`.show()` requires matplotlib & mplcursors installed") import matplotlib.pyplot as plt - visualize_page(self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio) + show_kwargs = {k: kwargs.pop(k) for k in ("words_only", "display_artefacts", "display_layout") if k in kwargs} + visualize_page( + self.export(), + self.page, + interactive=interactive, + preserve_aspect_ratio=preserve_aspect_ratio, + **show_kwargs, + ) plt.show(**kwargs) def synthesize(self, **kwargs) -> np.ndarray: @@ -420,7 +459,10 @@ def export_as_xml(self, file_title: str = "docTR - XML export (hOCR)") -> tuple[ @classmethod def from_dict(cls, save_dict: dict[str, Any], **kwargs): kwargs = {k: save_dict[k] for k in cls._exported_keys} - kwargs.update({"blocks": [Block.from_dict(block_dict) for block_dict in save_dict["blocks"]]}) + kwargs.update({ + "blocks": [Block.from_dict(block_dict) for block_dict in save_dict["blocks"]], + "layout": [LayoutElement.from_dict(region_dict) for region_dict in save_dict.get("layout", [])], + }) return cls(**kwargs) @@ -434,11 +476,13 @@ class KIEPage(Element): dimensions: the page size in pixels in format (height, width) orientation: a dictionary with the value of the rotation angle in degress and confidence of the prediction language: a dictionary with the language value and confidence of the prediction + layout: optional list of layout regions detected on the page """ _exported_keys: list[str] = ["page_idx", "dimensions", "orientation", "language"] - _children_names: list[str] = ["predictions"] + _children_names: list[str] = ["predictions", "layout"] predictions: dict[str, list[Prediction]] = {} + layout: list[LayoutElement] = [] def __init__( self, @@ -448,8 +492,9 @@ def __init__( dimensions: tuple[int, int], orientation: dict[str, Any] | None = None, language: dict[str, Any] | None = None, + layout: list[LayoutElement] | None = None, ) -> None: - super().__init__(predictions=predictions) + super().__init__(predictions=predictions, layout=layout if layout is not None else []) self.page = page self.page_idx = page_idx self.dimensions = dimensions @@ -477,8 +522,13 @@ def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, ** requires_package("mplcursors", "`.show()` requires matplotlib & mplcursors installed") import matplotlib.pyplot as plt + show_kwargs = {k: kwargs.pop(k) for k in ("words_only", "display_artefacts", "display_layout") if k in kwargs} visualize_kie_page( - self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio + self.export(), + self.page, + interactive=interactive, + preserve_aspect_ratio=preserve_aspect_ratio, + **show_kwargs, ) plt.show(**kwargs) @@ -593,7 +643,11 @@ def export_as_xml(self, file_title: str = "docTR - XML export (hOCR)") -> tuple[ def from_dict(cls, save_dict: dict[str, Any], **kwargs): kwargs = {k: save_dict[k] for k in cls._exported_keys} kwargs.update({ - "predictions": [Prediction.from_dict(predictions_dict) for predictions_dict in save_dict["predictions"]] + "predictions": { + class_name: [Prediction.from_dict(pred) for pred in preds] + for class_name, preds in save_dict["predictions"].items() + }, + "layout": [LayoutElement.from_dict(region_dict) for region_dict in save_dict.get("layout", [])], }) return cls(**kwargs) diff --git a/doctr/utils/visualization.py b/doctr/utils/visualization.py index 33e3a19bd5..bf634fe1d9 100644 --- a/doctr/utils/visualization.py +++ b/doctr/utils/visualization.py @@ -157,6 +157,7 @@ def visualize_page( image: np.ndarray, words_only: bool = True, display_artefacts: bool = True, + display_layout: bool = True, scale: float = 10, interactive: bool = True, add_labels: bool = True, @@ -179,6 +180,7 @@ def visualize_page( image: np array of the page, needs to have the same shape than page['dimensions'] words_only: whether only words should be displayed display_artefacts: whether artefacts should be displayed + display_layout: whether detected layout regions should be displayed scale: figsize of the largest windows side interactive: whether the plot should be interactive add_labels: for static plot, adds text labels on top of bounding box @@ -199,6 +201,33 @@ def visualize_page( if interactive: artists: list[patches.Patch] = [] # instantiate an empty list of patches (to be drawn on the page) + # Draw layout regions first so text boxes are overlaid on top of them + if display_layout and page.get("layout"): + region_classes = sorted({region["type"] for region in page["layout"]}) + layout_colors = {cls: color for color, cls in zip(get_colors(max(len(region_classes), 1)), region_classes)} + for region in page["layout"]: + rect = create_obj_patch( + region["geometry"], + page["dimensions"], + label=f"{region['type']} (confidence: {region['confidence']:.2%})", + color=layout_colors[region["type"]], + linewidth=2, + fill=False, + **kwargs, + ) + ax.add_patch(rect) + if interactive: + artists.append(rect) + elif add_labels and len(region["geometry"]) == 2: + ax.text( + int(page["dimensions"][1] * region["geometry"][0][0]), + int(page["dimensions"][0] * region["geometry"][0][1]), + region["type"], + size=9, + alpha=0.7, + color=layout_colors[region["type"]], + ) + for block in page["blocks"]: if not words_only: rect = create_obj_patch( @@ -281,6 +310,7 @@ def visualize_kie_page( image: np.ndarray, words_only: bool = False, display_artefacts: bool = True, + display_layout: bool = True, scale: float = 10, interactive: bool = True, add_labels: bool = True, @@ -303,6 +333,7 @@ def visualize_kie_page( image: np array of the page, needs to have the same shape than page['dimensions'] words_only: whether only words should be displayed display_artefacts: whether artefacts should be displayed + display_layout: whether detected layout regions should be displayed scale: figsize of the largest windows side interactive: whether the plot should be interactive add_labels: for static plot, adds text labels on top of bounding box @@ -323,6 +354,23 @@ def visualize_kie_page( if interactive: artists: list[patches.Patch] = [] # instantiate an empty list of patches (to be drawn on the page) + if display_layout and page.get("layout"): + region_classes = sorted({region["type"] for region in page["layout"]}) + layout_colors = {cls: color for color, cls in zip(get_colors(max(len(region_classes), 1)), region_classes)} + for region in page["layout"]: + rect = create_obj_patch( + region["geometry"], + page["dimensions"], + label=f"{region['type']} (confidence: {region['confidence']:.2%})", + color=layout_colors[region["type"]], + linewidth=2, + fill=False, + **kwargs, + ) + ax.add_patch(rect) + if interactive: + artists.append(rect) + colors = {k: color for color, k in zip(get_colors(len(page["predictions"])), page["predictions"])} for key, value in page["predictions"].items(): for prediction in value: diff --git a/pyproject.toml b/pyproject.toml index 9e5c8418d7..14c2452ad3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,6 @@ dependencies = [ "rapidfuzz>=3.0.0,<4.0.0", "huggingface-hub>=0.20.0,<2.0.0", "Pillow>=9.2.0", - "defusedxml>=0.7.0", "anyascii>=0.3.2", "validators>=0.18.0", "tqdm>=4.30.0", @@ -159,7 +158,6 @@ module = [ "pyclipper.*", "shapely.*", "mplcursors.*", - "defusedxml.*", "weasyprint.*", "huggingface_hub.*", "pypdfium2.*", diff --git a/tests/common/test_io_elements.py b/tests/common/test_io_elements.py index c0158a50e9..7da0f4a004 100644 --- a/tests/common/test_io_elements.py +++ b/tests/common/test_io_elements.py @@ -40,6 +40,13 @@ def _mock_artefacts(size=(1, 1), offset=(0, 0), confidence=0.8): ] +def _mock_layout(): + return [ + elements.LayoutElement("Title", 0.95, ((0.1, 0.05), (0.9, 0.15))), + elements.LayoutElement("Text", 0.88, ((0.1, 0.2), (0.9, 0.9))), + ] + + def _mock_lines(size=(1, 1), offset=(0, 0)): sub_size = (size[0] / 2, size[1] / 2) return [ @@ -229,7 +236,7 @@ def test_artefact(): assert artefact.geometry == geom # Render - assert artefact.render() == "[QR_CODE]" + assert artefact.render() == "<[QR_CODE]>" # Export assert artefact.export() == {"type": artefact_type, "confidence": conf, "geometry": geom} @@ -238,6 +245,32 @@ def test_artefact(): assert artefact.__repr__() == f"Artefact(type='{artefact_type}', confidence={conf:.2})" +def test_layout_element(): + layout_type = "Title" + conf = 0.9 + geom = ((0, 0), (1, 1)) + region = elements.LayoutElement(layout_type, conf, geom) + + # Attribute checks + assert region.type == layout_type + assert region.confidence == conf + assert region.geometry == geom + + # Render + assert region.render() == "<[TITLE]>" + + # Export + assert region.export() == {"type": layout_type, "confidence": conf, "geometry": geom} + + # Repr + assert region.__repr__() == f"LayoutElement(type='{layout_type}', confidence={conf:.2})" + + # Class method + state_dict = {"geometry": ((0, 0), (0.5, 0.5)), "type": "Table", "confidence": 0.7} + region = elements.LayoutElement.from_dict(state_dict) + assert region.export() == state_dict + + def test_prediction(): prediction_str = "hello" conf = 0.8 @@ -314,11 +347,14 @@ def test_page(): orientation = {"value": 0.0, "confidence": 0.0} language = {"value": "EN", "confidence": 0.8} blocks = _mock_blocks() - page = elements.Page(page, blocks, page_idx, page_size, orientation, language) + layout = _mock_layout() + page = elements.Page(page, blocks, page_idx, page_size, orientation, language, layout=layout) # Attribute checks assert len(page.blocks) == len(blocks) assert all(isinstance(b, elements.Block) for b in page.blocks) + assert len(page.layout) == len(layout) + assert all(isinstance(r, elements.LayoutElement) for r in page.layout) assert isinstance(page.page, np.ndarray) assert page.page_idx == page_idx assert page.dimensions == page_size @@ -335,6 +371,7 @@ def test_page(): "dimensions": page_size, "orientation": orientation, "language": language, + "layout": [r.export() for r in layout], } # Export XML @@ -356,6 +393,15 @@ def test_page(): assert img.shape == (*page_size, 3) +def test_page_without_layout(): + # Backward compatibility: layout defaults to an empty list + page = np.zeros((300, 200, 3), dtype=np.uint8) + page = elements.Page(page, _mock_blocks(), 0, (300, 200)) + + assert page.layout == [] + assert page.export()["layout"] == [] + + def test_kiepage(): page = np.zeros((300, 200, 3), dtype=np.uint8) page_idx = 0 @@ -363,11 +409,14 @@ def test_kiepage(): orientation = {"value": 0.0, "confidence": 0.0} language = {"value": "EN", "confidence": 0.8} predictions = {CLASS_NAME: _mock_prediction()} - kie_page = elements.KIEPage(page, predictions, page_idx, page_size, orientation, language) + layout = _mock_layout() + kie_page = elements.KIEPage(page, predictions, page_idx, page_size, orientation, language, layout=layout) # Attribute checks assert len(kie_page.predictions) == len(predictions) assert all(isinstance(b, elements.Prediction) for b in kie_page.predictions[CLASS_NAME]) + assert len(kie_page.layout) == len(layout) + assert all(isinstance(r, elements.LayoutElement) for r in kie_page.layout) assert isinstance(kie_page.page, np.ndarray) assert kie_page.page_idx == page_idx assert kie_page.dimensions == page_size @@ -384,6 +433,7 @@ def test_kiepage(): "dimensions": page_size, "orientation": orientation, "language": language, + "layout": [r.export() for r in layout], } # Export XML diff --git a/tests/common/test_utils_visualization.py b/tests/common/test_utils_visualization.py index ae232ebb85..cc003c416b 100644 --- a/tests/common/test_utils_visualization.py +++ b/tests/common/test_utils_visualization.py @@ -1,6 +1,6 @@ import numpy as np import pytest -from test_io_elements import _mock_pages +from test_io_elements import _mock_kie_pages, _mock_layout, _mock_pages from doctr.utils import visualization @@ -10,17 +10,35 @@ def test_visualize_page(): image = np.ones((300, 200, 3)) visualization.visualize_page(pages[0].export(), image, words_only=False) visualization.visualize_page(pages[0].export(), image, words_only=True, interactive=False) - # geometry checks + + # with detected layout regions + page_export = pages[0].export() + page_export["layout"] = [region.export() for region in _mock_layout()] + visualization.visualize_page(page_export, image, words_only=False, display_layout=True) + visualization.visualize_page(page_export, image, words_only=False, display_layout=True, interactive=False) + visualization.visualize_page(page_export, image, words_only=False, display_layout=False, interactive=False) + with pytest.raises(ValueError): visualization.create_obj_patch([1, 2], (100, 100)) - with pytest.raises(ValueError): visualization.create_obj_patch((1, 2), (100, 100)) - with pytest.raises(ValueError): visualization.create_obj_patch((1, 2, 3, 4, 5), (100, 100)) +def test_visualize_kie_page(): + pages = _mock_kie_pages() + image = np.ones((300, 200, 3)) + visualization.visualize_kie_page(pages[0].export(), image, words_only=False) + visualization.visualize_kie_page(pages[0].export(), image, words_only=True, interactive=False) + + # with detected layout regions + page_export = pages[0].export() + page_export["layout"] = [region.export() for region in _mock_layout()] + visualization.visualize_kie_page(page_export, image, words_only=False, display_layout=True) + visualization.visualize_kie_page(page_export, image, words_only=False, display_layout=False, interactive=False) + + def test_draw_boxes(): image = np.ones((256, 256, 3), dtype=np.float32) boxes = [