diff --git a/doctr/models/factory/hub.py b/doctr/models/factory/hub.py
index 45d57a326d..57839f5646 100644
--- a/doctr/models/factory/hub.py
+++ b/doctr/models/factory/hub.py
@@ -9,7 +9,6 @@
import logging
import subprocess
import tempfile
-import textwrap
from pathlib import Path
from typing import Any
@@ -101,61 +100,62 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
raise ValueError("task must be one of classification, detection, recognition, layout")
# default readme
- readme = textwrap.dedent(
- f"""---
- language: en
- tags:
- - ocr
- - pytorch
- - doctr
- - {task}
- ---
+ readme = f"""---
+language: en
+tags:
+- ocr
+- pytorch
+- doctr
+- {task}
+---
-
-
-
+
+
+
- **Optical Character Recognition made seamless & accessible to anyone, powered by PyTorch**
+**Optical Character Recognition made seamless & accessible to anyone, powered by PyTorch**
- ## Task: {task}
+## Task: {task}
- https://github.com/mindee/doctr
+https://github.com/mindee/doctr
- ### Example usage:
+### Example usage:
- ```python
- >>> from doctr.io import DocumentFile
- >>> from doctr.models import ocr_predictor, from_hub
+```python
+>>> from doctr.io import DocumentFile
+>>> from doctr.models import ocr_predictor, from_hub
- >>> img = DocumentFile.from_images([''])
- >>> # Load your model from the hub
- >>> model = from_hub('mindee/my-model')
+>>> img = DocumentFile.from_images([''])
+>>> # Load your model from the hub
+>>> model = from_hub('mindee/my-model')
- >>> # Pass it to the predictor
- >>> # If your model is a recognition model:
- >>> predictor = ocr_predictor(det_arch='db_mobilenet_v3_large',
- >>> reco_arch=model,
- >>> pretrained=True)
+>>> # Pass it to the predictor
+>>> # If your model is a recognition model:
+>>> predictor = ocr_predictor(det_arch='db_mobilenet_v3_large',
+>>> reco_arch=model,
+>>> pretrained=True)
- >>> # If your model is a detection model:
- >>> predictor = ocr_predictor(det_arch=model,
- >>> reco_arch='crnn_mobilenet_v3_small',
- >>> pretrained=True)
+>>> # If your model is a detection model:
+>>> predictor = ocr_predictor(det_arch=model,
+>>> reco_arch='crnn_mobilenet_v3_small',
+>>> pretrained=True)
- >>> # Get your predictions
- >>> res = predictor(img)
- ```
- """
- )
+>>> # Get your predictions
+>>> res = predictor(img)
+```
+"""
# add run configuration to readme if available
if run_config is not None:
arch = run_config.arch
- readme += textwrap.dedent(
- f"""### Run Configuration
- \n{json.dumps(vars(run_config), indent=2, ensure_ascii=False)}"""
- )
+ readme += f"""
+### Run Configuration
+
+```json
+{json.dumps(vars(run_config), indent=2, ensure_ascii=False)}
+```
+"""
if arch not in AVAILABLE_ARCHS[task]:
raise ValueError(
diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py
index 2aa59e7d18..b15ca54099 100644
--- a/doctr/models/layout/lw_detr/base.py
+++ b/doctr/models/layout/lw_detr/base.py
@@ -143,23 +143,20 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int
results: list[tuple[list[int], np.ndarray, list[float]]] = []
for b in range(boxes.shape[0]):
- # Convert logits to probabilities and get scores and labels
- exp = np.exp(logits[b] - logits[b].max(axis=-1, keepdims=True))
- prob = exp / exp.sum(axis=-1, keepdims=True)
+ # Sigmoid scores (the model is trained with a sigmoid-based (IA-BCE) loss without
+ # a background class)
+ prob = 1.0 / (1.0 + np.exp(-logits[b])) # (num_queries, num_classes)
+ num_classes = prob.shape[-1]
- prob_fg = prob[:, :-1] # exclude background
- scores = prob_fg.max(axis=-1)
- labels = prob_fg.argmax(axis=-1)
+ # Keep only the topk (query, class) pairs before NMS
+ flat_prob = prob.reshape(-1)
+ topk = min(self.topk, flat_prob.size) if self.topk is not None else flat_prob.size
+ topk_idxs = np.argsort(flat_prob)[::-1][:topk]
- # Keep only topk predictions before NMS
- if self.topk is not None and len(scores) > self.topk:
- idxs = np.argsort(scores)[::-1][: self.topk]
- else:
- idxs = np.arange(len(scores))
-
- scores_b = scores[idxs]
- labels_b = labels[idxs]
- bboxes = boxes[b][idxs]
+ scores_b = flat_prob[topk_idxs]
+ labels_b = topk_idxs % num_classes
+ query_idxs = topk_idxs // num_classes
+ bboxes = boxes[b][query_idxs]
mask = scores_b > self.score_thresh
@@ -275,11 +272,9 @@ def to_quad(box: np.ndarray):
if box.shape == (4,):
x1, y1, x2, y2 = box
return np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]], dtype=np.float32)
- if box.shape == (8,):
- return box.reshape(4, 2)
if box.shape == (4, 2):
return box.astype(np.float32)
- raise ValueError(f"Unsupported box shape: {box.shape}")
+ raise ValueError(f"Unsupported box shape: {box.shape}") # pragma: no cover
for sample in target:
boxes_all = []
@@ -287,7 +282,7 @@ def to_quad(box: np.ndarray):
for class_name, boxes in sample.items():
if class_name not in class_to_id:
- raise ValueError(f"Unknown class name: {class_name}")
+ raise ValueError(f"Unknown class name: {class_name}") # pragma: no cover
cls_id = class_to_id[class_name]
boxes = np.asarray(boxes)
@@ -307,7 +302,7 @@ def to_quad(box: np.ndarray):
labels_all.append(cls_id)
targets.append({
- "boxes": np.asarray(boxes_all, dtype=np.float32),
+ "boxes": np.asarray(boxes_all, dtype=np.float32) if boxes_all else np.zeros((0, 6), dtype=np.float32),
"labels": np.asarray(labels_all, dtype=np.int64),
})
diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py
index 08f9ea9fb4..ed970f9aa9 100644
--- a/doctr/models/layout/lw_detr/layers/pytorch.py
+++ b/doctr/models/layout/lw_detr/layers/pytorch.py
@@ -12,7 +12,55 @@
from doctr.models.modules import ChannelLayerNorm
from doctr.models.utils import conv_sequence_pt
-__all__ = ["MultiScaleProjector", "C2fBottleneck", "LWDETRHead", "LWDETRDecoder", "LWDETRMultiscaleDeformableAttention"]
+__all__ = [
+ "MultiScaleProjector",
+ "C2fBottleneck",
+ "LWDETRHead",
+ "LWDETRDecoder",
+ "LWDETRMultiscaleDeformableAttention",
+ "refine_obb_boxes",
+]
+
+
+def refine_obb_boxes(reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor:
+ """Refine oriented bounding boxes by applying predicted deltas to reference points.
+ Both reference points and deltas are in the format (cx, cy, w, h, sinθ, cosθ).
+
+ cx' = cx + delta_cx * w
+ cy' = cy + delta_cy * h
+ w' = w * exp(delta_w)
+ h' = h * exp(delta_h)
+ (sinθ', cosθ') = rotation composition of (sinθ, cosθ) with the normalized delta rotation
+
+ Args:
+ reference_points: (..., 6) tensor containing the reference points
+ deltas: (..., 6) tensor containing the predicted deltas
+
+ Returns:
+ refined_boxes: (..., 6) tensor containing the refined bounding boxes
+ """
+ reference_points = reference_points.to(deltas.device)
+ # center
+ cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2]
+ # size: clamp deltas to prevent exp() from overflowing during early training.
+ # NOTE: the upper bound must allow the encoder proposals (w = h = 0.05) to reach full-page boxes
+ # in a single refinement step: exp(3.5) * 0.05 ~= 1.66, so boxes up to (and beyond) the full canvas
+ # remain reachable. A tighter bound (e.g. 2.0 -> 0.05 * exp(2) ~= 0.37) silently caps the size of the
+ # encoder proposals and creates an irreducible loss floor for large layout items (tables, pictures).
+ wh = torch.clamp(deltas[..., 2:4], min=-5.0, max=3.5).exp() * reference_points[..., 2:4]
+ # rotation (eps avoids division-by-zero NaN creation)
+ delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6)
+ sin_delta = delta_rot[..., 0:1]
+ cos_delta = delta_rot[..., 1:2]
+ sin_ref = reference_points[..., 4:5]
+ cos_ref = reference_points[..., 5:6]
+
+ # compose rotations
+ sin_new = sin_ref * cos_delta + cos_ref * sin_delta
+ cos_new = cos_ref * cos_delta - sin_ref * sin_delta
+ rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6)
+
+ return torch.cat((cxcy, wh, rot), dim=-1)
class LWDETRHead(nn.Module):
@@ -109,7 +157,7 @@ def forward(
# Crash prevention: ensure seq_len is perfectly divisible
assert seq_len % self.group_detr == 0, (
f"Seq len {seq_len} must be divisible by group_detr {self.group_detr}"
- ) # noqa: E501
+ )
hidden_states_original = hidden_states
if position_embeddings is not None:
@@ -249,8 +297,8 @@ def forward(
value = self.value_proj(encoder_hidden_states)
if attention_mask is not None:
- # we invert the attention_mask
- value = value.masked_fill(~attention_mask[..., None], float(0))
+ # attention_mask contains True on padded positions -> zero-out the padded values
+ value = value.masked_fill(attention_mask[..., None], float(0))
value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(hidden_states).view(
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
@@ -487,38 +535,39 @@ def __init__(
)
def get_reference(
- self, reference_points: torch.Tensor, valid_ratios: torch.Tensor
+ self,
+ reference_points: torch.Tensor,
+ num_levels: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""This function computes the reference point inputs and positional embeddings for the decoder layers.
Args:
reference_points: (batch_size, num_queries, 6)
tensor containing the current reference points in the format (cx, cy, w, h, sinθ, cosθ)
- valid_ratios: (batch_size, num_levels, 2)
- tensor containing the valid ratios for each level of the input feature maps
+ num_levels: number of feature levels used in the decoder (1 for LW-DETR small and medium)
Returns:
- reference_points_inputs: (batch_size, num_queries, 1, num_levels, 4)
- tensor containing the reference point inputs for the decoder layers,
- which are the normalized center coordinates,
- width and height of the bounding boxes w.r.t. the valid ratios of the input feature maps
+ reference_points_inputs: (batch_size, num_queries, num_levels, 6)
+ tensor containing the reference point inputs for the decoder layers
query_pos: (batch_size, num_queries, d_model)
tensor containing the positional embeddings for the decoder layers,
which are computed from the reference points using sine and cosine functions and a linear projection
"""
- obj_center = reference_points[..., :4]
- spatial_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None]
- # Extract angles
- angle = reference_points[..., 4:6] # (sin, cos)
- angle_expanded = angle[:, :, None]
+ ref_xywh = reference_points[..., :4]
+ angle = reference_points[..., 4:6]
+
+ spatial_inputs = ref_xywh[:, :, None].expand(-1, -1, num_levels, -1)
+ angle_expanded = angle[:, :, None].expand(-1, -1, num_levels, -1)
+
reference_points_inputs = torch.cat([spatial_inputs, angle_expanded], dim=-1)
- # DETR positional encoding
- query_sine_embed = gen_sine_position_embeddings(spatial_inputs[:, :, 0, :], self.d_model)
+ # generate sine positional embeddings from the reference points and
+ # project them to get the final query positional embeddings
+ query_sine_embed = gen_sine_position_embeddings(ref_xywh, self.d_model)
base_query_pos = self.ref_point_head(query_sine_embed)
- # Angle embedding
+ # Angle embedding: use the same sine/cosine encoding scheme as for the spatial coordinates,
+ # but with a separate linear projection.
sin_t = angle[..., 0:1]
cos_t = angle[..., 1:2]
-
angle_feat = torch.cat(
[
sin_t,
@@ -528,51 +577,45 @@ def get_reference(
],
dim=-1,
)
-
angle_emb = self.angle_proj(angle_feat)
# Combine
query_pos = base_query_pos + angle_emb
return reference_points_inputs, query_pos
- def refine_boxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor:
- reference_points = reference_points.to(deltas.device)
- cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2]
-
- # Clamp deltas to prevent exp() from shooting to Infinity during early training
- wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=2.0).exp() * reference_points[..., 2:4]
-
- # Add eps=1e-6 to avoid division-by-zero NaN creation
- delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6)
- sin_delta = delta_rot[..., 0:1]
- cos_delta = delta_rot[..., 1:2]
- sin_ref = reference_points[..., 4:5]
- cos_ref = reference_points[..., 5:6]
-
- sin_new = sin_ref * cos_delta + cos_ref * sin_delta
- cos_new = cos_ref * cos_delta - sin_ref * sin_delta
-
- # Add eps=1e-6 here too
- rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6)
-
- return torch.cat((cxcy, wh, rot), dim=-1)
-
def forward(
self,
inputs_embeds: torch.Tensor | None,
reference_points: torch.Tensor,
- spatial_shapes_list: torch.Tensor,
- valid_ratios: torch.Tensor,
+ spatial_shapes_list: list[tuple[int, int]],
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor | None = None,
):
+ """Forward pass of the decoder with iterative (per-layer) box refinement.
+
+ Returns:
+ last_hidden_state: (B, Q, d_model) normalized hidden states of the last decoder layer
+ intermediate: (num_layers, B, Q, d_model) normalized hidden states of every decoder layer
+ intermediate_reference_points: (num_layers, B, Q, 6) the reference points used as INPUT to each
+ decoder layer. Box predictions for layer i must be computed as
+ `refine_obb_boxes(intermediate_reference_points[i], bbox_embed(intermediate[i]))`.
+ Following LW-DETR, the refined reference points are detached before being fed to the next
+ layer, while the undetached versions are kept in the returned stack so that gradients can
+ flow one extra refinement step when computing the auxiliary losses.
+ """
intermediate: list[torch.Tensor] = []
+ # reference points fed to each decoder layer (input of layer i at index i)
intermediate_reference_points: list[torch.Tensor] = [reference_points]
if inputs_embeds is not None:
hidden_states = inputs_embeds
- reference_points_inputs, query_pos = self.get_reference(reference_points, valid_ratios)
+ num_levels = len(spatial_shapes_list)
+
+ reference_points_inputs, query_pos = self.get_reference(
+ reference_points,
+ num_levels=num_levels,
+ )
for lid, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
@@ -585,25 +628,23 @@ def forward(
)
hidden_states_norm = self.layernorm(hidden_states)
+ intermediate.append(hidden_states_norm)
# iterative refinement
- if self.bbox_embed is not None:
+ if self.bbox_embed is not None and lid < len(self.layers) - 1:
delta = self.bbox_embed(hidden_states_norm)
+ new_reference_points = refine_obb_boxes(reference_points, delta)
- reference_points = self.refine_boxes(
- reference_points.squeeze(2),
- delta,
- )
-
- intermediate_reference_points.append(reference_points)
+ # keep the undetached version for the auxiliary losses ("look forward" supervision)
+ intermediate_reference_points.append(new_reference_points)
+ # the next layer consumes detached reference points
+ reference_points = new_reference_points.detach()
reference_points_inputs, query_pos = self.get_reference(
reference_points,
- valid_ratios,
+ num_levels=num_levels,
)
- intermediate.append(hidden_states_norm)
-
intermediate_stack = torch.stack(intermediate)
last_hidden_state = intermediate_stack[-1]
diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py
index 848d44280e..96f2edd205 100644
--- a/doctr/models/layout/lw_detr/pytorch.py
+++ b/doctr/models/layout/lw_detr/pytorch.py
@@ -10,6 +10,7 @@
import numpy as np
import torch
+from scipy.optimize import linear_sum_assignment
from torch import nn
from torch.nn import functional as F
@@ -17,7 +18,12 @@
from ...utils import load_pretrained_params
from .base import _LWDETR, LWDETRPostProcessor
-from .layers import LWDETRDecoder, LWDETRHead, LWDETRMultiscaleDeformableAttention, MultiScaleProjector
+from .layers import (
+ LWDETRDecoder,
+ LWDETRHead,
+ MultiScaleProjector,
+ refine_obb_boxes,
+)
__all__ = ["LWDETR", "lw_detr_s", "lw_detr_m"]
@@ -68,6 +74,82 @@
}
+def _obb_covariance_components(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Compute the components (a, b, c) of the Gaussian covariance matrix [[a, c], [c, b]] associated with
+ oriented boxes in (cx, cy, w, h, sinθ, cosθ) format, following ProbIoU (https://arxiv.org/abs/2106.06072).
+
+ Args:
+ boxes: (..., 6) tensor of oriented boxes
+
+ Returns:
+ a, b, c: (...,) tensors of covariance components
+ """
+ w = boxes[..., 2].clamp(min=1e-6)
+ h = boxes[..., 3].clamp(min=1e-6)
+ rot = F.normalize(boxes[..., 4:6], dim=-1, eps=1e-6)
+ sin, cos = rot[..., 0], rot[..., 1]
+
+ var_w = w.pow(2) / 12.0
+ var_h = h.pow(2) / 12.0
+ cos2 = cos.pow(2)
+ sin2 = sin.pow(2)
+
+ a = var_w * cos2 + var_h * sin2
+ b = var_w * sin2 + var_h * cos2
+ c = (var_w - var_h) * cos * sin
+ return a, b, c
+
+
+def _probiou(
+ boxes1: torch.Tensor,
+ boxes2: torch.Tensor,
+ pairwise: bool = False,
+ scale: float = 1.0,
+ eps: float = 1e-7,
+) -> torch.Tensor:
+ """Compute the probabilistic IoU between oriented boxes in (cx, cy, w, h, sinθ, cosθ) format,
+ as described in `"Gaussian Bounding Boxes and Probabilistic IoU" `_.
+
+ Args:
+ boxes1: (N, 6) tensor of oriented boxes
+ boxes2: (M, 6) tensor of oriented boxes (M = N if `pairwise` is False)
+ pairwise: if True, return the (N, M) matrix of IoUs, otherwise the element-wise (N,) IoUs
+ scale: factor applied to the spatial components (cx, cy, w, h) before computing the IoU.
+ ProbIoU is mathematically scale-invariant, but the stabilizing `eps` terms are not: in
+ normalized [0, 1] coordinates the covariance terms of small boxes (e.g. checkboxes) fall
+ below `eps` and the result degenerates (non-overlapping small boxes can score ~1).
+ eps: small value for numerical stability
+
+ Returns:
+ probiou: (N,) or (N, M) tensor of IoU-like similarities in [0, 1]
+ """
+ if scale != 1.0:
+ boxes1 = torch.cat([boxes1[..., :4] * scale, boxes1[..., 4:]], dim=-1)
+ boxes2 = torch.cat([boxes2[..., :4] * scale, boxes2[..., 4:]], dim=-1)
+
+ x1, y1 = boxes1[..., 0], boxes1[..., 1]
+ x2, y2 = boxes2[..., 0], boxes2[..., 1]
+ a1, b1, c1 = _obb_covariance_components(boxes1)
+ a2, b2, c2 = _obb_covariance_components(boxes2)
+
+ if pairwise:
+ x1, y1, a1, b1, c1 = (t.unsqueeze(-1) for t in (x1, y1, a1, b1, c1)) # (N, 1)
+ x2, y2, a2, b2, c2 = (t.unsqueeze(-2) for t in (x2, y2, a2, b2, c2)) # (1, M)
+
+ denom = (a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps
+ t1 = ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / denom * 0.25
+ t2 = ((c1 + c2) * (x2 - x1) * (y1 - y2)) / denom * 0.5
+ t3 = (
+ ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2)).clamp(min=eps)
+ / (4 * ((a1 * b1 - c1.pow(2)).clamp(min=eps) * (a2 * b2 - c2.pow(2)).clamp(min=eps)).sqrt() + eps)
+ + eps
+ ).log() * 0.5
+
+ bhattacharyya = (t1 + t2 + t3).clamp(min=eps, max=100.0)
+ hellinger = (1.0 - (-bhattacharyya).exp() + eps).sqrt()
+ return 1.0 - hellinger
+
+
class LWDETRBackbone(nn.Module):
"""Backbone of LW-DETR, based on a ViT Det architecture. The backbone is used as feature extractor.
@@ -102,6 +184,45 @@ def __init__(
num_blocks=num_blocks,
)
+ def _resize_padding_mask(self, mask: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
+ """Resize padding mask to feature-map size
+
+ Args:
+ mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
+ size: the target size (H', W') for the resized mask
+
+ Returns:
+ resized_mask: a binary mask of shape [batch_size x H' x W'], containing 1 on padded pixels
+ """
+ if mask.dtype != torch.bool:
+ mask = mask.bool()
+
+ valid = (~mask).float().unsqueeze(1) # True/1 = valid pixels
+
+ # Data-dependent sanity checks
+ if self.training: # pragma: no cover
+ if (valid.flatten(1).sum(dim=1) == 0).any():
+ bad = torch.where(valid.flatten(1).sum(dim=1) == 0)[0].tolist()
+ raise RuntimeError(f"Input masks are fully padded before resizing: {bad}")
+
+ # Use max pooling to resize the valid mask:
+ # a pixel in the resized mask is valid if at least one pixel
+ # in the corresponding window in the input mask is valid
+ h_in, w_in = int(mask.shape[-2]), int(mask.shape[-1])
+ h_out, w_out = int(size[0]), int(size[1])
+ kh, kw = h_in // h_out, w_in // w_out
+ valid_resized = F.max_pool2d(valid, kernel_size=(kh, kw), stride=(kh, kw)) > 0
+
+ resized_mask = ~valid_resized.squeeze(1)
+
+ # Data-dependent sanity checks
+ if self.training: # pragma: no cover
+ if resized_mask.flatten(1).all(dim=1).any():
+ bad = torch.where(resized_mask.flatten(1).all(dim=1))[0].tolist()
+ raise RuntimeError(f"Feature masks became fully padded after resizing: {bad}")
+
+ return resized_mask
+
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> list[tuple[torch.Tensor, torch.Tensor]]:
"""Forward pass of the backbone.
@@ -120,10 +241,7 @@ def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> list[tup
# [(B, C, H, W)]
if mask is None: # pragma: no cover
mask = torch.zeros((x.shape[0], x.shape[2], x.shape[3]), dtype=torch.bool, device=x.device)
- return [
- (feat, F.interpolate(mask.unsqueeze(1).float(), size=feat.shape[-2:], mode="nearest").squeeze(1).bool())
- for feat in feats
- ]
+ return [(feat, self._resize_padding_mask(mask, feat.shape[2:])) for feat in feats]
class LWDETR(nn.Module, _LWDETR):
@@ -153,11 +271,11 @@ def __init__(
self,
feat_extractor: LWDETRBackbone,
class_names: list[str],
- score_thresh: float = 0.3,
+ score_thresh: float = 0.25,
iou_thresh: float = 0.5,
d_model: int = 256,
- num_queries: int = 130,
- group_detr: int = 1,
+ num_queries: int = 195,
+ group_detr: int = 13,
dec_layers: int = 3,
sa_num_heads: int = 8,
ca_num_heads: int = 16,
@@ -171,7 +289,8 @@ def __init__(
super().__init__()
self.class_names: list[str] = class_names
- self.num_classes = len(self.class_names) + 1 # +1 for background class
+ # No background class: the model is trained with a sigmoid-based (IA-BCE) loss
+ self.num_classes = len(self.class_names)
self.cfg = cfg
self.exportable = exportable
self.assume_straight_pages = assume_straight_pages
@@ -183,10 +302,6 @@ def __init__(
self.d_model = d_model
self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 6)
- # Initialize angle to (sin=0, cos=1)
- with torch.no_grad():
- self.reference_point_embed.weight[:, 4] = 0.0 # sinθ
- self.reference_point_embed.weight[:, 5] = 1.0 # cosθ
self.query_feat = nn.Embedding(self.num_queries * self.group_detr, self.d_model)
@@ -242,44 +357,25 @@ def __init__(
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, std=0.02)
- elif isinstance(m, LWDETRMultiscaleDeformableAttention):
- nn.init.constant_(m.sampling_offsets.weight, 0.0)
-
- thetas = torch.arange(m.n_heads, dtype=torch.float32) * (2.0 * math.pi / m.n_heads)
- grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
- grid_init = (
- (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
- .view(m.n_heads, 1, 1, 2)
- .repeat(1, m.n_levels, m.n_points, 1)
- )
-
- for i in range(m.n_points):
- grid_init[:, :, i, :] *= i + 1
-
- with torch.no_grad():
- m.sampling_offsets.bias.copy_(grid_init.view(-1))
-
- nn.init.constant_(m.attention_weights.weight, 0.0)
- nn.init.constant_(m.attention_weights.bias, 0.0)
-
- nn.init.xavier_uniform_(m.value_proj.weight)
- nn.init.zeros_(m.value_proj.bias)
-
- nn.init.xavier_uniform_(m.output_proj.weight)
- nn.init.zeros_(m.output_proj.bias)
if isinstance(m, nn.Linear) and m.out_features == self.num_classes:
prior_prob = 0.01
bias_value = -math.log((1 - prior_prob) / prior_prob)
if m.bias is not None:
nn.init.constant_(m.bias, bias_value)
- if isinstance(m, LWDETRHead):
- last = m.layers[-1]
- if isinstance(last, nn.Linear):
- nn.init.zeros_(last.weight)
- nn.init.zeros_(last.bias)
- if last.bias.shape[0] == 6:
- nn.init.constant_(last.bias[5], 1.0)
+
+ # Initialize the iterative refinement heads to predict zero deltas (i.e. identity refinement)
+ # at the start of training, to stabilize training in the early stages when the encoder proposals are still noisy
+ with torch.no_grad():
+ for head in [self.bbox_embed, *self.enc_out_bbox_embed]:
+ last = head.layers[-1]
+ last.weight.zero_()
+ last.bias.zero_()
+ last.bias[5] = 1.0 # cosθ of the rotation delta -> identity rotation
+
+ # The reference point embedding acts as a learned delta composed with the encoder proposals
+ self.reference_point_embed.weight.zero_()
+ self.reference_point_embed.weight[:, 5] = 1.0 # cosθ
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
"""Load pretrained parameters onto the model
@@ -290,62 +386,6 @@ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
"""
load_pretrained_params(self, path_or_url, **kwargs)
- def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor:
- """Refine bounding boxes by applying the predicted deltas to the reference points.
- The reference points are in the format (cx, cy, w, h, sinθ, cosθ), and the deltas are in the same format.
- The refined boxes are computed as follows:
-
- cx' = cx + delta_cx * w
- cy' = cy + delta_cy * h
- w' = w * exp(delta_w)
- h' = h * exp(delta_h)
- sinθ' = sinθ * cosΔ + cosθ * sinΔ
- cosθ' = cosθ * cosΔ - sinθ * sinΔ
-
- Args:
- reference_points: (N, S, 6) tensor containing the reference points
- deltas: (N, S, 6) tensor containing the predicted deltas
-
- Returns:
- refined_boxes: (N, S, 6) tensor containing the refined bounding boxes
- """
- reference_points = reference_points.to(deltas.device)
- # center
- cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2]
- # size
- wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=2.0).exp() * reference_points[..., 2:4]
- # rotation
- delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6)
- sin_delta = delta_rot[..., 0:1]
- cos_delta = delta_rot[..., 1:2]
- sin_ref = reference_points[..., 4:5]
- cos_ref = reference_points[..., 5:6]
-
- # compose rotations
- sin_new = sin_ref * cos_delta + cos_ref * sin_delta
- cos_new = cos_ref * cos_delta - sin_ref * sin_delta
- rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6)
-
- return torch.cat((cxcy, wh, rot), dim=-1)
-
- def get_valid_ratio(self, mask: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
- """Get the valid ratio of all feature maps.
-
- Args:
- mask: (N, H, W) binary tensor containing 1 on padded pixels
- dtype: the desired data type of the output tensor
-
- Returns:
- valid_ratio: (N, 2) tensor containing the valid ratio of width and height for each image in the batch
- """
- _, height, width = mask.shape
- valid_height = torch.sum(~mask[:, :, 0], 1)
- valid_width = torch.sum(~mask[:, 0, :], 1)
- valid_ratio_height = valid_height.to(dtype) / height
- valid_ratio_width = valid_width.to(dtype) / width
- valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1)
- return valid_ratio
-
def gen_encoder_output_proposals(
self, enc_output: torch.Tensor, padding_mask: torch.Tensor, spatial_shapes: list[tuple[int, int]]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -367,56 +407,48 @@ def gen_encoder_output_proposals(
batch_size = enc_output.shape[0]
proposals = []
_cur = 0
- for level, (height, width) in enumerate(spatial_shapes):
- mask_flatten_ = padding_mask[:, _cur : (_cur + height * width)].view(batch_size, height, width, 1)
- valid_height = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
- valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
+ for level, (height, width) in enumerate(spatial_shapes):
grid_y, grid_x = torch.meshgrid(
- torch.linspace(
- 0,
- height - 1,
- height,
- dtype=enc_output.dtype,
- device=enc_output.device,
- ),
- torch.linspace(
- 0,
- width - 1,
- width,
- dtype=enc_output.dtype,
- device=enc_output.device,
- ),
+ torch.arange(height, dtype=enc_output.dtype, device=enc_output.device),
+ torch.arange(width, dtype=enc_output.dtype, device=enc_output.device),
indexing="ij",
)
- grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
- scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2)
- grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale
+ grid = torch.stack([grid_x, grid_y], dim=-1)
+ grid = grid.unsqueeze(0).expand(batch_size, -1, -1, -1)
+
+ scale = torch.tensor(
+ [width, height],
+ dtype=enc_output.dtype,
+ device=enc_output.device,
+ ).view(1, 1, 1, 2)
+
+ # Canvas-normalized center coordinates
+ grid = (grid + 0.5) / scale
width_height = torch.ones_like(grid) * 0.05 * (2.0**level)
- # add default rotation (sin=0, cos=1)
sin = torch.zeros_like(grid[..., :1])
cos = torch.ones_like(grid[..., :1])
- proposal = torch.cat((grid, width_height, sin, cos), -1).view(batch_size, -1, 6)
+ proposal = torch.cat((grid, width_height, sin, cos), dim=-1).view(batch_size, -1, 6)
proposals.append(proposal)
_cur += height * width
- output_proposals = torch.cat(proposals, 1)
- spatial_valid = ((output_proposals[..., :4] > 0.01) & (output_proposals[..., :4] < 0.99)).all(-1, keepdim=True)
- output_proposals_valid = spatial_valid
- invalid_mask = padding_mask | ~output_proposals_valid.squeeze(-1)
- invalid_mask = padding_mask.unsqueeze(-1) | ~output_proposals_valid
- output_proposals = output_proposals.masked_fill(invalid_mask, float(0))
+ output_proposals = torch.cat(proposals, dim=1)
+
+ spatial_valid = ((output_proposals[..., :4] > 0.01) & (output_proposals[..., :4] < 0.99)).all(
+ dim=-1, keepdim=True
+ )
+ invalid_mask = padding_mask.unsqueeze(-1) | ~spatial_valid
+
+ output_proposals = output_proposals.masked_fill(invalid_mask, 0.0)
+ object_query = enc_output.masked_fill(invalid_mask, 0.0)
- # assign each pixel as an object query
- object_query = enc_output
- object_query = object_query.masked_fill(invalid_mask, float(0))
return object_query, output_proposals, invalid_mask
def forward(
self,
input: torch.Tensor,
- masks: torch.Tensor,
+ masks: torch.Tensor | None = None,
target: list[dict[str, np.ndarray]] | None = None,
return_model_output: bool = False,
return_preds: bool = False,
@@ -455,7 +487,6 @@ def forward(
mask_flatten_list.append(mask)
source_flatten = torch.cat(source_flatten_list, 1)
mask_flatten = torch.cat(mask_flatten_list, 1)
- valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in feats_masks], 1)
tgt = query_feat.unsqueeze(0).expand(batch_size, -1, -1)
reference_points = reference_points.unsqueeze(0).expand(batch_size, -1, -1)
@@ -469,7 +500,7 @@ def forward(
topk_coords_logits_list: list[torch.Tensor] = []
- # encoder predictions for auxiliary losses
+ # encoder predictions on the selected top-k proposals, kept undetached for the auxiliary loss
all_group_enc_logits: list[torch.Tensor] = []
all_group_enc_coords: list[torch.Tensor] = []
@@ -478,14 +509,11 @@ def forward(
group_object_query = self.enc_output_norm[group_id](group_object_query)
group_enc_outputs_class = self.enc_out_class_embed[group_id](group_object_query)
- all_group_enc_logits.append(group_enc_outputs_class)
group_enc_outputs_class_masked = group_enc_outputs_class.masked_fill(invalid_mask, float("-inf"))
group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query)
- group_enc_outputs_coord = self.refine_bboxes(output_proposals, group_delta_bbox)
-
- all_group_enc_coords.append(group_enc_outputs_coord)
+ group_enc_outputs_coord = refine_obb_boxes(output_proposals, group_delta_bbox)
group_topk_proposals = torch.topk(group_enc_outputs_class_masked.max(-1)[0], topk, dim=1)[1]
@@ -494,23 +522,33 @@ def forward(
1,
group_topk_proposals.unsqueeze(-1).repeat(1, 1, 6),
)
- group_topk_coords_logits = group_topk_coords_logits_undetach
- topk_coords_logits_list.append(group_topk_coords_logits)
+ # the auxiliary loss supervises only the selected proposals,
+ # so gather the matching class logits as well
+ group_topk_logits_undetach = torch.gather(
+ group_enc_outputs_class,
+ 1,
+ group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.num_classes),
+ )
+ all_group_enc_logits.append(group_topk_logits_undetach)
+ all_group_enc_coords.append(group_topk_coords_logits_undetach)
+
+ # the decoder consumes detached proposals as initial reference points
+ topk_coords_logits_list.append(group_topk_coords_logits_undetach.detach())
topk_coords_logits = torch.cat(topk_coords_logits_list, 1)
- reference_points = self.refine_bboxes(topk_coords_logits, reference_points)
+ reference_points = refine_obb_boxes(topk_coords_logits, reference_points)
last_hidden_states, intermediate, intermediate_reference_points = self.decoder(
inputs_embeds=tgt,
reference_points=reference_points,
spatial_shapes_list=spatial_shapes_list,
- valid_ratios=valid_ratios,
encoder_hidden_states=source_flatten,
+ encoder_attention_mask=mask_flatten,
)
logits = self.class_embed(last_hidden_states)
pred_boxes_delta = self.bbox_embed(last_hidden_states)
- pred_boxes = self.refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta)
+ pred_boxes = refine_obb_boxes(intermediate_reference_points[-1], pred_boxes_delta)
out: dict[str, Any] = {}
@@ -534,223 +572,163 @@ def _postprocess(logits, boxes):
# Build target
processed_targets = self.build_target(target, self.class_names)
- # Main loss from final decoder layer (group DETR)
+ # ProbIoU is computed in pixel coordinates
+ box_scale = float(max(input.shape[-2], input.shape[-1]))
+
+ # Main loss from final decoder layer (group DETR: each group is matched independently)
split_logits = logits.chunk(group_detr, dim=1)
split_boxes = pred_boxes.chunk(group_detr, dim=1)
main_loss: float | torch.Tensor = 0.0
for g_logits, g_boxes in zip(split_logits, split_boxes):
- main_loss += self.compute_loss(g_logits, g_boxes, processed_targets)
+ main_loss += self.compute_loss(g_logits, g_boxes, processed_targets, box_scale=box_scale)
loss = main_loss / group_detr
# Auxiliary losses from intermediate decoder layers
+ # (`intermediate_reference_points[i]` is the reference INPUT to decoder layer i)
for i in range(intermediate.shape[0] - 1):
aux_logits = self.class_embed(intermediate[i])
aux_boxes_delta = self.bbox_embed(intermediate[i])
- aux_boxes = self.refine_bboxes(intermediate_reference_points[i], aux_boxes_delta)
+ aux_boxes = refine_obb_boxes(intermediate_reference_points[i], aux_boxes_delta)
split_aux_logits = aux_logits.chunk(group_detr, dim=1)
split_aux_boxes = aux_boxes.chunk(group_detr, dim=1)
aux_loss: float | torch.Tensor = 0.0
for g_logits, g_boxes in zip(split_aux_logits, split_aux_boxes):
- aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets)
- loss += 0.5 * (aux_loss / group_detr)
+ aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets, box_scale=box_scale)
+ loss += aux_loss / group_detr
- # Auxiliary losses for encoder proposals
+ # Auxiliary losses for the selected encoder proposals
enc_loss: float | torch.Tensor = 0.0
for group_logits, group_coords in zip(all_group_enc_logits, all_group_enc_coords):
- enc_loss += self.compute_loss(group_logits, group_coords, processed_targets)
- loss += 0.1 * (enc_loss / group_detr)
+ enc_loss += self.compute_loss(group_logits, group_coords, processed_targets, box_scale=box_scale)
+ loss += enc_loss / group_detr
out["loss"] = loss
return out
def compute_loss(
- self, logits: torch.Tensor, pred_boxes: torch.Tensor, targets: list[dict[str, np.ndarray]]
+ self,
+ logits: torch.Tensor,
+ pred_boxes: torch.Tensor,
+ targets: list[dict[str, np.ndarray]],
+ cls_loss_weight: float = 1.0,
+ l1_loss_weight: float = 5.0,
+ iou_loss_weight: float = 2.0,
+ box_scale: float = 1024.0,
) -> torch.Tensor:
- """
- Compute the loss for LW-DETR. The loss consists of three components:
- classification loss, box regression loss, and rotation loss.
- The classification loss is a cross-entropy loss between the predicted class logits and the target classes.
- The box regression loss is a Smooth L1 loss between the predicted boxes and the target boxes,
- computed only on the positive samples.
- The rotation loss is computed as 1 - cosine similarity between the predicted rotation and the target rotation,
- averaged over the positive samples.
- The positive samples are determined using a SimOTA-like assignment strategy, where for each ground truth box,
- we select the top-k queries with the lowest cost
- (combination of classification cost, box regression cost, and rotation cost).
+ """Compute the LW-DETR loss for oriented bounding boxes.
+
+ Predictions are matched one-to-one to the ground truth boxes with Hungarian matching,
+ using a cost combining a focal-style classification cost, an L1 box cost
+ and a (negated) ProbIoU cost. The loss then consists of:
+
+ - an IoU-aware binary cross-entropy (IA-BCE) classification loss, as described in the LW-DETR paper,
+ where the target of a matched (query, class) pair is `p**alpha * IoU**(1 - alpha)` and the rotated
+ ProbIoU is used as IoU measure
+ - an L1 regression loss on the normalized (cx, cy, w, h) of the matched pairs
+ - a ProbIoU loss (1 - ProbIoU) on the matched oriented boxes, computed in absolute pixel
+ coordinates (as in O2-RT-DETR). The box rotation is supervised solely through this term:
+ ProbIoU is differentiable w.r.t. the angle even for non-overlapping boxes.
+
+ All terms are normalized by the number of ground truth boxes in the batch.
Args:
logits: (B, Q, C) tensor containing the predicted class logits for each query
- pred_boxes: (B, Q, 6) tensor containing the predicted boxes for each query
- targets: list of dictionaries where each dictionary corresponds to a sample and has keys corresponding
- to class names and values corresponding to lists of boxes in either polygon format (4, 2)
- or bounding box format (4,) (xmin, ymin, xmax, ymax)
+ pred_boxes: (B, Q, 6) tensor containing the predicted boxes (cx, cy, w, h, sinθ, cosθ) for each query
+ targets: list of B dictionaries with keys "boxes" ((N, 6) array in OBB format)
+ and "labels" ((N,) array of class indices)
+ cls_loss_weight: weight of the classification loss
+ l1_loss_weight: weight of the L1 box regression loss
+ iou_loss_weight: weight of the ProbIoU loss
+ box_scale: image size used to rescale normalized boxes to pixel coordinates for ProbIoU computation
Returns:
loss: the computed loss value
"""
-
- def rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Convert rotated boxes in (cx, cy, w, h, sinθ, cosθ) format to Gaussian distribution parameters
- (mean and covariance).
- """
- cxcy = boxes[..., :2]
-
- w = boxes[..., 2].clamp(min=1e-6)
- h = boxes[..., 3].clamp(min=1e-6)
-
- sin = boxes[..., 4]
- cos = boxes[..., 5]
-
- R = torch.stack(
- [
- torch.stack([cos, -sin], dim=-1),
- torch.stack([sin, cos], dim=-1),
- ],
- dim=-2,
- )
-
- sx = (w / 2) ** 2
- sy = (h / 2) ** 2
-
- S = torch.zeros((*boxes.shape[:-1], 2, 2), device=boxes.device)
-
- S[..., 0, 0] = sx
- S[..., 1, 1] = sy
-
- covariance = R @ S @ R.transpose(-1, -2)
- return cxcy, covariance
-
- def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Tensor:
- """Compute the ProbIoU loss between predicted boxes and target boxes."""
- mu1, sigma1 = rotated_boxes_to_gaussian(pred_boxes)
- mu2, sigma2 = rotated_boxes_to_gaussian(tgt_boxes)
-
- delta = (mu1 - mu2).unsqueeze(-1)
- sigma = (sigma1 + sigma2) * 0.5
-
- eps = 1e-6
- eye = torch.eye(2, device=sigma.device) * eps
- sigma_safe = sigma + eye
- sigma1_safe = sigma1 + eye
- sigma2_safe = sigma2 + eye
-
- sigma_inv = torch.linalg.inv(sigma_safe)
-
- mahalanobis = (delta.transpose(-1, -2) @ sigma_inv @ delta).squeeze(-1).squeeze(-1)
-
- det_sigma = torch.linalg.det(sigma_safe).clamp(min=eps)
- det_sigma1 = torch.linalg.det(sigma1_safe).clamp(min=eps)
- det_sigma2 = torch.linalg.det(sigma2_safe).clamp(min=eps)
-
- bhattacharyya = 0.125 * mahalanobis + 0.5 * torch.log(det_sigma / torch.sqrt(det_sigma1 * det_sigma2))
-
- probiou = torch.exp(-bhattacharyya)
- return 1 - probiou
-
+ alpha, gamma, eps = 0.25, 2.0, 1e-8
device = logits.device
- B, Q, C = logits.shape
+ batch_size = logits.shape[0]
- total_cls = torch.tensor(0.0, device=device)
- total_box = torch.tensor(0.0, device=device)
- total_rot = torch.tensor(0.0, device=device)
+ tgt_boxes_list: list[torch.Tensor] = []
+ tgt_labels_list: list[torch.Tensor] = []
+ for sample in targets:
+ tgt_boxes_list.append(
+ torch.as_tensor(sample["boxes"], device=device, dtype=pred_boxes.dtype).reshape(-1, 6)
+ )
+ tgt_labels_list.append(torch.as_tensor(sample["labels"], device=device, dtype=torch.long).reshape(-1))
- for b in range(B):
- pred_logits = logits[b]
- pred_boxes_b = pred_boxes[b]
+ # Number of target boxes in the batch, for loss normalization
+ num_boxes = max(sum(int(labels.numel()) for labels in tgt_labels_list), 1)
- tgt_boxes = torch.as_tensor(
- targets[b]["boxes"],
- device=device,
- dtype=pred_boxes.dtype,
- )
- tgt_cls = torch.as_tensor(
- targets[b]["labels"],
- device=device,
- dtype=torch.long,
- )
+ # Hungarian matching (one-to-one), performed independently for each sample
+ indices: list[tuple[torch.Tensor, torch.Tensor]] = []
+ with torch.no_grad():
+ prob = logits.sigmoid()
+ for b in range(batch_size):
+ tgt_boxes, tgt_labels = tgt_boxes_list[b], tgt_labels_list[b]
+ if tgt_labels.numel() == 0:
+ empty = torch.empty(0, dtype=torch.long, device=device)
+ indices.append((empty, empty))
+ continue
+
+ out_prob = prob[b]
+ out_boxes = pred_boxes[b]
+
+ # Focal-style classification cost
+ neg_cost = (1 - alpha) * out_prob.pow(gamma) * (-(1 - out_prob + eps).log())
+ pos_cost = alpha * (1 - out_prob).pow(gamma) * (-(out_prob + eps).log())
+ cost_class = pos_cost[:, tgt_labels] - neg_cost[:, tgt_labels]
+
+ # L1 cost on normalized (cx, cy, w, h)
+ cost_bbox = torch.cdist(out_boxes[:, :4], tgt_boxes[:, :4], p=1)
+
+ # Rotated IoU cost, computed in pixel coordinates
+ # this term also carries the angle signal for the matching
+ cost_iou = -_probiou(out_boxes, tgt_boxes, pairwise=True, scale=box_scale)
+
+ cost = 2.0 * cost_class + 5.0 * cost_bbox + 2.0 * cost_iou
+
+ query_idx, tgt_idx = linear_sum_assignment(cost.cpu().numpy())
+ indices.append((
+ torch.as_tensor(query_idx, dtype=torch.long, device=device),
+ torch.as_tensor(tgt_idx, dtype=torch.long, device=device),
+ ))
+
+ # Flatten the matched pairs across the batch
+ batch_idx = torch.cat([torch.full_like(src, b) for b, (src, _) in enumerate(indices)])
+ query_idx = torch.cat([src for (src, _) in indices])
+ matched_tgt_boxes = torch.cat([tgt_boxes_list[b][tgt] for b, (_, tgt) in enumerate(indices)])
+ matched_tgt_labels = torch.cat([tgt_labels_list[b][tgt] for b, (_, tgt) in enumerate(indices)])
+ matched_pred_boxes = pred_boxes[batch_idx, query_idx]
+
+ prob = logits.sigmoid()
+
+ # IoU-aware BCE classification loss (IA-BCE)
+ pos_weights = torch.zeros_like(logits)
+ neg_weights = prob.pow(gamma)
+ if len(batch_idx) > 0:
+ with torch.no_grad():
+ ious = _probiou(matched_pred_boxes, matched_tgt_boxes, scale=box_scale).clamp(min=0.0, max=1.0)
+ t = prob[batch_idx, query_idx, matched_tgt_labels].pow(alpha) * ious.pow(1 - alpha)
+ t = t.clamp(min=0.01)
+ pos_weights[batch_idx, query_idx, matched_tgt_labels] = t
+ neg_weights[batch_idx, query_idx, matched_tgt_labels] = 1 - t
- num_gt = len(tgt_cls)
+ cls_loss = -(pos_weights * (prob + eps).log() + neg_weights * (1 - prob + eps).log())
+ loss = cls_loss_weight * cls_loss.sum() / num_boxes
- pred_rot = F.normalize(pred_boxes_b[:, 4:6], dim=-1)
- tgt_rot = F.normalize(tgt_boxes[:, 4:6], dim=-1)
+ if len(batch_idx) == 0:
+ return loss
- with torch.no_grad():
- cls_prob = pred_logits.sigmoid()
- alpha = 0.25
- gamma = 2.0
-
- neg_cost = (1 - alpha) * (cls_prob**gamma) * (-(1 - cls_prob + 1e-8).log())
-
- pos_cost = alpha * ((1 - cls_prob) ** gamma) * (-(cls_prob + 1e-8).log())
-
- cost_cls = pos_cost[:, tgt_cls] - neg_cost[:, tgt_cls]
- cost_l1 = torch.cdist(
- pred_boxes_b[:, :4],
- tgt_boxes[:, :4],
- p=1,
- )
- cost_rot = 1 - (pred_rot @ tgt_rot.T).abs()
- total_cost = 5.0 * cost_cls + 2.0 * cost_l1 + 1.0 * cost_rot
- matching_matrix = torch.zeros(
- (Q, num_gt),
- dtype=torch.bool,
- device=device,
- )
-
- center_dist = torch.cdist(
- pred_boxes_b[:, :2],
- tgt_boxes[:, :2],
- p=2,
- )
-
- iou_like = torch.exp(-center_dist)
- dynamic_k = iou_like.sum(0).int().clamp(min=1, max=10)
-
- for gt_idx in range(num_gt):
- _, candidate_idx = torch.topk(-total_cost[:, gt_idx], k=int(dynamic_k[gt_idx].item()))
- matching_matrix[candidate_idx, gt_idx] = True
-
- # resolve duplicate matches
- multiple_match_mask = matching_matrix.sum(1) > 1
-
- if multiple_match_mask.any():
- duplicate_idx = multiple_match_mask.nonzero(as_tuple=False).squeeze(1)
- min_cost_idx = total_cost[duplicate_idx].argmin(dim=1)
- # Set all matches to False for the duplicate indices,
- # then set the match with the lowest cost to True
- matching_matrix[duplicate_idx] = False
- matching_matrix[duplicate_idx, min_cost_idx] = True
-
- pos_idx, gt_indices = matching_matrix.nonzero(as_tuple=True)
-
- target_classes = torch.zeros((Q,), dtype=torch.long, device=device)
-
- # background = 0
- target_classes[pos_idx] = tgt_cls[gt_indices]
-
- total_cls += F.cross_entropy(pred_logits, target_classes)
-
- if len(pos_idx) == 0:
- continue
+ # L1 loss on normalized (cx, cy, w, h)
+ l1_loss = F.l1_loss(matched_pred_boxes[:, :4], matched_tgt_boxes[:, :4], reduction="sum") / num_boxes
+ # ProbIoU loss on the whole oriented box (position, size and rotation), in pixel coordinates
+ probiou_loss = (1 - _probiou(matched_pred_boxes, matched_tgt_boxes, scale=box_scale)).sum() / num_boxes
- pred_sel = pred_boxes_b[pos_idx]
- tgt_sel = tgt_boxes[gt_indices]
- # L1 loss on (cx, cy, w, h)
- l1_loss = F.smooth_l1_loss(pred_sel[:, :4], tgt_sel[:, :4])
- # ProbIoU loss on the whole box (including rotation)
- probiou_loss = _probiou_loss(pred_sel, tgt_sel).mean()
- total_box += 2.0 * l1_loss + 0.5 * probiou_loss
- # Rotation loss
- cos_sim = (pred_rot[pos_idx] * tgt_rot[gt_indices]).sum(-1).abs()
- rot_loss = (1 - cos_sim).mean()
- total_rot += 0.5 * rot_loss
- # Average the loss over the batch
- return (total_cls + total_box + total_rot) / B
+ return loss + l1_loss_weight * l1_loss + iou_loss_weight * probiou_loss
def _lw_detr(
@@ -772,7 +750,7 @@ def _lw_detr(
False,
include_top=False,
input_shape=default_cfgs[arch]["input_shape"],
- patch_size=kwargs.get("patch_size", (16, 16)),
+ patch_size=kwargs.pop("patch_size", (16, 16)),
)
feat_extractor = LWDETRBackbone(encoder_fn=backbone)
diff --git a/doctr/transforms/modules/base.py b/doctr/transforms/modules/base.py
index a53acf038f..5b68d45b06 100644
--- a/doctr/transforms/modules/base.py
+++ b/doctr/transforms/modules/base.py
@@ -272,7 +272,14 @@ def __init__(self, scale: tuple[float, float] = (0.08, 1.0), ratio: tuple[float,
def extra_repr(self) -> str:
return f"scale={self.scale}, ratio={self.ratio}"
- def _crop_array(self, img: Any, target: np.ndarray, crop_box):
+ def _crop_image_only(self, img: Any, crop_box: tuple[float, float, float, float]) -> Any:
+ dummy_box = np.array([[0, 0, 1, 1]], dtype=np.float32)
+ cropped_img, _ = F.crop_detection(img, dummy_box, crop_box)
+ return cropped_img
+
+ def _crop_array(
+ self, img: Any, target: np.ndarray, crop_box: tuple[float, float, float, float]
+ ) -> tuple[Any, np.ndarray]:
is_polygon = target.shape[1:] == (4, 2)
if is_polygon:
@@ -339,17 +346,15 @@ def __call__(self, sample: Sample) -> Sample:
(y + crop_h) / h,
)
- r_mask = None
- if mask is not None:
- r_mask, _ = self._crop_array(mask, np.zeros((0, 4)), crop_box)
-
if target is None:
- r_img, _ = self._crop_array(img, np.zeros((0, 4)), crop_box)
+ r_img = self._crop_image_only(img, crop_box)
+ r_mask = self._crop_image_only(mask, crop_box) if mask is not None else None
return sample.replace(image=r_img, mask=r_mask)
if isinstance(target, dict):
cropped_targets = {}
cropped_img = None
+ crop_rejected = False
for cls_name, arr in target.items():
if len(arr) == 0:
@@ -358,13 +363,29 @@ def __call__(self, sample: Sample) -> Sample:
c_img, c_arr = self._crop_array(img, arr, crop_box)
+ if c_img is img:
+ crop_rejected = True
+ break
+
if cropped_img is None:
cropped_img = c_img
cropped_targets[cls_name] = c_arr
- final_img = cropped_img if cropped_img is not None else img
- return sample.replace(image=final_img, mask=r_mask, target=cropped_targets)
+ if crop_rejected or cropped_img is None:
+ return sample.replace(
+ image=img,
+ mask=mask,
+ target={key: value.copy() for key, value in target.items()},
+ )
+
+ r_mask = self._crop_image_only(mask, crop_box) if mask is not None else None
+ return sample.replace(image=cropped_img, mask=r_mask, target=cropped_targets)
c_img, c_target = self._crop_array(img, target, crop_box)
+ if c_img is img:
+ r_mask = mask
+ else:
+ r_mask = self._crop_image_only(mask, crop_box) if mask is not None else None
+
return sample.replace(image=c_img, mask=r_mask, target=c_target)
diff --git a/doctr/transforms/modules/pytorch.py b/doctr/transforms/modules/pytorch.py
index c7245f3a82..c6b7881b56 100644
--- a/doctr/transforms/modules/pytorch.py
+++ b/doctr/transforms/modules/pytorch.py
@@ -380,7 +380,7 @@ def __call__(self, sample: Sample) -> Sample:
else:
shadowed_image = random_shadow(sample.image, self.opacity_range).clip(0, 1)
return sample.replace(image=shadowed_image)
- except ValueError:
+ except ValueError: # pragma: no cover
return sample
def extra_repr(self) -> str:
diff --git a/references/classification/train_character.py b/references/classification/train_character.py
index c266bb2f5c..22a9298ecb 100644
--- a/references/classification/train_character.py
+++ b/references/classification/train_character.py
@@ -66,7 +66,7 @@ def record_lr(
loss_recorder = []
if amp:
- scaler = torch.cuda.amp.GradScaler()
+ scaler = torch.amp.GradScaler("cuda")
for batch_idx, (images, targets) in enumerate(train_loader):
targets = torch.tensor(targets)
@@ -79,7 +79,7 @@ def record_lr(
# Forward, Backward & update
optimizer.zero_grad()
if amp:
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast("cuda"):
out = model(images)
train_loss = cross_entropy(out, targets)
scaler.scale(train_loss).backward()
@@ -110,7 +110,7 @@ def record_lr(
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None):
if amp:
- scaler = torch.cuda.amp.GradScaler()
+ scaler = torch.amp.GradScaler("cuda")
model.train()
# Iterate over the batches of the dataset
@@ -126,7 +126,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a
optimizer.zero_grad()
if amp:
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast("cuda"):
out = model(images)
train_loss = cross_entropy(out, targets)
scaler.scale(train_loss).backward()
@@ -168,7 +168,7 @@ def evaluate(model, val_loader, batch_transforms, amp=False, log=None):
targets = targets.cuda()
if amp:
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast("cuda"):
out = model(images)
loss = cross_entropy(out, targets)
else:
diff --git a/references/classification/train_orientation.py b/references/classification/train_orientation.py
index 86dc4c5931..0cbbf05616 100644
--- a/references/classification/train_orientation.py
+++ b/references/classification/train_orientation.py
@@ -78,7 +78,7 @@ def record_lr(
loss_recorder = []
if amp:
- scaler = torch.cuda.amp.GradScaler()
+ scaler = torch.amp.GradScaler("cuda")
for batch_idx, (images, targets) in enumerate(train_loader):
targets = torch.tensor(targets)
@@ -91,7 +91,7 @@ def record_lr(
# Forward, Backward & update
optimizer.zero_grad()
if amp:
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast("cuda"):
out = model(images)
train_loss = cross_entropy(out, targets)
scaler.scale(train_loss).backward()
@@ -122,7 +122,7 @@ def record_lr(
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None):
if amp:
- scaler = torch.cuda.amp.GradScaler()
+ scaler = torch.amp.GradScaler("cuda")
model.train()
# Iterate over the batches of the dataset
@@ -138,7 +138,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a
optimizer.zero_grad()
if amp:
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast("cuda"):
out = model(images)
train_loss = cross_entropy(out, targets)
scaler.scale(train_loss).backward()
@@ -180,7 +180,7 @@ def evaluate(model, val_loader, batch_transforms, amp=False, log=None):
targets = targets.cuda()
if amp:
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast("cuda"):
out = model(images)
loss = cross_entropy(out, targets)
else:
diff --git a/references/detection/evaluate.py b/references/detection/evaluate.py
index 7c5fb597aa..a8674884f4 100644
--- a/references/detection/evaluate.py
+++ b/references/detection/evaluate.py
@@ -39,7 +39,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False):
images = batch_transforms(images)
targets = [{CLASS_NAME: t} for t in targets]
if amp:
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast("cuda"):
out = model(images, targets, return_preds=True)
else:
out = model(images, targets, return_preds=True)
diff --git a/references/detection/train.py b/references/detection/train.py
index 43e4ec3a56..2070b63c3e 100644
--- a/references/detection/train.py
+++ b/references/detection/train.py
@@ -63,7 +63,7 @@ def record_lr(
loss_recorder = []
if amp:
- scaler = torch.cuda.amp.GradScaler()
+ scaler = torch.amp.GradScaler("cuda")
for batch_idx, (images, targets) in enumerate(train_loader):
if torch.cuda.is_available():
@@ -74,7 +74,7 @@ def record_lr(
# Forward, Backward & update
optimizer.zero_grad()
if amp:
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast("cuda"):
train_loss = model(images, targets)["loss"]
scaler.scale(train_loss).backward()
# Gradient clipping
@@ -107,7 +107,7 @@ def record_lr(
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0):
if amp:
- scaler = torch.cuda.amp.GradScaler()
+ scaler = torch.amp.GradScaler("cuda")
model.train()
# Iterate over the batches of the dataset
@@ -120,7 +120,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a
optimizer.zero_grad()
if amp:
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast("cuda"):
train_loss = model(images, targets)["loss"]
scaler.scale(train_loss).backward()
# Gradient clipping
@@ -163,7 +163,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, args, amp=False, l
images = images.cuda()
images = batch_transforms(images)
if amp:
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast("cuda"):
out = model(images, targets, return_preds=True)
else:
out = model(images, targets, return_preds=True)
@@ -333,7 +333,7 @@ def main(args):
[
T.RandomHorizontalFlip(0.15),
T.OneOf([
- T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25),
+ T.RandomApply(T.RandomCrop(ratio=(0.85, 1.15), scale=(0.75, 1.0)), 0.25),
T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25),
]),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
@@ -342,7 +342,7 @@ def main(args):
else [
T.RandomHorizontalFlip(0.15),
T.OneOf([
- T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25),
+ T.RandomApply(T.RandomCrop(ratio=(0.85, 1.15), scale=(0.75, 1.0)), 0.25),
T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25),
]),
# Rotation augmentation
@@ -363,7 +363,7 @@ def main(args):
)
if distributed:
- sampler = DistributedSampler(train_set, rank=rank, shuffle=False, drop_last=True)
+ sampler = DistributedSampler(train_set, rank=rank, shuffle=True, drop_last=True)
else:
sampler = RandomSampler(train_set)
@@ -518,6 +518,8 @@ def log_at_step(train_loss=None, val_loss=None, lr=None):
# Training loop
for epoch in range(args.epochs):
+ if distributed:
+ sampler.set_epoch(epoch)
train_loss, actual_lr = fit_one_epoch(
model, train_loader, batch_transforms, optimizer, scheduler, amp=args.amp, log=log_at_step, rank=rank
)
diff --git a/references/layout/evaluate.py b/references/layout/evaluate.py
index 36fb78cf80..aa8ea788aa 100644
--- a/references/layout/evaluate.py
+++ b/references/layout/evaluate.py
@@ -39,7 +39,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False):
padding_masks = padding_masks.cuda()
imgs = batch_transforms(imgs)
if amp:
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast("cuda"):
out = model(imgs, padding_masks, targets, return_preds=True)
else:
out = model(imgs, padding_masks, targets, return_preds=True)
diff --git a/references/layout/train.py b/references/layout/train.py
index f3f0d2c117..8f9ce372b1 100644
--- a/references/layout/train.py
+++ b/references/layout/train.py
@@ -17,7 +17,14 @@
# The following import is required for DDP
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR
+from torch.optim.lr_scheduler import (
+ CosineAnnealingLR,
+ LinearLR,
+ MultiplicativeLR,
+ OneCycleLR,
+ PolynomialLR,
+ SequentialLR,
+)
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.transforms.v2 import Compose, Normalize, RandomGrayscale, RandomPhotometricDistort
@@ -31,7 +38,7 @@
from doctr.datasets import LayoutDataset
from doctr.models import layout, login_to_hub, push_to_hf_hub
from doctr.utils.metrics import ObjectDetectionMetric
-from utils import EarlyStopper, convert_target, plot_recorder, plot_samples
+from utils import EarlyStopper, build_param_groups, convert_target, plot_recorder, plot_samples
def record_lr(
@@ -63,7 +70,7 @@ def record_lr(
loss_recorder = []
if amp:
- scaler = torch.cuda.amp.GradScaler()
+ scaler = torch.amp.GradScaler("cuda")
for batch_idx, (images, targets) in enumerate(train_loader):
imgs, padding_masks = images
@@ -77,19 +84,19 @@ def record_lr(
# Forward, Backward & update
optimizer.zero_grad()
if amp:
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast("cuda"):
train_loss = model(imgs, padding_masks, targets)["loss"]
scaler.scale(train_loss).backward()
# Gradient clipping
scaler.unscale_(optimizer)
- torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
# Update the params
scaler.step(optimizer)
scaler.update()
else:
train_loss = model(imgs, padding_masks, targets)["loss"]
train_loss.backward()
- torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
optimizer.step()
# Update LR
scheduler.step()
@@ -110,7 +117,7 @@ def record_lr(
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0):
if amp:
- scaler = torch.cuda.amp.GradScaler()
+ scaler = torch.amp.GradScaler("cuda")
model.train()
# Iterate over the batches of the dataset
@@ -125,19 +132,19 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a
optimizer.zero_grad()
if amp:
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast("cuda"):
train_loss = model(imgs, padding_masks, targets)["loss"]
scaler.scale(train_loss).backward()
# Gradient clipping
scaler.unscale_(optimizer)
- torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
# Update the params
scaler.step(optimizer)
scaler.update()
else:
train_loss = model(imgs, padding_masks, targets)["loss"]
train_loss.backward()
- torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
optimizer.step()
scheduler.step()
@@ -170,7 +177,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False, log=Non
padding_masks = padding_masks.cuda()
imgs = batch_transforms(imgs)
if amp:
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast("cuda"):
out = model(imgs, padding_masks, targets, return_preds=True)
else:
out = model(imgs, padding_masks, targets, return_preds=True)
@@ -383,7 +390,7 @@ def main(args):
[
T.RandomHorizontalFlip(0.15),
T.OneOf([
- T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25),
+ T.RandomApply(T.RandomCrop(ratio=(0.85, 1.15), scale=(0.75, 1.0)), 0.25),
T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25),
]),
T.Resize(
@@ -397,7 +404,7 @@ def main(args):
else [
T.RandomHorizontalFlip(0.15),
T.OneOf([
- T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25),
+ T.RandomApply(T.RandomCrop(ratio=(0.85, 1.15), scale=(0.75, 1.0)), 0.25),
T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25),
]),
# Rotation augmentation
@@ -424,7 +431,7 @@ def main(args):
)
if distributed:
- sampler = DistributedSampler(train_set, rank=rank, shuffle=False, drop_last=True)
+ sampler = DistributedSampler(train_set, rank=rank, shuffle=True, drop_last=True)
else:
sampler = RandomSampler(train_set)
@@ -472,24 +479,18 @@ def main(args):
# construct DDP model
model = DDP(model, device_ids=[rank])
+ param_groups = build_param_groups(
+ model,
+ lr=args.lr,
+ backbone_lr=args.lr if not args.pretrained else args.lr * 0.1,
+ weight_decay=args.weight_decay or 1e-4,
+ )
+
# Optimizer
if args.optim == "adam":
- optimizer = torch.optim.Adam(
- [p for p in model.parameters() if p.requires_grad],
- args.lr,
- betas=(0.95, 0.999),
- eps=1e-6,
- weight_decay=args.weight_decay,
- )
-
+ optimizer = torch.optim.Adam(param_groups, lr=args.lr, betas=(0.9, 0.999), eps=1e-8)
elif args.optim == "adamw":
- optimizer = torch.optim.AdamW(
- [p for p in model.parameters() if p.requires_grad],
- args.lr,
- betas=(0.9, 0.999),
- eps=1e-6,
- weight_decay=args.weight_decay or 1e-4,
- )
+ optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.999), eps=1e-8)
# LR Finder
if rank == 0 and args.find_lr:
@@ -498,12 +499,55 @@ def main(args):
return
# Scheduler
+ total_steps = args.epochs * len(train_loader)
+ warmup_steps = max(1, min(2000, int(0.05 * total_steps)))
+
if args.sched == "cosine":
- scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4)
+ warmup = LinearLR(
+ optimizer,
+ start_factor=0.01,
+ end_factor=1.0,
+ total_iters=warmup_steps,
+ )
+ cosine = CosineAnnealingLR(
+ optimizer,
+ T_max=max(1, total_steps - warmup_steps),
+ eta_min=args.lr * 0.01,
+ )
+ scheduler = SequentialLR(
+ optimizer,
+ schedulers=[warmup, cosine],
+ milestones=[warmup_steps],
+ )
+
elif args.sched == "onecycle":
- scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader))
+ scheduler = OneCycleLR(
+ optimizer,
+ max_lr=[g["lr"] for g in optimizer.param_groups],
+ total_steps=total_steps,
+ pct_start=warmup_steps / total_steps,
+ div_factor=100,
+ final_div_factor=100,
+ anneal_strategy="cos",
+ )
+
elif args.sched == "poly":
- scheduler = PolynomialLR(optimizer, args.epochs * len(train_loader))
+ warmup = LinearLR(
+ optimizer,
+ start_factor=0.01,
+ end_factor=1.0,
+ total_iters=warmup_steps,
+ )
+ poly = PolynomialLR(
+ optimizer,
+ total_iters=total_steps - warmup_steps,
+ power=1.0,
+ )
+ scheduler = SequentialLR(
+ optimizer,
+ schedulers=[warmup, poly],
+ milestones=[warmup_steps],
+ )
# Training monitoring
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
@@ -591,6 +635,8 @@ def log_at_step(train_loss=None, val_loss=None, lr=None):
# Training loop
for epoch in range(args.epochs):
+ if distributed:
+ sampler.set_epoch(epoch)
train_loss, actual_lr = fit_one_epoch(
model,
train_loader,
@@ -690,8 +736,8 @@ def parse_args():
"--save-interval-epoch", dest="save_interval_epoch", action="store_true", help="Save model every epoch"
)
parser.add_argument("--input_size", type=int, default=1024, help="model input size, H = W")
- parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam or AdamW)")
- parser.add_argument("--wd", "--weight-decay", default=0, type=float, help="weight decay", dest="weight_decay")
+ parser.add_argument("--lr", type=float, default=1e-4, help="learning rate for the optimizer (Adam or AdamW)")
+ parser.add_argument("--wd", "--weight-decay", default=1e-4, type=float, help="weight decay", dest="weight_decay")
parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading")
parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint")
parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop")
diff --git a/references/layout/utils.py b/references/layout/utils.py
index 6629487ab7..f7e56f049f 100644
--- a/references/layout/utils.py
+++ b/references/layout/utils.py
@@ -86,6 +86,44 @@ def plot_samples(
plt.show()
+def build_param_groups(model: Any, lr: float, backbone_lr: float, weight_decay: float):
+ """Build parameter groups for the optimizer, separating backbone and non-backbone parameters,
+ and applying weight decay only to non-bias and non-norm parameters.
+
+ Args:
+ model: the model containing the parameters
+ lr: learning rate for non-backbone parameters
+ backbone_lr: learning rate for backbone parameters
+ weight_decay: weight decay to apply to non-bias and non-norm parameters
+
+ Returns:
+ a list of parameter groups to be passed to the optimizer
+ """
+ no_decay_keys = ("bias", "norm", ".bn", "embed") # Embedding, LayerNorm, BN
+
+ def is_backbone(name: str) -> bool:
+ return name.removeprefix("module.").startswith("feat_extractor.")
+
+ groups: dict[tuple[bool, bool], list[Any]] = {
+ (False, True): [],
+ (False, False): [],
+ (True, True): [],
+ (True, False): [],
+ }
+ for n, p in model.named_parameters():
+ if not p.requires_grad:
+ continue
+ decay = not (p.ndim <= 1 or any(k in n.lower() for k in no_decay_keys))
+ groups[(is_backbone(n), decay)].append(p)
+
+ return [
+ {"params": groups[(False, True)], "lr": lr, "weight_decay": weight_decay},
+ {"params": groups[(False, False)], "lr": lr, "weight_decay": 0.0},
+ {"params": groups[(True, True)], "lr": backbone_lr, "weight_decay": weight_decay},
+ {"params": groups[(True, False)], "lr": backbone_lr, "weight_decay": 0.0},
+ ]
+
+
def plot_recorder(lr_recorder, loss_recorder, beta: float = 0.95, **kwargs) -> None:
"""Display the results of the LR grid search.
Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py
diff --git a/references/recognition/evaluate.py b/references/recognition/evaluate.py
index 45a6b38306..22f69fa2cd 100644
--- a/references/recognition/evaluate.py
+++ b/references/recognition/evaluate.py
@@ -38,7 +38,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False):
images = images.cuda()
images = batch_transforms(images)
if amp:
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast("cuda"):
out = model(images, targets, return_preds=True)
else:
out = model(images, targets, return_preds=True)
diff --git a/references/recognition/train.py b/references/recognition/train.py
index dc6b7b1b24..913f2ed017 100644
--- a/references/recognition/train.py
+++ b/references/recognition/train.py
@@ -68,7 +68,7 @@ def record_lr(
loss_recorder = []
if amp:
- scaler = torch.cuda.amp.GradScaler()
+ scaler = torch.amp.GradScaler("cuda")
for batch_idx, (images, targets) in enumerate(train_loader):
if torch.cuda.is_available():
@@ -79,7 +79,7 @@ def record_lr(
# Forward, Backward & update
optimizer.zero_grad()
if amp:
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast("cuda"):
train_loss = model(images, targets)["loss"]
scaler.scale(train_loss).backward()
# Gradient clipping
@@ -112,7 +112,7 @@ def record_lr(
def fit_one_epoch(model, device, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0):
if amp:
- scaler = torch.cuda.amp.GradScaler()
+ scaler = torch.amp.GradScaler("cuda")
model.train()
# Iterate over the batches of the dataset
@@ -125,7 +125,7 @@ def fit_one_epoch(model, device, train_loader, batch_transforms, optimizer, sche
optimizer.zero_grad()
if amp:
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast("cuda"):
train_loss = model(images, targets)["loss"]
scaler.scale(train_loss).backward()
# Gradient clipping
@@ -167,7 +167,7 @@ def evaluate(model, device, val_loader, batch_transforms, val_metric, amp=False,
images = images.to(device)
images = batch_transforms(images)
if amp:
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast("cuda"):
out = model(images, targets, return_preds=True)
else:
out = model(images, targets, return_preds=True)
@@ -559,6 +559,8 @@ def log_at_step(train_loss=None, val_loss=None, lr=None):
early_stopper = EarlyStopper(patience=args.early_stop_epochs, min_delta=args.early_stop_delta)
# Training loop
for epoch in range(args.epochs):
+ if distributed:
+ sampler.set_epoch(epoch)
train_loss, actual_lr = fit_one_epoch(
model,
device,
diff --git a/tests/pytorch/test_models_factory.py b/tests/pytorch/test_models_factory.py
index 8cf0a9fa45..db9aed1cdf 100644
--- a/tests/pytorch/test_models_factory.py
+++ b/tests/pytorch/test_models_factory.py
@@ -50,7 +50,7 @@ def test_push_to_hf_hub():
["vitstr_small", "recognition", "Felix92/doctr-dummy-torch-vitstr-small"],
["parseq", "recognition", "Felix92/doctr-dummy-torch-parseq"],
["viptr_tiny", "recognition", "Felix92/doctr-dummy-torch-viptr-tiny"],
- # ["lw_detr_s", "layout", "Felix92/doctr-dummy-torch-lw-detr-s"],
+ ["lw_detr_s", "layout", "Felix92/doctr-dummy-torch-lw-detr-s"],
],
)
def test_models_huggingface_hub(arch_name, task_name, dummy_model_id, tmpdir):
diff --git a/tests/pytorch/test_models_layout.py b/tests/pytorch/test_models_layout.py
index d0c2865411..5eb3f582e3 100644
--- a/tests/pytorch/test_models_layout.py
+++ b/tests/pytorch/test_models_layout.py
@@ -27,7 +27,7 @@ def test_layout_models(arch_name, input_shape, train_mode, use_polygons):
model = model.train() if train_mode else model.eval()
assert isinstance(model, torch.nn.Module)
input_tensor = torch.rand((batch_size, *input_shape))
- input_masks = torch.ones((batch_size, input_shape[1], input_shape[2]), dtype=torch.bool)
+ input_masks = torch.zeros((batch_size, input_shape[1], input_shape[2]), dtype=torch.bool)
class_names = model.class_names
@@ -130,7 +130,7 @@ def test_models_onnx_export(arch_name, input_shape):
batch_size = 2
model = layout.__dict__[arch_name](pretrained=True, exportable=True).eval()
dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32)
- dummy_masks = torch.ones((batch_size, input_shape[1], input_shape[2]), dtype=torch.bool)
+ dummy_masks = torch.zeros((batch_size, input_shape[1], input_shape[2]), dtype=torch.bool)
pt = model(dummy_input, dummy_masks)
pt_logits = pt["logits"].detach().cpu().numpy()
pt_boxes = pt["pred_boxes"].detach().cpu().numpy()
diff --git a/tests/pytorch/test_transforms_pt.py b/tests/pytorch/test_transforms_pt.py
index 2a41a712b1..353dc87625 100644
--- a/tests/pytorch/test_transforms_pt.py
+++ b/tests/pytorch/test_transforms_pt.py
@@ -133,6 +133,21 @@ def test_resize():
assert mask.shape == (64, 64)
assert mask.dtype == torch.bool
+ # Test with included mask in input sample
+ input_t = Sample(
+ image=torch.ones((3, 32, 64), dtype=torch.float32),
+ mask=torch.zeros((32, 64), dtype=torch.bool),
+ target=target_boxes,
+ )
+ data = transfo(input_t)
+ out, mask, new_target = data.image, data.mask, data.target
+
+ assert out.shape[-2:] == (64, 64)
+ assert new_target.shape == target_boxes.shape
+ assert np.all((0 <= new_target) & (new_target <= 1))
+ assert mask.shape == (64, 64)
+ assert mask.dtype == torch.bool
+
# Test with invalid target shape
input_t = torch.ones((3, 32, 64), dtype=torch.float32)
target = np.ones((2, 5)) # Invalid shape
@@ -312,6 +327,18 @@ def test_random_rotate():
assert r_targets["boxes"].shape == (0, 4)
assert r_targets["polygons"].shape == (0, 4, 2)
+ # Test with mask in input sample
+ input_m = torch.zeros((50, 50), dtype=torch.bool)
+ data = rotator(Sample(image=input_t, mask=input_m, target=boxes))
+ r_img, r_mask = data.image, data.mask
+ assert r_img.ndim == input_t.ndim
+ assert r_mask.ndim - 1 == input_m.ndim # Mask should be 2D
+
+ # Test without target
+ data = rotator(Sample(image=input_t))
+ r_img = data.image
+ assert r_img.ndim == input_t.ndim
+
# FP16 (only on GPU)
if torch.cuda.is_available():
input_t = torch.ones((3, 50, 50), dtype=torch.float16).cuda()
@@ -354,8 +381,8 @@ def test_crop_detection():
@pytest.mark.parametrize(
"target",
[
- np.array([[15, 20, 35, 30]]), # box
- np.array([[[15, 20], [35, 20], [35, 30], [15, 30]]]), # polygon
+ np.array([[15, 20, 35, 30]], dtype=np.float32), # box
+ np.array([[[15, 20], [35, 20], [35, 30], [15, 30]]], dtype=np.float32), # polygon
],
)
def test_random_crop(target):
@@ -363,23 +390,28 @@ def test_random_crop(target):
assert repr(cropper) == "RandomCrop(scale=(0.5, 1.0), ratio=(0.75, 1.33))"
input_t = torch.ones((3, 50, 50), dtype=torch.float32)
- sample = cropper(Sample(image=input_t, target=target))
- img, target = sample.image, sample.target
+ original_target = target.copy()
+
+ sample = cropper(Sample(image=input_t, target=original_target.copy()))
+ img, cropped_target = sample.image, sample.target
+
# Check the scale
assert img.shape[-1] * img.shape[-2] >= 0.4 * input_t.shape[-1] * input_t.shape[-2]
# Check aspect ratio
assert 0.65 <= img.shape[-2] / img.shape[-1] <= 1.6
# Check the target
- assert np.all(target >= 0)
- if target.ndim == 2:
- assert np.all(target[:, [0, 2]] <= img.shape[-1]) and np.all(target[:, [1, 3]] <= img.shape[-2])
+ assert np.all(cropped_target >= 0)
+ if cropped_target.ndim == 2:
+ assert np.all(cropped_target[:, [0, 2]] <= img.shape[-1])
+ assert np.all(cropped_target[:, [1, 3]] <= img.shape[-2])
else:
- assert np.all(target[..., 0] <= img.shape[-1]) and np.all(target[..., 1] <= img.shape[-2])
+ assert np.all(cropped_target[..., 0] <= img.shape[-1])
+ assert np.all(cropped_target[..., 1] <= img.shape[-2])
# Test dict targets
dict_target = {
- "boxes": np.array([[15, 20, 35, 30]]),
- "polygons": np.array([[[15, 20], [35, 20], [35, 30], [15, 30]]]),
+ "boxes": np.array([[15, 20, 35, 30]], dtype=np.float32),
+ "polygons": np.array([[[15, 20], [35, 20], [35, 30], [15, 30]]], dtype=np.float32),
}
sample = cropper(Sample(image=input_t, target=dict_target))
img, cropped_targets = sample.image, sample.target
@@ -401,6 +433,22 @@ def test_random_crop(target):
assert np.all(cropped_targets["polygons"][..., 0] <= img.shape[-1])
assert np.all(cropped_targets["polygons"][..., 1] <= img.shape[-2])
+ # Test with mask in input sample
+ input_m = torch.ones((50, 50), dtype=torch.bool)
+ sample = cropper(Sample(image=input_t, mask=input_m, target=original_target.copy()))
+ img, mask, cropped_target = sample.image, sample.mask, sample.target
+
+ assert img.shape[-1] * img.shape[-2] >= 0.4 * input_t.shape[-1] * input_t.shape[-2]
+ assert 0.65 <= img.shape[-2] / img.shape[-1] <= 1.6
+ assert mask.shape == img.shape[-2:]
+ assert mask.dtype == torch.bool
+
+ # Test without target
+ sample = cropper(Sample(image=input_t))
+ img = sample.image
+ assert img.shape[-1] * img.shape[-2] >= 0.4 * input_t.shape[-1] * input_t.shape[-2]
+ assert 0.65 <= img.shape[-2] / img.shape[-1] <= 1.6
+
@pytest.mark.parametrize(
"input_dtype, input_size",
@@ -450,6 +498,7 @@ def test_gaussian_noise(input_dtype, input_shape):
assert torch.all(transformed <= 255)
else:
assert torch.all(transformed <= 1.0)
+ assert repr(transform) == "GaussianNoise(mean=0.0, std=1.0)"
@pytest.mark.parametrize(