diff --git a/docs/source/using_doctr/using_models.rst b/docs/source/using_doctr/using_models.rst index 18b3b5dab0..4f67f73ead 100644 --- a/docs/source/using_doctr/using_models.rst +++ b/docs/source/using_doctr/using_models.rst @@ -562,3 +562,58 @@ This will only have an effect with `assume_straight_pages=False` and/or `straigh # You can also add multiple hooks which will be executed sequentially for hook in [my_hook, my_hook, my_hook]: predictor.add_hook(hook) + + +* Restrict the recognition model to a subset of its vocabulary. + +If you only expect text from one or more known languages, you can whitelist the corresponding vocabs so the +recognition model can no longer predict any character outside of them. This works with every recognition +architecture and with any predictor wrapping one (`ocr_predictor`, `kie_predictor`, `recognition_predictor`). +A whitelist can only restrict a model to characters it already knows: characters that are not part of the +model's own vocabulary are silently ignored, so make sure the model was trained on a vocab that covers the +languages you need (e.g. a multilingual model). + +.. code:: python3 + + from doctr.datasets import VOCABS + from doctr.io import DocumentFile + from doctr.models import ocr_predictor + from doctr.models.utils import add_whitelist + + predictor = ocr_predictor(pretrained=True) + + # The recognition model can now only predict Polish/German characters + handle = add_whitelist(predictor, [VOCABS["polish"], VOCABS["german"]]) + + input_page = DocumentFile.from_images("path/to/your/image.png") + out = predictor(input_page) + + # Restore the original, unconstrained decoding + handle.remove() + +The returned handle can also be used as a context manager, in which case the whitelist is removed on exit: + +.. code:: python3 + + with add_whitelist(predictor, VOCABS["german"]): + out = predictor(input_page) # only German characters can be predicted here + # the whitelist is automatically removed outside of the ``with`` block + +By default forbidden characters are dropped (``strategy="mask"``), so decoding falls back to the highest-scoring +allowed character. Alternatively, ``strategy="nearest"`` folds each forbidden character onto the closest allowed +one (e.g. ``ä`` -> ``a``, ``ł`` -> ``l``), which is useful to normalize accents/diacritics onto a base alphabet. +The mapping is built by transliteration by default; pass ``mapping="weights"`` to derive it from the model's own +learned confusions, or a ``{forbidden_char: allowed_char}`` dict to override specific characters. + +.. code:: python3 + + from doctr.datasets import VOCABS + from doctr.models import ocr_predictor + from doctr.models.utils import add_whitelist + + predictor = ocr_predictor(pretrained=True) + + # Fold any non-ASCII character onto its closest ASCII letter (e.g. é -> e, ł -> l) + handle = add_whitelist(predictor, VOCABS["latin"], strategy="nearest") + out = predictor(input_page) + handle.remove() diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py index 6667320cb4..50dd9205ef 100644 --- a/doctr/models/utils/pytorch.py +++ b/doctr/models/utils/pytorch.py @@ -4,10 +4,12 @@ # See LICENSE or go to for full license details. import logging +from collections.abc import Iterable from typing import Any import torch import validators +from anyascii import anyascii from torch import nn from doctr.utils.data import download_from_url @@ -17,6 +19,7 @@ "conv_sequence_pt", "set_device_and_dtype", "export_model_to_onnx", + "add_whitelist", "_copy_tensor", "_bf16_to_float32", "_CompiledModule", @@ -196,3 +199,262 @@ def export_model_to_onnx( ) logging.info(f"Model exported to {model_name}.onnx") return f"{model_name}.onnx" + + +# Location of the final vocabulary-projection layer for each recognition architecture. +_RECOGNITION_PROJECTIONS: dict[str, str] = { + "CRNN": "linear", + "SAR": "decoder.output_dense", + "MASTER": "linear", + "ViTSTR": "head", + "PARSeq": "head", + "VIPTR": "head", +} + + +class WhitelistHandle: + """Removable registration returned by :func:`add_whitelist`. + + Call :meth:`remove` to restore the model's original, unconstrained decoding. The + handle can also be used as a context manager, in which case the whitelist is removed + on exit. + """ + + def __init__(self, handles: list[torch.utils.hooks.RemovableHandle]) -> None: + self._handles = handles + + def remove(self) -> None: + """Remove the whitelist and restore the model's unconstrained decoding.""" + for handle in self._handles: + handle.remove() + self._handles = [] + + def __enter__(self) -> "WhitelistHandle": + return self + + def __exit__(self, *_: Any) -> None: + self.remove() + + +def _recognition_model(model: nn.Module) -> nn.Module: + # Accept an ocr_predictor / kie_predictor / recognition_predictor or a recognition model + if hasattr(model, "vocab") and hasattr(model, "postprocessor"): + return model + reco_predictor = getattr(model, "reco_predictor", model) + reco_model = getattr(reco_predictor, "model", None) + if reco_model is None: + raise TypeError( + "Expected an ocr_predictor, kie_predictor, recognition_predictor or a recognition " + f"model, but could not find a recognition model on {type(model).__name__}." + ) + return reco_model + + +def _vocab_projections(model: nn.Module, vocab_size: int) -> list[nn.Linear]: + path = _RECOGNITION_PROJECTIONS.get(type(model).__name__) + if path is not None: + layer: Any = model + for part in path.split("."): + layer = getattr(layer, part) + if isinstance(layer, nn.Linear): + return [layer] + # Fallback for unknown architectures: any Linear projecting to the vocab (+ up to 3 specials) + candidates = [ + module + for module in model.modules() + if isinstance(module, nn.Linear) and module.out_features in {vocab_size + 1, vocab_size + 2, vocab_size + 3} + ] + if not candidates: + raise RuntimeError(f"Could not locate the vocabulary projection layer of {type(model).__name__}.") + return candidates + + +def _anyascii_nearest_map(vocab: str, allowed: set[str]) -> dict[str, str]: + """Map each forbidden character to the visually closest allowed one via transliteration. + + Uses ``anyascii`` to fold characters to their ASCII form (e.g. ``ä -> a``, ``ł -> l``, + Cyrillic ``а -> a``); a forbidden character is mapped to an allowed character sharing the + same ASCII form. Forbidden characters without such a match are left unmapped (they fall + back to plain masking). + """ + by_translit: dict[str, str] = {} + for char in vocab: + if char not in allowed: + continue + key = anyascii(char) + current = by_translit.get(key) + # Prefer a pure-ASCII allowed character as the canonical target for a given form. + if current is None or (char == anyascii(char) and current != anyascii(current)): + by_translit[key] = char + + mapping: dict[str, str] = {} + for char in vocab: + if char in allowed: + continue + form = anyascii(char) + target = by_translit.get(form) or by_translit.get(form.lower()) or by_translit.get(form[:1]) + if target is not None: + mapping[char] = target + return mapping + + +def _weights_nearest_map(vocab: str, allowed: set[str], projection: nn.Linear) -> dict[str, str]: + """Map each forbidden character to the allowed one whose projection weights are most similar. + + This uses the model's own learned representation: the nearest allowed character is the one + the model most confuses the forbidden character with (cosine similarity of the projection + weight rows). + """ + vocab_size = len(vocab) + rows = nn.functional.normalize(projection.weight.detach()[:vocab_size], dim=1) + allowed_idx = [i for i, char in enumerate(vocab) if char in allowed] + forbidden_idx = [i for i, char in enumerate(vocab) if char not in allowed] + if not allowed_idx or not forbidden_idx: + return {} + similarity = rows[forbidden_idx] @ rows[allowed_idx].t() + nearest = similarity.argmax(dim=1) + return {vocab[forbidden_idx[k]]: vocab[allowed_idx[int(nearest[k])]] for k in range(len(forbidden_idx))} + + +def _keep_and_reassign( + vocab: str, allowed: set[str], out_features: int, char_map: dict[str, str] +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Build the keep mask and the (forbidden -> allowed) index tensors for one projection.""" + vocab_size = len(vocab) + keep = torch.zeros(out_features, dtype=torch.bool) + for idx, char in enumerate(vocab): + keep[idx] = char in allowed + keep[vocab_size] = True # sequence terminator (CTC blank / attention ) + + position = {char: idx for idx, char in enumerate(vocab)} + src, dst = [], [] + for forbidden_char, allowed_char in char_map.items(): + src_idx, dst_idx = position.get(forbidden_char), position.get(allowed_char) + # only reassign genuinely-forbidden characters onto genuinely-allowed ones + if src_idx is not None and dst_idx is not None and not keep[src_idx] and allowed_char in allowed: + src.append(src_idx) + dst.append(dst_idx) + return keep, torch.tensor(src, dtype=torch.long), torch.tensor(dst, dtype=torch.long) + + +def add_whitelist( + model: nn.Module, + vocabs: str | Iterable[str], + *, + strategy: str = "mask", + mapping: str | dict[str, str] | None = None, + verbose: bool = False, +) -> WhitelistHandle: + """Restrict a recognition model so it can only predict a subset of its vocabulary. + + The whitelist is enforced at the model's final projection layer, before the decoding + ``argmax``. Because the projection is the single point every logit flows through, the + constraint also applies inside the autoregressive decoding loop of SAR, MASTER and PARSeq, + so a forbidden character can never be produced -- not even fed back mid-word. The sequence + terminator (CTC ``blank`` / attention ````) is always kept so decoding still + terminates. It works with every recognition architecture and with any predictor wrapping + one (`ocr_predictor`, `kie_predictor`, `recognition_predictor`). + + Two strategies are available: + + * ``"mask"`` (default): the logits of forbidden characters are set to ``-inf``, so decoding + falls back to the highest-scoring allowed character. + * ``"nearest"``: the score of each forbidden character is first reassigned to the closest + allowed character (so e.g. ``ä`` folds onto ``a``), then forbidden logits are masked. + Forbidden characters without a mapping fall back to masking. + + A whitelist can only restrict a model to characters it already knows: characters that are + not part of the model's own vocabulary are silently ignored. + + >>> from doctr.datasets import VOCABS + >>> from doctr.models import ocr_predictor + >>> from doctr.models.utils import add_whitelist + >>> predictor = ocr_predictor(pretrained=True) + >>> handle = add_whitelist(predictor, [VOCABS["polish"], VOCABS["german"]]) + >>> # ... run the predictor; only Polish/German characters can be predicted ... + >>> handle.remove() # restore the original, unconstrained decoding + + Args: + model: an `ocr_predictor`, `kie_predictor`, `recognition_predictor`, or a recognition model. + vocabs: a vocabulary string (e.g. ``VOCABS["german"]``) or an iterable of vocabulary + strings (e.g. ``[VOCABS["polish"], VOCABS["german"]]``) whose characters are allowed. + strategy: ``"mask"`` (default) to drop forbidden characters, or ``"nearest"`` to fold + them onto the closest allowed character. + mapping: only used when ``strategy="nearest"``. ``None`` or ``"anyascii"`` builds the + forbidden-to-allowed map by transliteration (the default); ``"weights"`` derives it + from the projection weights (the model's own confusions); a ``dict`` of + ``{forbidden_char: allowed_char}`` overrides specific characters on top of the + transliteration map. + verbose: if True, log how many characters were kept, forbidden and reassigned per model. + + Returns: + a :class:`WhitelistHandle`; call its :meth:`~WhitelistHandle.remove` method to restore + the original, unconstrained decoding. + """ + if strategy not in {"mask", "nearest"}: + raise ValueError(f"Unknown strategy {strategy!r}; expected 'mask' or 'nearest'.") + if strategy == "mask" and mapping is not None: + raise ValueError("The 'mapping' argument is only used with strategy='nearest'.") + if isinstance(mapping, str) and mapping not in {"anyascii", "weights"}: + raise ValueError(f"Unknown mapping {mapping!r}; expected 'anyascii', 'weights', a dict or None.") + if mapping is not None and not isinstance(mapping, (str, dict)): + raise ValueError("The 'mapping' argument must be None, 'anyascii', 'weights' or a dict.") + + allowed = set(vocabs) if isinstance(vocabs, str) else {char for vocab in vocabs for char in vocab} + + handles: list[torch.utils.hooks.RemovableHandle] = [] + reco_model = _recognition_model(model) + vocab: str = reco_model.vocab # type: ignore[assignment] + vocab_size = len(vocab) + if not any(char in allowed for char in vocab): + raise ValueError( + "The whitelist shares no character with the model's vocabulary; the model would " + "be unable to predict anything." + ) + + # A vocab-level character map (shared by every projection); the weight-based map is + # derived per projection further down. + base_map: dict[str, str] = {} + if strategy == "nearest" and mapping != "weights": + base_map = _anyascii_nearest_map(vocab, allowed) + if isinstance(mapping, dict): + base_map = {**base_map, **mapping} + + reassigned = 0 + for projection in _vocab_projections(reco_model, vocab_size): + char_map = ( + _weights_nearest_map(vocab, allowed, projection) + if strategy == "nearest" and mapping == "weights" + else base_map + ) + keep, src, dst = _keep_and_reassign(vocab, allowed, projection.out_features, char_map) + reassigned = max(reassigned, src.numel()) + + def _constrain_logits( + _module: nn.Module, + _inputs: Any, + output: torch.Tensor, + keep: torch.Tensor = keep, + src: torch.Tensor = src, + dst: torch.Tensor = dst, + ): + output = output.clone() + if src.numel(): + # move each forbidden character's score onto its nearest allowed character + values = output[..., src.to(output.device)] + index = dst.to(output.device).view(*([1] * (output.dim() - 1)), -1).expand(*output.shape[:-1], -1) + output.scatter_reduce_(-1, index, values, reduce="amax", include_self=True) + output[..., ~keep.to(output.device)] = float("-inf") + return output + + handles.append(projection.register_forward_hook(_constrain_logits)) + + if verbose: # pragma: no cover + kept = sum(char in allowed for char in vocab) + logging.info( + f"add_whitelist: {type(reco_model).__name__} - kept {kept}/{vocab_size} vocabulary " + f"characters, forbade {vocab_size - kept}" + + (f", reassigned {reassigned} to a nearest allowed character." if strategy == "nearest" else ".") + ) + + return WhitelistHandle(handles) diff --git a/references/layout/train.py b/references/layout/train.py index 8f9ce372b1..4e9d550f02 100644 --- a/references/layout/train.py +++ b/references/layout/train.py @@ -479,10 +479,11 @@ def main(args): # construct DDP model model = DDP(model, device_ids=[rank]) + backbone_lr = args.lr * 0.1 if args.pretrained or args.resume is not None else args.lr param_groups = build_param_groups( model, lr=args.lr, - backbone_lr=args.lr if not args.pretrained else args.lr * 0.1, + backbone_lr=backbone_lr, weight_decay=args.weight_decay or 1e-4, ) @@ -556,6 +557,7 @@ def main(args): if rank == 0: config = { "learning_rate": args.lr, + "backbone_learning_rate": backbone_lr, "epochs": args.epochs, "weight_decay": args.weight_decay, "batch_size": args.batch_size, diff --git a/tests/pytorch/test_models_utils_pt.py b/tests/pytorch/test_models_utils_pt.py index 75b5d40ca2..c5db89b6a9 100644 --- a/tests/pytorch/test_models_utils_pt.py +++ b/tests/pytorch/test_models_utils_pt.py @@ -1,16 +1,21 @@ import os +import numpy as np import pytest import torch from torch import nn +from doctr.datasets import VOCABS +from doctr.models import recognition, recognition_predictor from doctr.models.utils import ( _bf16_to_float32, _copy_tensor, + add_whitelist, conv_sequence_pt, load_pretrained_params, set_device_and_dtype, ) +from doctr.models.utils.pytorch import _vocab_projections def test_copy_tensor(): @@ -95,3 +100,233 @@ def test_set_device_and_dtype(): assert all(isinstance(t, torch.Tensor) for t in new_masks) assert all(t.dtype == torch.bool for t in new_masks) assert all(t.device == torch.device("cpu") for t in new_masks) + + +@pytest.mark.parametrize( + "arch_name", + [ + "crnn_vgg16_bn", + "crnn_mobilenet_v3_small", + "crnn_mobilenet_v3_large", + "sar_resnet31", + "master", + "vitstr_small", + "vitstr_base", + "parseq", + "viptr_tiny", + ], +) +def test_add_whitelist(arch_name): + # A test vocab containing the full whitelist plus extra forbiddable characters, kept small + # enough to fit every architecture (SAR's feedback embedding caps at 512 entries). + _whitelist = [VOCABS["polish"], VOCABS["german"]] + _allowed = "".join(dict.fromkeys("".join(_whitelist))) # ordered-unique characters + _extra = "".join(c for c in VOCABS["multilingual"] if c not in set(_allowed))[:200] + _test_vocab = _allowed + _extra + + model = recognition.__dict__[arch_name](pretrained=True, vocab=_test_vocab).eval() + allowed = set(_allowed) + # every whitelisted character is part of the model vocab (nothing is dropped) + assert allowed.issubset(set(model.vocab)) + + forbidden_idx = [i for i, c in enumerate(model.vocab) if c not in allowed] + allowed_idx = [i for i, c in enumerate(model.vocab) if c in allowed] + terminator_idx = len(model.vocab) + assert len(forbidden_idx) > 0 + + samples = torch.rand(4, 3, 32, 128) + + handle = add_whitelist(model, _whitelist) + with torch.inference_mode(): + out = model(samples, return_model_output=True, return_preds=True) + logits = out["out_map"] + + # forbidden characters are masked out, while whitelisted characters and the terminator stay finite + assert torch.isneginf(logits[..., forbidden_idx]).all() + assert torch.isfinite(logits[..., allowed_idx]).all() + assert torch.isfinite(logits[..., terminator_idx]).all() + # the decoded output only contains whitelisted characters (and no leaked special tokens) + for word, _ in out["preds"]: + assert all(char in allowed for char in word) + + # remove() restores the original, unconstrained decoding + handle.remove() + with torch.inference_mode(): + restored = model(samples, return_model_output=True)["out_map"] + assert torch.isfinite(restored).all() + + # Test biased model: + # Even when the model is biased toward forbidden characters, the whitelist must win + # (this also exercises the autoregressive feedback loop). + model = recognition.parseq(pretrained=True, vocab=_test_vocab).eval() + allowed = set(VOCABS["german"]) + forbidden_idx = torch.tensor([i for i, c in enumerate(model.vocab) if c not in allowed]) + + def bias_forbidden(module, inputs, output): + output = output.clone() + output[..., forbidden_idx] += 1e4 + return output + + bias_handle = model.head.register_forward_hook(bias_forbidden) # runs before the whitelist hook + samples = torch.rand(4, 3, 32, 128) + with torch.inference_mode(): + attacked = model(samples, return_preds=True)["preds"] + assert any(char not in allowed for word, _ in attacked for char in word) + + whitelist_handle = add_whitelist(model, VOCABS["german"]) # registered after -> overrides the bias + with torch.inference_mode(): + defended = model(samples, return_preds=True)["preds"] + assert all(char in allowed for word, _ in defended for char in word) + + whitelist_handle.remove() + bias_handle.remove() + + # Test as context manager and on a predictor + model = recognition.crnn_vgg16_bn(pretrained=True, vocab=_test_vocab).eval() + samples = torch.rand(2, 3, 32, 128) + with add_whitelist(model, VOCABS["german"]): + with torch.inference_mode(): + masked = model(samples, return_model_output=True)["out_map"] + assert torch.isneginf(masked).any() + # outside the context the whitelist has been removed + with torch.inference_mode(): + restored = model(samples, return_model_output=True)["out_map"] + assert torch.isfinite(restored).all() + + # test whitelist error cases + model = recognition.crnn_vgg16_bn(pretrained=True, vocab="abc123").eval() + # a whitelist disjoint from the model vocabulary is rejected + with pytest.raises(ValueError): + add_whitelist(model, "XYZ") + # an object that is not a recognition model / predictor is rejected + with pytest.raises(TypeError): + add_whitelist(nn.Linear(8, 8), "abc") + + +@pytest.mark.parametrize( + "arch_name", + [ + "crnn_vgg16_bn", + "crnn_mobilenet_v3_small", + "crnn_mobilenet_v3_large", + "sar_resnet31", + "master", + "vitstr_small", + "vitstr_base", + "parseq", + "viptr_tiny", + ], +) +def test_end_to_end_add_whitelist(arch_name): + vocab = "abcXYZ" + allowed = set("abc") + model = recognition.__dict__[arch_name](pretrained=False, vocab=vocab).eval() + predictor = recognition_predictor(model, batch_size=2) + + forbidden_idx = model.vocab.index("X") + allowed_idx = model.vocab.index("a") + projection = _vocab_projections(model, len(model.vocab))[0] + + def bias_forbidden(_module, _inputs, output): + output = output.clone() + output[..., forbidden_idx] += 1e4 + output[..., allowed_idx] += 5e3 + return output + + bias_handle = projection.register_forward_hook(bias_forbidden) + crops = [(255 * np.random.rand(32, 128, 3)).astype(np.uint8) for _ in range(2)] + + try: + unconstrained = predictor(crops) + assert all("X" in word for word, _ in unconstrained) + with add_whitelist(predictor, "abc"): + constrained = predictor(crops) + assert all(word and all(char in allowed for char in word) for word, _ in constrained) + restored = predictor(crops) + assert all("X" in word for word, _ in restored) + finally: + bias_handle.remove() + + +def test_vocab_projections_fallback_candidates(): + from doctr.models.utils.pytorch import _vocab_projections + + vocab_size = 6 + + class UnknownRecognitionModel(nn.Module): + def __init__(self): + super().__init__() + self.hidden = nn.Linear(4, vocab_size + 4) + self.projection = nn.Linear(vocab_size + 4, vocab_size + 1) + self.aux_projection = nn.Linear(vocab_size + 4, vocab_size + 3) + self.unrelated = nn.Linear(vocab_size + 4, vocab_size + 4) + + model = UnknownRecognitionModel() + assert _vocab_projections(model, vocab_size) == [model.projection, model.aux_projection] + + with pytest.raises(RuntimeError, match="Could not locate the vocabulary projection layer"): + _vocab_projections(nn.Linear(4, 4), vocab_size) + + +# Forbidden characters whose visual base (via anyascii) is part of the German whitelist. +_NEAREST_VOCAB = "".join(dict.fromkeys(VOCABS["german"] + "ąóńł")) +_NEAREST_FOLDS = {"ą": "a", "ó": "o", "ń": "n", "ł": "l"} + + +def _force_and_decode(model, target_char, **whitelist_kwargs): + """Bias the model to prefer ``target_char``, apply the whitelist, return the decoded word.""" + from doctr.models.utils.pytorch import _vocab_projections, add_whitelist + + forbidden_idx = model.vocab.index(target_char) + projection = _vocab_projections(model, len(model.vocab))[0] + + def bias(module, inputs, output, idx=forbidden_idx): + output = output.clone() + output[..., idx] += 1e4 + return output + + bias_handle = projection.register_forward_hook(bias) # runs before the whitelist hook + whitelist_handle = add_whitelist(model, VOCABS["german"], **whitelist_kwargs) + with torch.inference_mode(): + word = model(torch.rand(2, 3, 32, 128), return_preds=True)["preds"][0][0] + whitelist_handle.remove() + bias_handle.remove() + return word + + +@pytest.mark.parametrize("arch_name", ["crnn_vgg16_bn", "sar_resnet31", "master", "parseq", "viptr_tiny"]) +def test_add_whitelist_nearest_folds_to_base(arch_name): + model = recognition.__dict__[arch_name](pretrained=True, vocab=_NEAREST_VOCAB).eval() + for forbidden_char, base_char in _NEAREST_FOLDS.items(): + word = _force_and_decode(model, forbidden_char, strategy="nearest") + # the forbidden character is folded onto its allowed visual base (CTC collapses repeats) + assert word and set(word) == {base_char} + + +def test_add_whitelist_nearest_custom_mapping(): + model = recognition.parseq(pretrained=True, vocab=_NEAREST_VOCAB).eval() + # an explicit mapping overrides the default transliteration (ą would otherwise fold to "a") + word = _force_and_decode(model, "ą", strategy="nearest", mapping={"ą": "z"}) + assert word and set(word) == {"z"} + + +def test_add_whitelist_nearest_weights_stays_within_whitelist(): + model = recognition.crnn_vgg16_bn(pretrained=True, vocab=_NEAREST_VOCAB).eval() + allowed = set(VOCABS["german"]) + handle = add_whitelist(model, VOCABS["german"], strategy="nearest", mapping="weights") + with torch.inference_mode(): + preds = model(torch.rand(3, 3, 32, 128), return_preds=True)["preds"] + handle.remove() + assert all(char in allowed for word, _ in preds for char in word) + + +def test_add_whitelist_strategy_errors(): + model = recognition.crnn_vgg16_bn(pretrained=True, vocab=_NEAREST_VOCAB).eval() + with pytest.raises(ValueError): # mapping is meaningless without strategy="nearest" + add_whitelist(model, VOCABS["german"], mapping={"ą": "a"}) + with pytest.raises(ValueError): # unknown strategy + add_whitelist(model, VOCABS["german"], strategy="drop") + with pytest.raises(ValueError): # unknown mapping keyword + add_whitelist(model, VOCABS["german"], strategy="nearest", mapping="closest") + with pytest.raises(ValueError): # unsupported mapping type + add_whitelist(model, VOCABS["german"], strategy="nearest", mapping=123)