Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doctr/datasets/ic03.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import os
import xml.etree.ElementTree as ET # nosec B405
from typing import Any

import defusedxml.ElementTree as ET
import numpy as np
from tqdm import tqdm

Expand Down
2 changes: 1 addition & 1 deletion doctr/datasets/svt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import os
import xml.etree.ElementTree as ET # nosec B405
from typing import Any

import defusedxml.ElementTree as ET
import numpy as np
from tqdm import tqdm

Expand Down
86 changes: 70 additions & 16 deletions doctr/io/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Any

from defusedxml import defuse_stdlib

defuse_stdlib()
from xml.etree import ElementTree as ET
from xml.etree.ElementTree import Element as ETElement
from xml.etree.ElementTree import SubElement
Expand All @@ -26,7 +22,7 @@
except ModuleNotFoundError:
pass

__all__ = ["Element", "Word", "Artefact", "Line", "Prediction", "Block", "Page", "KIEPage", "Document"]
__all__ = ["Element", "Word", "Artefact", "Line", "Prediction", "Block", "Page", "KIEPage", "Document", "LayoutElement"]


class Element(NestedObject):
Expand Down Expand Up @@ -70,7 +66,7 @@ class Word(Element):
value: the text string of the word
confidence: the confidence associated with the text prediction
geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to
the page's size
the page's size
objectness_score: the objectness score of the detection
crop_orientation: the general orientation of the crop in degrees and its confidence
"""
Expand Down Expand Up @@ -126,8 +122,8 @@ def __init__(self, artefact_type: str, confidence: float, geometry: BoundingBox)
self.confidence = confidence

def render(self) -> str:
"""Renders the full text of the element"""
return f"[{self.type.upper()}]"
"""Renders the region as a tag"""
return f"<[{self.type.upper()}]>"

def extra_repr(self) -> str:
return f"type='{self.type}', confidence={self.confidence:.2}"
Expand All @@ -138,6 +134,38 @@ def from_dict(cls, save_dict: dict[str, Any], **kwargs):
return cls(**kwargs)


class LayoutElement(Element):
"""Implements a layout region predicted by a layout detection model

Args:
layout_type: the predicted region class (e.g. 'Title', 'Text', 'Table', 'Page-header')
confidence: the confidence of the region prediction
geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to
the page's size
"""

_exported_keys: list[str] = ["geometry", "type", "confidence"]
_children_names: list[str] = []

def __init__(self, layout_type: str, confidence: float, geometry: BoundingBox | np.ndarray) -> None:
super().__init__()
self.geometry = geometry
self.type = layout_type
self.confidence = confidence

def render(self) -> str:
"""Renders the region as a tag"""
return f"<[{self.type.upper()}]>"

def extra_repr(self) -> str:
return f"type='{self.type}', confidence={self.confidence:.2}"

@classmethod
def from_dict(cls, save_dict: dict[str, Any], **kwargs):
kwargs = {k: save_dict[k] for k in cls._exported_keys}
return cls(layout_type=kwargs["type"], confidence=kwargs["confidence"], geometry=kwargs["geometry"])


class Line(Element):
"""Implements a line element as a collection of words

Expand Down Expand Up @@ -258,11 +286,13 @@ class Page(Element):
dimensions: the page size in pixels in format (height, width)
orientation: a dictionary with the value of the rotation angle in degress and confidence of the prediction
language: a dictionary with the language value and confidence of the prediction
layout: optional list of layout regions detected on the page
"""

_exported_keys: list[str] = ["page_idx", "dimensions", "orientation", "language"]
_children_names: list[str] = ["blocks"]
_children_names: list[str] = ["blocks", "layout"]
blocks: list[Block] = []
layout: list[LayoutElement] = []

def __init__(
self,
Expand All @@ -272,8 +302,9 @@ def __init__(
dimensions: tuple[int, int],
orientation: dict[str, Any] | None = None,
language: dict[str, Any] | None = None,
layout: list[LayoutElement] | None = None,
) -> None:
super().__init__(blocks=blocks)
super().__init__(blocks=blocks, layout=layout if layout is not None else [])
self.page = page
self.page_idx = page_idx
self.dimensions = dimensions
Expand All @@ -294,12 +325,20 @@ def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, **
interactive: whether the display should be interactive
preserve_aspect_ratio: pass True if you passed True to the predictor
**kwargs: additional keyword arguments passed to the matplotlib.pyplot.show method
(e.g. ``display_layout=False`` to hide detected layout regions)
"""
requires_package("matplotlib", "`.show()` requires matplotlib & mplcursors installed")
requires_package("mplcursors", "`.show()` requires matplotlib & mplcursors installed")
import matplotlib.pyplot as plt

visualize_page(self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio)
show_kwargs = {k: kwargs.pop(k) for k in ("words_only", "display_artefacts", "display_layout") if k in kwargs}
visualize_page(
self.export(),
self.page,
interactive=interactive,
preserve_aspect_ratio=preserve_aspect_ratio,
**show_kwargs,
)
plt.show(**kwargs)

def synthesize(self, **kwargs) -> np.ndarray:
Expand Down Expand Up @@ -420,7 +459,10 @@ def export_as_xml(self, file_title: str = "docTR - XML export (hOCR)") -> tuple[
@classmethod
def from_dict(cls, save_dict: dict[str, Any], **kwargs):
kwargs = {k: save_dict[k] for k in cls._exported_keys}
kwargs.update({"blocks": [Block.from_dict(block_dict) for block_dict in save_dict["blocks"]]})
kwargs.update({
"blocks": [Block.from_dict(block_dict) for block_dict in save_dict["blocks"]],
"layout": [LayoutElement.from_dict(region_dict) for region_dict in save_dict.get("layout", [])],
})
return cls(**kwargs)


Expand All @@ -434,11 +476,13 @@ class KIEPage(Element):
dimensions: the page size in pixels in format (height, width)
orientation: a dictionary with the value of the rotation angle in degress and confidence of the prediction
language: a dictionary with the language value and confidence of the prediction
layout: optional list of layout regions detected on the page
"""

_exported_keys: list[str] = ["page_idx", "dimensions", "orientation", "language"]
_children_names: list[str] = ["predictions"]
_children_names: list[str] = ["predictions", "layout"]
predictions: dict[str, list[Prediction]] = {}
layout: list[LayoutElement] = []

def __init__(
self,
Expand All @@ -448,8 +492,9 @@ def __init__(
dimensions: tuple[int, int],
orientation: dict[str, Any] | None = None,
language: dict[str, Any] | None = None,
layout: list[LayoutElement] | None = None,
) -> None:
super().__init__(predictions=predictions)
super().__init__(predictions=predictions, layout=layout if layout is not None else [])
self.page = page
self.page_idx = page_idx
self.dimensions = dimensions
Expand Down Expand Up @@ -477,8 +522,13 @@ def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, **
requires_package("mplcursors", "`.show()` requires matplotlib & mplcursors installed")
import matplotlib.pyplot as plt

show_kwargs = {k: kwargs.pop(k) for k in ("words_only", "display_artefacts", "display_layout") if k in kwargs}
visualize_kie_page(
self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio
self.export(),
self.page,
interactive=interactive,
preserve_aspect_ratio=preserve_aspect_ratio,
**show_kwargs,
)
plt.show(**kwargs)

Expand Down Expand Up @@ -593,7 +643,11 @@ def export_as_xml(self, file_title: str = "docTR - XML export (hOCR)") -> tuple[
def from_dict(cls, save_dict: dict[str, Any], **kwargs):
kwargs = {k: save_dict[k] for k in cls._exported_keys}
kwargs.update({
"predictions": [Prediction.from_dict(predictions_dict) for predictions_dict in save_dict["predictions"]]
"predictions": {
class_name: [Prediction.from_dict(pred) for pred in preds]
for class_name, preds in save_dict["predictions"].items()
},
"layout": [LayoutElement.from_dict(region_dict) for region_dict in save_dict.get("layout", [])],
})
return cls(**kwargs)

Expand Down
48 changes: 48 additions & 0 deletions doctr/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def visualize_page(
image: np.ndarray,
words_only: bool = True,
display_artefacts: bool = True,
display_layout: bool = True,
scale: float = 10,
interactive: bool = True,
add_labels: bool = True,
Expand All @@ -179,6 +180,7 @@ def visualize_page(
image: np array of the page, needs to have the same shape than page['dimensions']
words_only: whether only words should be displayed
display_artefacts: whether artefacts should be displayed
display_layout: whether detected layout regions should be displayed
scale: figsize of the largest windows side
interactive: whether the plot should be interactive
add_labels: for static plot, adds text labels on top of bounding box
Expand All @@ -199,6 +201,33 @@ def visualize_page(
if interactive:
artists: list[patches.Patch] = [] # instantiate an empty list of patches (to be drawn on the page)

# Draw layout regions first so text boxes are overlaid on top of them
if display_layout and page.get("layout"):
region_classes = sorted({region["type"] for region in page["layout"]})
layout_colors = {cls: color for color, cls in zip(get_colors(max(len(region_classes), 1)), region_classes)}
for region in page["layout"]:
rect = create_obj_patch(
region["geometry"],
page["dimensions"],
label=f"{region['type']} (confidence: {region['confidence']:.2%})",
color=layout_colors[region["type"]],
linewidth=2,
fill=False,
**kwargs,
)
ax.add_patch(rect)
if interactive:
artists.append(rect)
elif add_labels and len(region["geometry"]) == 2:
ax.text(
int(page["dimensions"][1] * region["geometry"][0][0]),
int(page["dimensions"][0] * region["geometry"][0][1]),
region["type"],
size=9,
alpha=0.7,
color=layout_colors[region["type"]],
)

for block in page["blocks"]:
if not words_only:
rect = create_obj_patch(
Expand Down Expand Up @@ -281,6 +310,7 @@ def visualize_kie_page(
image: np.ndarray,
words_only: bool = False,
display_artefacts: bool = True,
display_layout: bool = True,
scale: float = 10,
interactive: bool = True,
add_labels: bool = True,
Expand All @@ -303,6 +333,7 @@ def visualize_kie_page(
image: np array of the page, needs to have the same shape than page['dimensions']
words_only: whether only words should be displayed
display_artefacts: whether artefacts should be displayed
display_layout: whether detected layout regions should be displayed
scale: figsize of the largest windows side
interactive: whether the plot should be interactive
add_labels: for static plot, adds text labels on top of bounding box
Expand All @@ -323,6 +354,23 @@ def visualize_kie_page(
if interactive:
artists: list[patches.Patch] = [] # instantiate an empty list of patches (to be drawn on the page)

if display_layout and page.get("layout"):
region_classes = sorted({region["type"] for region in page["layout"]})
layout_colors = {cls: color for color, cls in zip(get_colors(max(len(region_classes), 1)), region_classes)}
for region in page["layout"]:
rect = create_obj_patch(
region["geometry"],
page["dimensions"],
label=f"{region['type']} (confidence: {region['confidence']:.2%})",
color=layout_colors[region["type"]],
linewidth=2,
fill=False,
**kwargs,
)
ax.add_patch(rect)
if interactive:
artists.append(rect)

colors = {k: color for color, k in zip(get_colors(len(page["predictions"])), page["predictions"])}
for key, value in page["predictions"].items():
for prediction in value:
Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ dependencies = [
"rapidfuzz>=3.0.0,<4.0.0",
"huggingface-hub>=0.20.0,<2.0.0",
"Pillow>=9.2.0",
"defusedxml>=0.7.0",
"anyascii>=0.3.2",
"validators>=0.18.0",
"tqdm>=4.30.0",
Expand Down Expand Up @@ -159,7 +158,6 @@ module = [
"pyclipper.*",
"shapely.*",
"mplcursors.*",
"defusedxml.*",
"weasyprint.*",
"huggingface_hub.*",
"pypdfium2.*",
Expand Down
Loading
Loading