Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions docs/source/using_doctr/using_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -562,3 +562,58 @@ This will only have an effect with `assume_straight_pages=False` and/or `straigh
# You can also add multiple hooks which will be executed sequentially
for hook in [my_hook, my_hook, my_hook]:
predictor.add_hook(hook)


* Restrict the recognition model to a subset of its vocabulary.

If you only expect text from one or more known languages, you can whitelist the corresponding vocabs so the
recognition model can no longer predict any character outside of them. This works with every recognition
architecture and with any predictor wrapping one (`ocr_predictor`, `kie_predictor`, `recognition_predictor`).
A whitelist can only restrict a model to characters it already knows: characters that are not part of the
model's own vocabulary are silently ignored, so make sure the model was trained on a vocab that covers the
languages you need (e.g. a multilingual model).

.. code:: python3

from doctr.datasets import VOCABS
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
from doctr.models.utils import add_whitelist

predictor = ocr_predictor(pretrained=True)

# The recognition model can now only predict Polish/German characters
handle = add_whitelist(predictor, [VOCABS["polish"], VOCABS["german"]])

input_page = DocumentFile.from_images("path/to/your/image.png")
out = predictor(input_page)

# Restore the original, unconstrained decoding
handle.remove()

The returned handle can also be used as a context manager, in which case the whitelist is removed on exit:

.. code:: python3

with add_whitelist(predictor, VOCABS["german"]):
out = predictor(input_page) # only German characters can be predicted here
# the whitelist is automatically removed outside of the ``with`` block

By default forbidden characters are dropped (``strategy="mask"``), so decoding falls back to the highest-scoring
allowed character. Alternatively, ``strategy="nearest"`` folds each forbidden character onto the closest allowed
one (e.g. ``ä`` -> ``a``, ``ł`` -> ``l``), which is useful to normalize accents/diacritics onto a base alphabet.
The mapping is built by transliteration by default; pass ``mapping="weights"`` to derive it from the model's own
learned confusions, or a ``{forbidden_char: allowed_char}`` dict to override specific characters.

.. code:: python3

from doctr.datasets import VOCABS
from doctr.models import ocr_predictor
from doctr.models.utils import add_whitelist

predictor = ocr_predictor(pretrained=True)

# Fold any non-ASCII character onto its closest ASCII letter (e.g. é -> e, ł -> l)
handle = add_whitelist(predictor, VOCABS["latin"], strategy="nearest")
out = predictor(input_page)
handle.remove()
262 changes: 262 additions & 0 deletions doctr/models/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import logging
from collections.abc import Iterable
from typing import Any

import torch
import validators
from anyascii import anyascii
from torch import nn

from doctr.utils.data import download_from_url
Expand All @@ -17,6 +19,7 @@
"conv_sequence_pt",
"set_device_and_dtype",
"export_model_to_onnx",
"add_whitelist",
"_copy_tensor",
"_bf16_to_float32",
"_CompiledModule",
Expand Down Expand Up @@ -196,3 +199,262 @@
)
logging.info(f"Model exported to {model_name}.onnx")
return f"{model_name}.onnx"


# Location of the final vocabulary-projection layer for each recognition architecture.
_RECOGNITION_PROJECTIONS: dict[str, str] = {
"CRNN": "linear",
"SAR": "decoder.output_dense",
"MASTER": "linear",
"ViTSTR": "head",
"PARSeq": "head",
"VIPTR": "head",
}


class WhitelistHandle:
"""Removable registration returned by :func:`add_whitelist`.

Check notice on line 216 in doctr/models/utils/pytorch.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/utils/pytorch.py#L216

1 blank line required before class docstring (found 0) (D203)

Check notice on line 216 in doctr/models/utils/pytorch.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/utils/pytorch.py#L216

Multi-line docstring summary should start at the second line (D213)

Call :meth:`remove` to restore the model's original, unconstrained decoding. The
handle can also be used as a context manager, in which case the whitelist is removed
on exit.
"""

def __init__(self, handles: list[torch.utils.hooks.RemovableHandle]) -> None:
self._handles = handles

def remove(self) -> None:
"""Remove the whitelist and restore the model's unconstrained decoding."""
for handle in self._handles:
handle.remove()
self._handles = []

def __enter__(self) -> "WhitelistHandle":
return self

def __exit__(self, *_: Any) -> None:
self.remove()


def _recognition_model(model: nn.Module) -> nn.Module:
# Accept an ocr_predictor / kie_predictor / recognition_predictor or a recognition model
if hasattr(model, "vocab") and hasattr(model, "postprocessor"):
return model
reco_predictor = getattr(model, "reco_predictor", model)
reco_model = getattr(reco_predictor, "model", None)
if reco_model is None:
raise TypeError(
"Expected an ocr_predictor, kie_predictor, recognition_predictor or a recognition "
f"model, but could not find a recognition model on {type(model).__name__}."
)
return reco_model


def _vocab_projections(model: nn.Module, vocab_size: int) -> list[nn.Linear]:
path = _RECOGNITION_PROJECTIONS.get(type(model).__name__)
if path is not None:
layer: Any = model
for part in path.split("."):
layer = getattr(layer, part)
if isinstance(layer, nn.Linear):
return [layer]
# Fallback for unknown architectures: any Linear projecting to the vocab (+ up to 3 specials)
candidates = [
module
for module in model.modules()
if isinstance(module, nn.Linear) and module.out_features in {vocab_size + 1, vocab_size + 2, vocab_size + 3}
]
if not candidates:
raise RuntimeError(f"Could not locate the vocabulary projection layer of {type(model).__name__}.")
return candidates


def _anyascii_nearest_map(vocab: str, allowed: set[str]) -> dict[str, str]:
"""Map each forbidden character to the visually closest allowed one via transliteration.

Check notice on line 273 in doctr/models/utils/pytorch.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/utils/pytorch.py#L273

Multi-line docstring summary should start at the second line (D213)

Uses ``anyascii`` to fold characters to their ASCII form (e.g. ``ä -> a``, ``ł -> l``,
Cyrillic ``а -> a``); a forbidden character is mapped to an allowed character sharing the
same ASCII form. Forbidden characters without such a match are left unmapped (they fall
back to plain masking).
"""
by_translit: dict[str, str] = {}
for char in vocab:
if char not in allowed:
continue
key = anyascii(char)
current = by_translit.get(key)
# Prefer a pure-ASCII allowed character as the canonical target for a given form.
if current is None or (char == anyascii(char) and current != anyascii(current)):
by_translit[key] = char

mapping: dict[str, str] = {}
for char in vocab:
if char in allowed:
continue
form = anyascii(char)
target = by_translit.get(form) or by_translit.get(form.lower()) or by_translit.get(form[:1])
if target is not None:
mapping[char] = target
return mapping


def _weights_nearest_map(vocab: str, allowed: set[str], projection: nn.Linear) -> dict[str, str]:
"""Map each forbidden character to the allowed one whose projection weights are most similar.

Check notice on line 302 in doctr/models/utils/pytorch.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/utils/pytorch.py#L302

Multi-line docstring summary should start at the second line (D213)

This uses the model's own learned representation: the nearest allowed character is the one
the model most confuses the forbidden character with (cosine similarity of the projection
weight rows).
"""
vocab_size = len(vocab)
rows = nn.functional.normalize(projection.weight.detach()[:vocab_size], dim=1)
allowed_idx = [i for i, char in enumerate(vocab) if char in allowed]
forbidden_idx = [i for i, char in enumerate(vocab) if char not in allowed]
if not allowed_idx or not forbidden_idx:
return {}
similarity = rows[forbidden_idx] @ rows[allowed_idx].t()
nearest = similarity.argmax(dim=1)
return {vocab[forbidden_idx[k]]: vocab[allowed_idx[int(nearest[k])]] for k in range(len(forbidden_idx))}


def _keep_and_reassign(
vocab: str, allowed: set[str], out_features: int, char_map: dict[str, str]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Build the keep mask and the (forbidden -> allowed) index tensors for one projection."""
vocab_size = len(vocab)
keep = torch.zeros(out_features, dtype=torch.bool)
for idx, char in enumerate(vocab):
keep[idx] = char in allowed
keep[vocab_size] = True # sequence terminator (CTC blank / attention <eos>)

position = {char: idx for idx, char in enumerate(vocab)}
src, dst = [], []
for forbidden_char, allowed_char in char_map.items():
src_idx, dst_idx = position.get(forbidden_char), position.get(allowed_char)
# only reassign genuinely-forbidden characters onto genuinely-allowed ones
if src_idx is not None and dst_idx is not None and not keep[src_idx] and allowed_char in allowed:
src.append(src_idx)
dst.append(dst_idx)
return keep, torch.tensor(src, dtype=torch.long), torch.tensor(dst, dtype=torch.long)


def add_whitelist(
model: nn.Module,
vocabs: str | Iterable[str],
*,
strategy: str = "mask",
mapping: str | dict[str, str] | None = None,
verbose: bool = False,
) -> WhitelistHandle:
"""Restrict a recognition model so it can only predict a subset of its vocabulary.

Check notice on line 348 in doctr/models/utils/pytorch.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/utils/pytorch.py#L348

Missing blank line after last section ('Returns') (D413)

Check notice on line 348 in doctr/models/utils/pytorch.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/utils/pytorch.py#L348

Missing dashed underline after section ('Returns') (D407)

Check notice on line 348 in doctr/models/utils/pytorch.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/utils/pytorch.py#L348

Multi-line docstring summary should start at the second line (D213)

Check notice on line 348 in doctr/models/utils/pytorch.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/utils/pytorch.py#L348

Section name should end with a newline ('Returns', not 'Returns:') (D406)

The whitelist is enforced at the model's final projection layer, before the decoding
``argmax``. Because the projection is the single point every logit flows through, the
constraint also applies inside the autoregressive decoding loop of SAR, MASTER and PARSeq,
so a forbidden character can never be produced -- not even fed back mid-word. The sequence
terminator (CTC ``blank`` / attention ``<eos>``) is always kept so decoding still
terminates. It works with every recognition architecture and with any predictor wrapping
one (`ocr_predictor`, `kie_predictor`, `recognition_predictor`).

Two strategies are available:

* ``"mask"`` (default): the logits of forbidden characters are set to ``-inf``, so decoding
falls back to the highest-scoring allowed character.
* ``"nearest"``: the score of each forbidden character is first reassigned to the closest
allowed character (so e.g. ``ä`` folds onto ``a``), then forbidden logits are masked.
Forbidden characters without a mapping fall back to masking.

A whitelist can only restrict a model to characters it already knows: characters that are
not part of the model's own vocabulary are silently ignored.

>>> from doctr.datasets import VOCABS
>>> from doctr.models import ocr_predictor
>>> from doctr.models.utils import add_whitelist
>>> predictor = ocr_predictor(pretrained=True)
>>> handle = add_whitelist(predictor, [VOCABS["polish"], VOCABS["german"]])
>>> # ... run the predictor; only Polish/German characters can be predicted ...
>>> handle.remove() # restore the original, unconstrained decoding

Args:
model: an `ocr_predictor`, `kie_predictor`, `recognition_predictor`, or a recognition model.
vocabs: a vocabulary string (e.g. ``VOCABS["german"]``) or an iterable of vocabulary
strings (e.g. ``[VOCABS["polish"], VOCABS["german"]]``) whose characters are allowed.
strategy: ``"mask"`` (default) to drop forbidden characters, or ``"nearest"`` to fold
them onto the closest allowed character.
mapping: only used when ``strategy="nearest"``. ``None`` or ``"anyascii"`` builds the
forbidden-to-allowed map by transliteration (the default); ``"weights"`` derives it
from the projection weights (the model's own confusions); a ``dict`` of
``{forbidden_char: allowed_char}`` overrides specific characters on top of the
transliteration map.
verbose: if True, log how many characters were kept, forbidden and reassigned per model.

Returns:
a :class:`WhitelistHandle`; call its :meth:`~WhitelistHandle.remove` method to restore
the original, unconstrained decoding.
"""
if strategy not in {"mask", "nearest"}:
raise ValueError(f"Unknown strategy {strategy!r}; expected 'mask' or 'nearest'.")
if strategy == "mask" and mapping is not None:
raise ValueError("The 'mapping' argument is only used with strategy='nearest'.")
if isinstance(mapping, str) and mapping not in {"anyascii", "weights"}:
raise ValueError(f"Unknown mapping {mapping!r}; expected 'anyascii', 'weights', a dict or None.")
if mapping is not None and not isinstance(mapping, (str, dict)):
raise ValueError("The 'mapping' argument must be None, 'anyascii', 'weights' or a dict.")

allowed = set(vocabs) if isinstance(vocabs, str) else {char for vocab in vocabs for char in vocab}

handles: list[torch.utils.hooks.RemovableHandle] = []
reco_model = _recognition_model(model)
vocab: str = reco_model.vocab # type: ignore[assignment]
vocab_size = len(vocab)
if not any(char in allowed for char in vocab):
raise ValueError(
"The whitelist shares no character with the model's vocabulary; the model would "
"be unable to predict anything."
)

# A vocab-level character map (shared by every projection); the weight-based map is
# derived per projection further down.
base_map: dict[str, str] = {}
if strategy == "nearest" and mapping != "weights":
base_map = _anyascii_nearest_map(vocab, allowed)
if isinstance(mapping, dict):
base_map = {**base_map, **mapping}

reassigned = 0
for projection in _vocab_projections(reco_model, vocab_size):
char_map = (
_weights_nearest_map(vocab, allowed, projection)
if strategy == "nearest" and mapping == "weights"
else base_map
)
keep, src, dst = _keep_and_reassign(vocab, allowed, projection.out_features, char_map)
reassigned = max(reassigned, src.numel())

def _constrain_logits(
_module: nn.Module,
_inputs: Any,
output: torch.Tensor,
keep: torch.Tensor = keep,
src: torch.Tensor = src,
dst: torch.Tensor = dst,
):
output = output.clone()
if src.numel():
# move each forbidden character's score onto its nearest allowed character
values = output[..., src.to(output.device)]
index = dst.to(output.device).view(*([1] * (output.dim() - 1)), -1).expand(*output.shape[:-1], -1)
output.scatter_reduce_(-1, index, values, reduce="amax", include_self=True)
output[..., ~keep.to(output.device)] = float("-inf")
return output

handles.append(projection.register_forward_hook(_constrain_logits))

if verbose: # pragma: no cover
kept = sum(char in allowed for char in vocab)
logging.info(
f"add_whitelist: {type(reco_model).__name__} - kept {kept}/{vocab_size} vocabulary "
f"characters, forbade {vocab_size - kept}"
+ (f", reassigned {reassigned} to a nearest allowed character." if strategy == "nearest" else ".")
)

return WhitelistHandle(handles)
4 changes: 3 additions & 1 deletion references/layout/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,10 +479,11 @@ def main(args):
# construct DDP model
model = DDP(model, device_ids=[rank])

backbone_lr = args.lr * 0.1 if args.pretrained or args.resume is not None else args.lr
param_groups = build_param_groups(
model,
lr=args.lr,
backbone_lr=args.lr if not args.pretrained else args.lr * 0.1,
backbone_lr=backbone_lr,
weight_decay=args.weight_decay or 1e-4,
)

Expand Down Expand Up @@ -556,6 +557,7 @@ def main(args):
if rank == 0:
config = {
"learning_rate": args.lr,
"backbone_learning_rate": backbone_lr,
"epochs": args.epochs,
"weight_decay": args.weight_decay,
"batch_size": args.batch_size,
Expand Down
Loading
Loading