Skip to content
Merged
82 changes: 41 additions & 41 deletions doctr/models/factory/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import logging
import subprocess
import tempfile
import textwrap
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -101,61 +100,62 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
raise ValueError("task must be one of classification, detection, recognition, layout")

# default readme
readme = textwrap.dedent(
f"""---
language: en
tags:
- ocr
- pytorch
- doctr
- {task}
---
readme = f"""---
language: en
tags:
- ocr
- pytorch
- doctr
- {task}
---


<p align="center">
<img src="https://doctr-static.mindee.com/models?id=v0.3.1/Logo_doctr.gif&src=0" width="60%">
</p>
<p align="center">
<img src="https://doctr-static.mindee.com/models?id=v0.3.1/Logo_doctr.gif&src=0" width="60%">
</p>

**Optical Character Recognition made seamless & accessible to anyone, powered by PyTorch**
**Optical Character Recognition made seamless & accessible to anyone, powered by PyTorch**

## Task: {task}
## Task: {task}

https://github.com/mindee/doctr
https://github.com/mindee/doctr

### Example usage:
### Example usage:

```python
>>> from doctr.io import DocumentFile
>>> from doctr.models import ocr_predictor, from_hub
```python
>>> from doctr.io import DocumentFile
>>> from doctr.models import ocr_predictor, from_hub

>>> img = DocumentFile.from_images(['<image_path>'])
>>> # Load your model from the hub
>>> model = from_hub('mindee/my-model')
>>> img = DocumentFile.from_images(['<image_path>'])
>>> # Load your model from the hub
>>> model = from_hub('mindee/my-model')

>>> # Pass it to the predictor
>>> # If your model is a recognition model:
>>> predictor = ocr_predictor(det_arch='db_mobilenet_v3_large',
>>> reco_arch=model,
>>> pretrained=True)
>>> # Pass it to the predictor
>>> # If your model is a recognition model:
>>> predictor = ocr_predictor(det_arch='db_mobilenet_v3_large',
>>> reco_arch=model,
>>> pretrained=True)

>>> # If your model is a detection model:
>>> predictor = ocr_predictor(det_arch=model,
>>> reco_arch='crnn_mobilenet_v3_small',
>>> pretrained=True)
>>> # If your model is a detection model:
>>> predictor = ocr_predictor(det_arch=model,
>>> reco_arch='crnn_mobilenet_v3_small',
>>> pretrained=True)

>>> # Get your predictions
>>> res = predictor(img)
```
"""
)
>>> # Get your predictions
>>> res = predictor(img)
```
"""

# add run configuration to readme if available
if run_config is not None:
arch = run_config.arch
readme += textwrap.dedent(
f"""### Run Configuration
\n{json.dumps(vars(run_config), indent=2, ensure_ascii=False)}"""
)
readme += f"""
### Run Configuration

```json
{json.dumps(vars(run_config), indent=2, ensure_ascii=False)}
```
"""

if arch not in AVAILABLE_ARCHS[task]:
raise ValueError(
Expand Down
35 changes: 15 additions & 20 deletions doctr/models/layout/lw_detr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,23 +143,20 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int
results: list[tuple[list[int], np.ndarray, list[float]]] = []

for b in range(boxes.shape[0]):
# Convert logits to probabilities and get scores and labels
exp = np.exp(logits[b] - logits[b].max(axis=-1, keepdims=True))
prob = exp / exp.sum(axis=-1, keepdims=True)
# Sigmoid scores (the model is trained with a sigmoid-based (IA-BCE) loss without
# a background class)
prob = 1.0 / (1.0 + np.exp(-logits[b])) # (num_queries, num_classes)
num_classes = prob.shape[-1]

prob_fg = prob[:, :-1] # exclude background
scores = prob_fg.max(axis=-1)
labels = prob_fg.argmax(axis=-1)
# Keep only the topk (query, class) pairs before NMS
flat_prob = prob.reshape(-1)
topk = min(self.topk, flat_prob.size) if self.topk is not None else flat_prob.size
topk_idxs = np.argsort(flat_prob)[::-1][:topk]

# Keep only topk predictions before NMS
if self.topk is not None and len(scores) > self.topk:
idxs = np.argsort(scores)[::-1][: self.topk]
else:
idxs = np.arange(len(scores))

scores_b = scores[idxs]
labels_b = labels[idxs]
bboxes = boxes[b][idxs]
scores_b = flat_prob[topk_idxs]
labels_b = topk_idxs % num_classes
query_idxs = topk_idxs // num_classes
bboxes = boxes[b][query_idxs]

mask = scores_b > self.score_thresh

Expand Down Expand Up @@ -275,19 +272,17 @@ def to_quad(box: np.ndarray):
if box.shape == (4,):
x1, y1, x2, y2 = box
return np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]], dtype=np.float32)
if box.shape == (8,):
return box.reshape(4, 2)
if box.shape == (4, 2):
return box.astype(np.float32)
raise ValueError(f"Unsupported box shape: {box.shape}")
raise ValueError(f"Unsupported box shape: {box.shape}") # pragma: no cover

for sample in target:
boxes_all = []
labels_all = []

for class_name, boxes in sample.items():
if class_name not in class_to_id:
raise ValueError(f"Unknown class name: {class_name}")
raise ValueError(f"Unknown class name: {class_name}") # pragma: no cover

cls_id = class_to_id[class_name]
boxes = np.asarray(boxes)
Expand All @@ -307,7 +302,7 @@ def to_quad(box: np.ndarray):
labels_all.append(cls_id)

targets.append({
"boxes": np.asarray(boxes_all, dtype=np.float32),
"boxes": np.asarray(boxes_all, dtype=np.float32) if boxes_all else np.zeros((0, 6), dtype=np.float32),
"labels": np.asarray(labels_all, dtype=np.int64),
})

Expand Down
Loading
Loading