diff --git a/doctr/models/factory/hub.py b/doctr/models/factory/hub.py index 45d57a326d..57839f5646 100644 --- a/doctr/models/factory/hub.py +++ b/doctr/models/factory/hub.py @@ -9,7 +9,6 @@ import logging import subprocess import tempfile -import textwrap from pathlib import Path from typing import Any @@ -101,61 +100,62 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: # raise ValueError("task must be one of classification, detection, recognition, layout") # default readme - readme = textwrap.dedent( - f"""--- - language: en - tags: - - ocr - - pytorch - - doctr - - {task} - --- + readme = f"""--- +language: en +tags: +- ocr +- pytorch +- doctr +- {task} +--- -

- -

+

+ +

- **Optical Character Recognition made seamless & accessible to anyone, powered by PyTorch** +**Optical Character Recognition made seamless & accessible to anyone, powered by PyTorch** - ## Task: {task} +## Task: {task} - https://github.com/mindee/doctr +https://github.com/mindee/doctr - ### Example usage: +### Example usage: - ```python - >>> from doctr.io import DocumentFile - >>> from doctr.models import ocr_predictor, from_hub +```python +>>> from doctr.io import DocumentFile +>>> from doctr.models import ocr_predictor, from_hub - >>> img = DocumentFile.from_images(['']) - >>> # Load your model from the hub - >>> model = from_hub('mindee/my-model') +>>> img = DocumentFile.from_images(['']) +>>> # Load your model from the hub +>>> model = from_hub('mindee/my-model') - >>> # Pass it to the predictor - >>> # If your model is a recognition model: - >>> predictor = ocr_predictor(det_arch='db_mobilenet_v3_large', - >>> reco_arch=model, - >>> pretrained=True) +>>> # Pass it to the predictor +>>> # If your model is a recognition model: +>>> predictor = ocr_predictor(det_arch='db_mobilenet_v3_large', +>>> reco_arch=model, +>>> pretrained=True) - >>> # If your model is a detection model: - >>> predictor = ocr_predictor(det_arch=model, - >>> reco_arch='crnn_mobilenet_v3_small', - >>> pretrained=True) +>>> # If your model is a detection model: +>>> predictor = ocr_predictor(det_arch=model, +>>> reco_arch='crnn_mobilenet_v3_small', +>>> pretrained=True) - >>> # Get your predictions - >>> res = predictor(img) - ``` - """ - ) +>>> # Get your predictions +>>> res = predictor(img) +``` +""" # add run configuration to readme if available if run_config is not None: arch = run_config.arch - readme += textwrap.dedent( - f"""### Run Configuration - \n{json.dumps(vars(run_config), indent=2, ensure_ascii=False)}""" - ) + readme += f""" +### Run Configuration + +```json +{json.dumps(vars(run_config), indent=2, ensure_ascii=False)} +``` +""" if arch not in AVAILABLE_ARCHS[task]: raise ValueError( diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index 2aa59e7d18..b15ca54099 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -143,23 +143,20 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int results: list[tuple[list[int], np.ndarray, list[float]]] = [] for b in range(boxes.shape[0]): - # Convert logits to probabilities and get scores and labels - exp = np.exp(logits[b] - logits[b].max(axis=-1, keepdims=True)) - prob = exp / exp.sum(axis=-1, keepdims=True) + # Sigmoid scores (the model is trained with a sigmoid-based (IA-BCE) loss without + # a background class) + prob = 1.0 / (1.0 + np.exp(-logits[b])) # (num_queries, num_classes) + num_classes = prob.shape[-1] - prob_fg = prob[:, :-1] # exclude background - scores = prob_fg.max(axis=-1) - labels = prob_fg.argmax(axis=-1) + # Keep only the topk (query, class) pairs before NMS + flat_prob = prob.reshape(-1) + topk = min(self.topk, flat_prob.size) if self.topk is not None else flat_prob.size + topk_idxs = np.argsort(flat_prob)[::-1][:topk] - # Keep only topk predictions before NMS - if self.topk is not None and len(scores) > self.topk: - idxs = np.argsort(scores)[::-1][: self.topk] - else: - idxs = np.arange(len(scores)) - - scores_b = scores[idxs] - labels_b = labels[idxs] - bboxes = boxes[b][idxs] + scores_b = flat_prob[topk_idxs] + labels_b = topk_idxs % num_classes + query_idxs = topk_idxs // num_classes + bboxes = boxes[b][query_idxs] mask = scores_b > self.score_thresh @@ -275,11 +272,9 @@ def to_quad(box: np.ndarray): if box.shape == (4,): x1, y1, x2, y2 = box return np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]], dtype=np.float32) - if box.shape == (8,): - return box.reshape(4, 2) if box.shape == (4, 2): return box.astype(np.float32) - raise ValueError(f"Unsupported box shape: {box.shape}") + raise ValueError(f"Unsupported box shape: {box.shape}") # pragma: no cover for sample in target: boxes_all = [] @@ -287,7 +282,7 @@ def to_quad(box: np.ndarray): for class_name, boxes in sample.items(): if class_name not in class_to_id: - raise ValueError(f"Unknown class name: {class_name}") + raise ValueError(f"Unknown class name: {class_name}") # pragma: no cover cls_id = class_to_id[class_name] boxes = np.asarray(boxes) @@ -307,7 +302,7 @@ def to_quad(box: np.ndarray): labels_all.append(cls_id) targets.append({ - "boxes": np.asarray(boxes_all, dtype=np.float32), + "boxes": np.asarray(boxes_all, dtype=np.float32) if boxes_all else np.zeros((0, 6), dtype=np.float32), "labels": np.asarray(labels_all, dtype=np.int64), }) diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py index 08f9ea9fb4..ed970f9aa9 100644 --- a/doctr/models/layout/lw_detr/layers/pytorch.py +++ b/doctr/models/layout/lw_detr/layers/pytorch.py @@ -12,7 +12,55 @@ from doctr.models.modules import ChannelLayerNorm from doctr.models.utils import conv_sequence_pt -__all__ = ["MultiScaleProjector", "C2fBottleneck", "LWDETRHead", "LWDETRDecoder", "LWDETRMultiscaleDeformableAttention"] +__all__ = [ + "MultiScaleProjector", + "C2fBottleneck", + "LWDETRHead", + "LWDETRDecoder", + "LWDETRMultiscaleDeformableAttention", + "refine_obb_boxes", +] + + +def refine_obb_boxes(reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: + """Refine oriented bounding boxes by applying predicted deltas to reference points. + Both reference points and deltas are in the format (cx, cy, w, h, sinθ, cosθ). + + cx' = cx + delta_cx * w + cy' = cy + delta_cy * h + w' = w * exp(delta_w) + h' = h * exp(delta_h) + (sinθ', cosθ') = rotation composition of (sinθ, cosθ) with the normalized delta rotation + + Args: + reference_points: (..., 6) tensor containing the reference points + deltas: (..., 6) tensor containing the predicted deltas + + Returns: + refined_boxes: (..., 6) tensor containing the refined bounding boxes + """ + reference_points = reference_points.to(deltas.device) + # center + cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] + # size: clamp deltas to prevent exp() from overflowing during early training. + # NOTE: the upper bound must allow the encoder proposals (w = h = 0.05) to reach full-page boxes + # in a single refinement step: exp(3.5) * 0.05 ~= 1.66, so boxes up to (and beyond) the full canvas + # remain reachable. A tighter bound (e.g. 2.0 -> 0.05 * exp(2) ~= 0.37) silently caps the size of the + # encoder proposals and creates an irreducible loss floor for large layout items (tables, pictures). + wh = torch.clamp(deltas[..., 2:4], min=-5.0, max=3.5).exp() * reference_points[..., 2:4] + # rotation (eps avoids division-by-zero NaN creation) + delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) + sin_delta = delta_rot[..., 0:1] + cos_delta = delta_rot[..., 1:2] + sin_ref = reference_points[..., 4:5] + cos_ref = reference_points[..., 5:6] + + # compose rotations + sin_new = sin_ref * cos_delta + cos_ref * sin_delta + cos_new = cos_ref * cos_delta - sin_ref * sin_delta + rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) + + return torch.cat((cxcy, wh, rot), dim=-1) class LWDETRHead(nn.Module): @@ -109,7 +157,7 @@ def forward( # Crash prevention: ensure seq_len is perfectly divisible assert seq_len % self.group_detr == 0, ( f"Seq len {seq_len} must be divisible by group_detr {self.group_detr}" - ) # noqa: E501 + ) hidden_states_original = hidden_states if position_embeddings is not None: @@ -249,8 +297,8 @@ def forward( value = self.value_proj(encoder_hidden_states) if attention_mask is not None: - # we invert the attention_mask - value = value.masked_fill(~attention_mask[..., None], float(0)) + # attention_mask contains True on padded positions -> zero-out the padded values + value = value.masked_fill(attention_mask[..., None], float(0)) value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) sampling_offsets = self.sampling_offsets(hidden_states).view( batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 @@ -487,38 +535,39 @@ def __init__( ) def get_reference( - self, reference_points: torch.Tensor, valid_ratios: torch.Tensor + self, + reference_points: torch.Tensor, + num_levels: int, ) -> tuple[torch.Tensor, torch.Tensor]: """This function computes the reference point inputs and positional embeddings for the decoder layers. Args: reference_points: (batch_size, num_queries, 6) tensor containing the current reference points in the format (cx, cy, w, h, sinθ, cosθ) - valid_ratios: (batch_size, num_levels, 2) - tensor containing the valid ratios for each level of the input feature maps + num_levels: number of feature levels used in the decoder (1 for LW-DETR small and medium) Returns: - reference_points_inputs: (batch_size, num_queries, 1, num_levels, 4) - tensor containing the reference point inputs for the decoder layers, - which are the normalized center coordinates, - width and height of the bounding boxes w.r.t. the valid ratios of the input feature maps + reference_points_inputs: (batch_size, num_queries, num_levels, 6) + tensor containing the reference point inputs for the decoder layers query_pos: (batch_size, num_queries, d_model) tensor containing the positional embeddings for the decoder layers, which are computed from the reference points using sine and cosine functions and a linear projection """ - obj_center = reference_points[..., :4] - spatial_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] - # Extract angles - angle = reference_points[..., 4:6] # (sin, cos) - angle_expanded = angle[:, :, None] + ref_xywh = reference_points[..., :4] + angle = reference_points[..., 4:6] + + spatial_inputs = ref_xywh[:, :, None].expand(-1, -1, num_levels, -1) + angle_expanded = angle[:, :, None].expand(-1, -1, num_levels, -1) + reference_points_inputs = torch.cat([spatial_inputs, angle_expanded], dim=-1) - # DETR positional encoding - query_sine_embed = gen_sine_position_embeddings(spatial_inputs[:, :, 0, :], self.d_model) + # generate sine positional embeddings from the reference points and + # project them to get the final query positional embeddings + query_sine_embed = gen_sine_position_embeddings(ref_xywh, self.d_model) base_query_pos = self.ref_point_head(query_sine_embed) - # Angle embedding + # Angle embedding: use the same sine/cosine encoding scheme as for the spatial coordinates, + # but with a separate linear projection. sin_t = angle[..., 0:1] cos_t = angle[..., 1:2] - angle_feat = torch.cat( [ sin_t, @@ -528,51 +577,45 @@ def get_reference( ], dim=-1, ) - angle_emb = self.angle_proj(angle_feat) # Combine query_pos = base_query_pos + angle_emb return reference_points_inputs, query_pos - def refine_boxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: - reference_points = reference_points.to(deltas.device) - cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] - - # Clamp deltas to prevent exp() from shooting to Infinity during early training - wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=2.0).exp() * reference_points[..., 2:4] - - # Add eps=1e-6 to avoid division-by-zero NaN creation - delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) - sin_delta = delta_rot[..., 0:1] - cos_delta = delta_rot[..., 1:2] - sin_ref = reference_points[..., 4:5] - cos_ref = reference_points[..., 5:6] - - sin_new = sin_ref * cos_delta + cos_ref * sin_delta - cos_new = cos_ref * cos_delta - sin_ref * sin_delta - - # Add eps=1e-6 here too - rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) - - return torch.cat((cxcy, wh, rot), dim=-1) - def forward( self, inputs_embeds: torch.Tensor | None, reference_points: torch.Tensor, - spatial_shapes_list: torch.Tensor, - valid_ratios: torch.Tensor, + spatial_shapes_list: list[tuple[int, int]], encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor | None = None, ): + """Forward pass of the decoder with iterative (per-layer) box refinement. + + Returns: + last_hidden_state: (B, Q, d_model) normalized hidden states of the last decoder layer + intermediate: (num_layers, B, Q, d_model) normalized hidden states of every decoder layer + intermediate_reference_points: (num_layers, B, Q, 6) the reference points used as INPUT to each + decoder layer. Box predictions for layer i must be computed as + `refine_obb_boxes(intermediate_reference_points[i], bbox_embed(intermediate[i]))`. + Following LW-DETR, the refined reference points are detached before being fed to the next + layer, while the undetached versions are kept in the returned stack so that gradients can + flow one extra refinement step when computing the auxiliary losses. + """ intermediate: list[torch.Tensor] = [] + # reference points fed to each decoder layer (input of layer i at index i) intermediate_reference_points: list[torch.Tensor] = [reference_points] if inputs_embeds is not None: hidden_states = inputs_embeds - reference_points_inputs, query_pos = self.get_reference(reference_points, valid_ratios) + num_levels = len(spatial_shapes_list) + + reference_points_inputs, query_pos = self.get_reference( + reference_points, + num_levels=num_levels, + ) for lid, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( @@ -585,25 +628,23 @@ def forward( ) hidden_states_norm = self.layernorm(hidden_states) + intermediate.append(hidden_states_norm) # iterative refinement - if self.bbox_embed is not None: + if self.bbox_embed is not None and lid < len(self.layers) - 1: delta = self.bbox_embed(hidden_states_norm) + new_reference_points = refine_obb_boxes(reference_points, delta) - reference_points = self.refine_boxes( - reference_points.squeeze(2), - delta, - ) - - intermediate_reference_points.append(reference_points) + # keep the undetached version for the auxiliary losses ("look forward" supervision) + intermediate_reference_points.append(new_reference_points) + # the next layer consumes detached reference points + reference_points = new_reference_points.detach() reference_points_inputs, query_pos = self.get_reference( reference_points, - valid_ratios, + num_levels=num_levels, ) - intermediate.append(hidden_states_norm) - intermediate_stack = torch.stack(intermediate) last_hidden_state = intermediate_stack[-1] diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 848d44280e..96f2edd205 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -10,6 +10,7 @@ import numpy as np import torch +from scipy.optimize import linear_sum_assignment from torch import nn from torch.nn import functional as F @@ -17,7 +18,12 @@ from ...utils import load_pretrained_params from .base import _LWDETR, LWDETRPostProcessor -from .layers import LWDETRDecoder, LWDETRHead, LWDETRMultiscaleDeformableAttention, MultiScaleProjector +from .layers import ( + LWDETRDecoder, + LWDETRHead, + MultiScaleProjector, + refine_obb_boxes, +) __all__ = ["LWDETR", "lw_detr_s", "lw_detr_m"] @@ -68,6 +74,82 @@ } +def _obb_covariance_components(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute the components (a, b, c) of the Gaussian covariance matrix [[a, c], [c, b]] associated with + oriented boxes in (cx, cy, w, h, sinθ, cosθ) format, following ProbIoU (https://arxiv.org/abs/2106.06072). + + Args: + boxes: (..., 6) tensor of oriented boxes + + Returns: + a, b, c: (...,) tensors of covariance components + """ + w = boxes[..., 2].clamp(min=1e-6) + h = boxes[..., 3].clamp(min=1e-6) + rot = F.normalize(boxes[..., 4:6], dim=-1, eps=1e-6) + sin, cos = rot[..., 0], rot[..., 1] + + var_w = w.pow(2) / 12.0 + var_h = h.pow(2) / 12.0 + cos2 = cos.pow(2) + sin2 = sin.pow(2) + + a = var_w * cos2 + var_h * sin2 + b = var_w * sin2 + var_h * cos2 + c = (var_w - var_h) * cos * sin + return a, b, c + + +def _probiou( + boxes1: torch.Tensor, + boxes2: torch.Tensor, + pairwise: bool = False, + scale: float = 1.0, + eps: float = 1e-7, +) -> torch.Tensor: + """Compute the probabilistic IoU between oriented boxes in (cx, cy, w, h, sinθ, cosθ) format, + as described in `"Gaussian Bounding Boxes and Probabilistic IoU" `_. + + Args: + boxes1: (N, 6) tensor of oriented boxes + boxes2: (M, 6) tensor of oriented boxes (M = N if `pairwise` is False) + pairwise: if True, return the (N, M) matrix of IoUs, otherwise the element-wise (N,) IoUs + scale: factor applied to the spatial components (cx, cy, w, h) before computing the IoU. + ProbIoU is mathematically scale-invariant, but the stabilizing `eps` terms are not: in + normalized [0, 1] coordinates the covariance terms of small boxes (e.g. checkboxes) fall + below `eps` and the result degenerates (non-overlapping small boxes can score ~1). + eps: small value for numerical stability + + Returns: + probiou: (N,) or (N, M) tensor of IoU-like similarities in [0, 1] + """ + if scale != 1.0: + boxes1 = torch.cat([boxes1[..., :4] * scale, boxes1[..., 4:]], dim=-1) + boxes2 = torch.cat([boxes2[..., :4] * scale, boxes2[..., 4:]], dim=-1) + + x1, y1 = boxes1[..., 0], boxes1[..., 1] + x2, y2 = boxes2[..., 0], boxes2[..., 1] + a1, b1, c1 = _obb_covariance_components(boxes1) + a2, b2, c2 = _obb_covariance_components(boxes2) + + if pairwise: + x1, y1, a1, b1, c1 = (t.unsqueeze(-1) for t in (x1, y1, a1, b1, c1)) # (N, 1) + x2, y2, a2, b2, c2 = (t.unsqueeze(-2) for t in (x2, y2, a2, b2, c2)) # (1, M) + + denom = (a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps + t1 = ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / denom * 0.25 + t2 = ((c1 + c2) * (x2 - x1) * (y1 - y2)) / denom * 0.5 + t3 = ( + ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2)).clamp(min=eps) + / (4 * ((a1 * b1 - c1.pow(2)).clamp(min=eps) * (a2 * b2 - c2.pow(2)).clamp(min=eps)).sqrt() + eps) + + eps + ).log() * 0.5 + + bhattacharyya = (t1 + t2 + t3).clamp(min=eps, max=100.0) + hellinger = (1.0 - (-bhattacharyya).exp() + eps).sqrt() + return 1.0 - hellinger + + class LWDETRBackbone(nn.Module): """Backbone of LW-DETR, based on a ViT Det architecture. The backbone is used as feature extractor. @@ -102,6 +184,45 @@ def __init__( num_blocks=num_blocks, ) + def _resize_padding_mask(self, mask: torch.Tensor, size: tuple[int, int]) -> torch.Tensor: + """Resize padding mask to feature-map size + + Args: + mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels + size: the target size (H', W') for the resized mask + + Returns: + resized_mask: a binary mask of shape [batch_size x H' x W'], containing 1 on padded pixels + """ + if mask.dtype != torch.bool: + mask = mask.bool() + + valid = (~mask).float().unsqueeze(1) # True/1 = valid pixels + + # Data-dependent sanity checks + if self.training: # pragma: no cover + if (valid.flatten(1).sum(dim=1) == 0).any(): + bad = torch.where(valid.flatten(1).sum(dim=1) == 0)[0].tolist() + raise RuntimeError(f"Input masks are fully padded before resizing: {bad}") + + # Use max pooling to resize the valid mask: + # a pixel in the resized mask is valid if at least one pixel + # in the corresponding window in the input mask is valid + h_in, w_in = int(mask.shape[-2]), int(mask.shape[-1]) + h_out, w_out = int(size[0]), int(size[1]) + kh, kw = h_in // h_out, w_in // w_out + valid_resized = F.max_pool2d(valid, kernel_size=(kh, kw), stride=(kh, kw)) > 0 + + resized_mask = ~valid_resized.squeeze(1) + + # Data-dependent sanity checks + if self.training: # pragma: no cover + if resized_mask.flatten(1).all(dim=1).any(): + bad = torch.where(resized_mask.flatten(1).all(dim=1))[0].tolist() + raise RuntimeError(f"Feature masks became fully padded after resizing: {bad}") + + return resized_mask + def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> list[tuple[torch.Tensor, torch.Tensor]]: """Forward pass of the backbone. @@ -120,10 +241,7 @@ def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> list[tup # [(B, C, H, W)] if mask is None: # pragma: no cover mask = torch.zeros((x.shape[0], x.shape[2], x.shape[3]), dtype=torch.bool, device=x.device) - return [ - (feat, F.interpolate(mask.unsqueeze(1).float(), size=feat.shape[-2:], mode="nearest").squeeze(1).bool()) - for feat in feats - ] + return [(feat, self._resize_padding_mask(mask, feat.shape[2:])) for feat in feats] class LWDETR(nn.Module, _LWDETR): @@ -153,11 +271,11 @@ def __init__( self, feat_extractor: LWDETRBackbone, class_names: list[str], - score_thresh: float = 0.3, + score_thresh: float = 0.25, iou_thresh: float = 0.5, d_model: int = 256, - num_queries: int = 130, - group_detr: int = 1, + num_queries: int = 195, + group_detr: int = 13, dec_layers: int = 3, sa_num_heads: int = 8, ca_num_heads: int = 16, @@ -171,7 +289,8 @@ def __init__( super().__init__() self.class_names: list[str] = class_names - self.num_classes = len(self.class_names) + 1 # +1 for background class + # No background class: the model is trained with a sigmoid-based (IA-BCE) loss + self.num_classes = len(self.class_names) self.cfg = cfg self.exportable = exportable self.assume_straight_pages = assume_straight_pages @@ -183,10 +302,6 @@ def __init__( self.d_model = d_model self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 6) - # Initialize angle to (sin=0, cos=1) - with torch.no_grad(): - self.reference_point_embed.weight[:, 4] = 0.0 # sinθ - self.reference_point_embed.weight[:, 5] = 1.0 # cosθ self.query_feat = nn.Embedding(self.num_queries * self.group_detr, self.d_model) @@ -242,44 +357,25 @@ def __init__( nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, std=0.02) - elif isinstance(m, LWDETRMultiscaleDeformableAttention): - nn.init.constant_(m.sampling_offsets.weight, 0.0) - - thetas = torch.arange(m.n_heads, dtype=torch.float32) * (2.0 * math.pi / m.n_heads) - grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) - grid_init = ( - (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) - .view(m.n_heads, 1, 1, 2) - .repeat(1, m.n_levels, m.n_points, 1) - ) - - for i in range(m.n_points): - grid_init[:, :, i, :] *= i + 1 - - with torch.no_grad(): - m.sampling_offsets.bias.copy_(grid_init.view(-1)) - - nn.init.constant_(m.attention_weights.weight, 0.0) - nn.init.constant_(m.attention_weights.bias, 0.0) - - nn.init.xavier_uniform_(m.value_proj.weight) - nn.init.zeros_(m.value_proj.bias) - - nn.init.xavier_uniform_(m.output_proj.weight) - nn.init.zeros_(m.output_proj.bias) if isinstance(m, nn.Linear) and m.out_features == self.num_classes: prior_prob = 0.01 bias_value = -math.log((1 - prior_prob) / prior_prob) if m.bias is not None: nn.init.constant_(m.bias, bias_value) - if isinstance(m, LWDETRHead): - last = m.layers[-1] - if isinstance(last, nn.Linear): - nn.init.zeros_(last.weight) - nn.init.zeros_(last.bias) - if last.bias.shape[0] == 6: - nn.init.constant_(last.bias[5], 1.0) + + # Initialize the iterative refinement heads to predict zero deltas (i.e. identity refinement) + # at the start of training, to stabilize training in the early stages when the encoder proposals are still noisy + with torch.no_grad(): + for head in [self.bbox_embed, *self.enc_out_bbox_embed]: + last = head.layers[-1] + last.weight.zero_() + last.bias.zero_() + last.bias[5] = 1.0 # cosθ of the rotation delta -> identity rotation + + # The reference point embedding acts as a learned delta composed with the encoder proposals + self.reference_point_embed.weight.zero_() + self.reference_point_embed.weight[:, 5] = 1.0 # cosθ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: """Load pretrained parameters onto the model @@ -290,62 +386,6 @@ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: """ load_pretrained_params(self, path_or_url, **kwargs) - def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: - """Refine bounding boxes by applying the predicted deltas to the reference points. - The reference points are in the format (cx, cy, w, h, sinθ, cosθ), and the deltas are in the same format. - The refined boxes are computed as follows: - - cx' = cx + delta_cx * w - cy' = cy + delta_cy * h - w' = w * exp(delta_w) - h' = h * exp(delta_h) - sinθ' = sinθ * cosΔ + cosθ * sinΔ - cosθ' = cosθ * cosΔ - sinθ * sinΔ - - Args: - reference_points: (N, S, 6) tensor containing the reference points - deltas: (N, S, 6) tensor containing the predicted deltas - - Returns: - refined_boxes: (N, S, 6) tensor containing the refined bounding boxes - """ - reference_points = reference_points.to(deltas.device) - # center - cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] - # size - wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=2.0).exp() * reference_points[..., 2:4] - # rotation - delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) - sin_delta = delta_rot[..., 0:1] - cos_delta = delta_rot[..., 1:2] - sin_ref = reference_points[..., 4:5] - cos_ref = reference_points[..., 5:6] - - # compose rotations - sin_new = sin_ref * cos_delta + cos_ref * sin_delta - cos_new = cos_ref * cos_delta - sin_ref * sin_delta - rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) - - return torch.cat((cxcy, wh, rot), dim=-1) - - def get_valid_ratio(self, mask: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: - """Get the valid ratio of all feature maps. - - Args: - mask: (N, H, W) binary tensor containing 1 on padded pixels - dtype: the desired data type of the output tensor - - Returns: - valid_ratio: (N, 2) tensor containing the valid ratio of width and height for each image in the batch - """ - _, height, width = mask.shape - valid_height = torch.sum(~mask[:, :, 0], 1) - valid_width = torch.sum(~mask[:, 0, :], 1) - valid_ratio_height = valid_height.to(dtype) / height - valid_ratio_width = valid_width.to(dtype) / width - valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1) - return valid_ratio - def gen_encoder_output_proposals( self, enc_output: torch.Tensor, padding_mask: torch.Tensor, spatial_shapes: list[tuple[int, int]] ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -367,56 +407,48 @@ def gen_encoder_output_proposals( batch_size = enc_output.shape[0] proposals = [] _cur = 0 - for level, (height, width) in enumerate(spatial_shapes): - mask_flatten_ = padding_mask[:, _cur : (_cur + height * width)].view(batch_size, height, width, 1) - valid_height = torch.sum(~mask_flatten_[:, :, 0, 0], 1) - valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + for level, (height, width) in enumerate(spatial_shapes): grid_y, grid_x = torch.meshgrid( - torch.linspace( - 0, - height - 1, - height, - dtype=enc_output.dtype, - device=enc_output.device, - ), - torch.linspace( - 0, - width - 1, - width, - dtype=enc_output.dtype, - device=enc_output.device, - ), + torch.arange(height, dtype=enc_output.dtype, device=enc_output.device), + torch.arange(width, dtype=enc_output.dtype, device=enc_output.device), indexing="ij", ) - grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) - scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2) - grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale + grid = torch.stack([grid_x, grid_y], dim=-1) + grid = grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + + scale = torch.tensor( + [width, height], + dtype=enc_output.dtype, + device=enc_output.device, + ).view(1, 1, 1, 2) + + # Canvas-normalized center coordinates + grid = (grid + 0.5) / scale width_height = torch.ones_like(grid) * 0.05 * (2.0**level) - # add default rotation (sin=0, cos=1) sin = torch.zeros_like(grid[..., :1]) cos = torch.ones_like(grid[..., :1]) - proposal = torch.cat((grid, width_height, sin, cos), -1).view(batch_size, -1, 6) + proposal = torch.cat((grid, width_height, sin, cos), dim=-1).view(batch_size, -1, 6) proposals.append(proposal) _cur += height * width - output_proposals = torch.cat(proposals, 1) - spatial_valid = ((output_proposals[..., :4] > 0.01) & (output_proposals[..., :4] < 0.99)).all(-1, keepdim=True) - output_proposals_valid = spatial_valid - invalid_mask = padding_mask | ~output_proposals_valid.squeeze(-1) - invalid_mask = padding_mask.unsqueeze(-1) | ~output_proposals_valid - output_proposals = output_proposals.masked_fill(invalid_mask, float(0)) + output_proposals = torch.cat(proposals, dim=1) + + spatial_valid = ((output_proposals[..., :4] > 0.01) & (output_proposals[..., :4] < 0.99)).all( + dim=-1, keepdim=True + ) + invalid_mask = padding_mask.unsqueeze(-1) | ~spatial_valid + + output_proposals = output_proposals.masked_fill(invalid_mask, 0.0) + object_query = enc_output.masked_fill(invalid_mask, 0.0) - # assign each pixel as an object query - object_query = enc_output - object_query = object_query.masked_fill(invalid_mask, float(0)) return object_query, output_proposals, invalid_mask def forward( self, input: torch.Tensor, - masks: torch.Tensor, + masks: torch.Tensor | None = None, target: list[dict[str, np.ndarray]] | None = None, return_model_output: bool = False, return_preds: bool = False, @@ -455,7 +487,6 @@ def forward( mask_flatten_list.append(mask) source_flatten = torch.cat(source_flatten_list, 1) mask_flatten = torch.cat(mask_flatten_list, 1) - valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in feats_masks], 1) tgt = query_feat.unsqueeze(0).expand(batch_size, -1, -1) reference_points = reference_points.unsqueeze(0).expand(batch_size, -1, -1) @@ -469,7 +500,7 @@ def forward( topk_coords_logits_list: list[torch.Tensor] = [] - # encoder predictions for auxiliary losses + # encoder predictions on the selected top-k proposals, kept undetached for the auxiliary loss all_group_enc_logits: list[torch.Tensor] = [] all_group_enc_coords: list[torch.Tensor] = [] @@ -478,14 +509,11 @@ def forward( group_object_query = self.enc_output_norm[group_id](group_object_query) group_enc_outputs_class = self.enc_out_class_embed[group_id](group_object_query) - all_group_enc_logits.append(group_enc_outputs_class) group_enc_outputs_class_masked = group_enc_outputs_class.masked_fill(invalid_mask, float("-inf")) group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query) - group_enc_outputs_coord = self.refine_bboxes(output_proposals, group_delta_bbox) - - all_group_enc_coords.append(group_enc_outputs_coord) + group_enc_outputs_coord = refine_obb_boxes(output_proposals, group_delta_bbox) group_topk_proposals = torch.topk(group_enc_outputs_class_masked.max(-1)[0], topk, dim=1)[1] @@ -494,23 +522,33 @@ def forward( 1, group_topk_proposals.unsqueeze(-1).repeat(1, 1, 6), ) - group_topk_coords_logits = group_topk_coords_logits_undetach - topk_coords_logits_list.append(group_topk_coords_logits) + # the auxiliary loss supervises only the selected proposals, + # so gather the matching class logits as well + group_topk_logits_undetach = torch.gather( + group_enc_outputs_class, + 1, + group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.num_classes), + ) + all_group_enc_logits.append(group_topk_logits_undetach) + all_group_enc_coords.append(group_topk_coords_logits_undetach) + + # the decoder consumes detached proposals as initial reference points + topk_coords_logits_list.append(group_topk_coords_logits_undetach.detach()) topk_coords_logits = torch.cat(topk_coords_logits_list, 1) - reference_points = self.refine_bboxes(topk_coords_logits, reference_points) + reference_points = refine_obb_boxes(topk_coords_logits, reference_points) last_hidden_states, intermediate, intermediate_reference_points = self.decoder( inputs_embeds=tgt, reference_points=reference_points, spatial_shapes_list=spatial_shapes_list, - valid_ratios=valid_ratios, encoder_hidden_states=source_flatten, + encoder_attention_mask=mask_flatten, ) logits = self.class_embed(last_hidden_states) pred_boxes_delta = self.bbox_embed(last_hidden_states) - pred_boxes = self.refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta) + pred_boxes = refine_obb_boxes(intermediate_reference_points[-1], pred_boxes_delta) out: dict[str, Any] = {} @@ -534,223 +572,163 @@ def _postprocess(logits, boxes): # Build target processed_targets = self.build_target(target, self.class_names) - # Main loss from final decoder layer (group DETR) + # ProbIoU is computed in pixel coordinates + box_scale = float(max(input.shape[-2], input.shape[-1])) + + # Main loss from final decoder layer (group DETR: each group is matched independently) split_logits = logits.chunk(group_detr, dim=1) split_boxes = pred_boxes.chunk(group_detr, dim=1) main_loss: float | torch.Tensor = 0.0 for g_logits, g_boxes in zip(split_logits, split_boxes): - main_loss += self.compute_loss(g_logits, g_boxes, processed_targets) + main_loss += self.compute_loss(g_logits, g_boxes, processed_targets, box_scale=box_scale) loss = main_loss / group_detr # Auxiliary losses from intermediate decoder layers + # (`intermediate_reference_points[i]` is the reference INPUT to decoder layer i) for i in range(intermediate.shape[0] - 1): aux_logits = self.class_embed(intermediate[i]) aux_boxes_delta = self.bbox_embed(intermediate[i]) - aux_boxes = self.refine_bboxes(intermediate_reference_points[i], aux_boxes_delta) + aux_boxes = refine_obb_boxes(intermediate_reference_points[i], aux_boxes_delta) split_aux_logits = aux_logits.chunk(group_detr, dim=1) split_aux_boxes = aux_boxes.chunk(group_detr, dim=1) aux_loss: float | torch.Tensor = 0.0 for g_logits, g_boxes in zip(split_aux_logits, split_aux_boxes): - aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets) - loss += 0.5 * (aux_loss / group_detr) + aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets, box_scale=box_scale) + loss += aux_loss / group_detr - # Auxiliary losses for encoder proposals + # Auxiliary losses for the selected encoder proposals enc_loss: float | torch.Tensor = 0.0 for group_logits, group_coords in zip(all_group_enc_logits, all_group_enc_coords): - enc_loss += self.compute_loss(group_logits, group_coords, processed_targets) - loss += 0.1 * (enc_loss / group_detr) + enc_loss += self.compute_loss(group_logits, group_coords, processed_targets, box_scale=box_scale) + loss += enc_loss / group_detr out["loss"] = loss return out def compute_loss( - self, logits: torch.Tensor, pred_boxes: torch.Tensor, targets: list[dict[str, np.ndarray]] + self, + logits: torch.Tensor, + pred_boxes: torch.Tensor, + targets: list[dict[str, np.ndarray]], + cls_loss_weight: float = 1.0, + l1_loss_weight: float = 5.0, + iou_loss_weight: float = 2.0, + box_scale: float = 1024.0, ) -> torch.Tensor: - """ - Compute the loss for LW-DETR. The loss consists of three components: - classification loss, box regression loss, and rotation loss. - The classification loss is a cross-entropy loss between the predicted class logits and the target classes. - The box regression loss is a Smooth L1 loss between the predicted boxes and the target boxes, - computed only on the positive samples. - The rotation loss is computed as 1 - cosine similarity between the predicted rotation and the target rotation, - averaged over the positive samples. - The positive samples are determined using a SimOTA-like assignment strategy, where for each ground truth box, - we select the top-k queries with the lowest cost - (combination of classification cost, box regression cost, and rotation cost). + """Compute the LW-DETR loss for oriented bounding boxes. + + Predictions are matched one-to-one to the ground truth boxes with Hungarian matching, + using a cost combining a focal-style classification cost, an L1 box cost + and a (negated) ProbIoU cost. The loss then consists of: + + - an IoU-aware binary cross-entropy (IA-BCE) classification loss, as described in the LW-DETR paper, + where the target of a matched (query, class) pair is `p**alpha * IoU**(1 - alpha)` and the rotated + ProbIoU is used as IoU measure + - an L1 regression loss on the normalized (cx, cy, w, h) of the matched pairs + - a ProbIoU loss (1 - ProbIoU) on the matched oriented boxes, computed in absolute pixel + coordinates (as in O2-RT-DETR). The box rotation is supervised solely through this term: + ProbIoU is differentiable w.r.t. the angle even for non-overlapping boxes. + + All terms are normalized by the number of ground truth boxes in the batch. Args: logits: (B, Q, C) tensor containing the predicted class logits for each query - pred_boxes: (B, Q, 6) tensor containing the predicted boxes for each query - targets: list of dictionaries where each dictionary corresponds to a sample and has keys corresponding - to class names and values corresponding to lists of boxes in either polygon format (4, 2) - or bounding box format (4,) (xmin, ymin, xmax, ymax) + pred_boxes: (B, Q, 6) tensor containing the predicted boxes (cx, cy, w, h, sinθ, cosθ) for each query + targets: list of B dictionaries with keys "boxes" ((N, 6) array in OBB format) + and "labels" ((N,) array of class indices) + cls_loss_weight: weight of the classification loss + l1_loss_weight: weight of the L1 box regression loss + iou_loss_weight: weight of the ProbIoU loss + box_scale: image size used to rescale normalized boxes to pixel coordinates for ProbIoU computation Returns: loss: the computed loss value """ - - def rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - Convert rotated boxes in (cx, cy, w, h, sinθ, cosθ) format to Gaussian distribution parameters - (mean and covariance). - """ - cxcy = boxes[..., :2] - - w = boxes[..., 2].clamp(min=1e-6) - h = boxes[..., 3].clamp(min=1e-6) - - sin = boxes[..., 4] - cos = boxes[..., 5] - - R = torch.stack( - [ - torch.stack([cos, -sin], dim=-1), - torch.stack([sin, cos], dim=-1), - ], - dim=-2, - ) - - sx = (w / 2) ** 2 - sy = (h / 2) ** 2 - - S = torch.zeros((*boxes.shape[:-1], 2, 2), device=boxes.device) - - S[..., 0, 0] = sx - S[..., 1, 1] = sy - - covariance = R @ S @ R.transpose(-1, -2) - return cxcy, covariance - - def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Tensor: - """Compute the ProbIoU loss between predicted boxes and target boxes.""" - mu1, sigma1 = rotated_boxes_to_gaussian(pred_boxes) - mu2, sigma2 = rotated_boxes_to_gaussian(tgt_boxes) - - delta = (mu1 - mu2).unsqueeze(-1) - sigma = (sigma1 + sigma2) * 0.5 - - eps = 1e-6 - eye = torch.eye(2, device=sigma.device) * eps - sigma_safe = sigma + eye - sigma1_safe = sigma1 + eye - sigma2_safe = sigma2 + eye - - sigma_inv = torch.linalg.inv(sigma_safe) - - mahalanobis = (delta.transpose(-1, -2) @ sigma_inv @ delta).squeeze(-1).squeeze(-1) - - det_sigma = torch.linalg.det(sigma_safe).clamp(min=eps) - det_sigma1 = torch.linalg.det(sigma1_safe).clamp(min=eps) - det_sigma2 = torch.linalg.det(sigma2_safe).clamp(min=eps) - - bhattacharyya = 0.125 * mahalanobis + 0.5 * torch.log(det_sigma / torch.sqrt(det_sigma1 * det_sigma2)) - - probiou = torch.exp(-bhattacharyya) - return 1 - probiou - + alpha, gamma, eps = 0.25, 2.0, 1e-8 device = logits.device - B, Q, C = logits.shape + batch_size = logits.shape[0] - total_cls = torch.tensor(0.0, device=device) - total_box = torch.tensor(0.0, device=device) - total_rot = torch.tensor(0.0, device=device) + tgt_boxes_list: list[torch.Tensor] = [] + tgt_labels_list: list[torch.Tensor] = [] + for sample in targets: + tgt_boxes_list.append( + torch.as_tensor(sample["boxes"], device=device, dtype=pred_boxes.dtype).reshape(-1, 6) + ) + tgt_labels_list.append(torch.as_tensor(sample["labels"], device=device, dtype=torch.long).reshape(-1)) - for b in range(B): - pred_logits = logits[b] - pred_boxes_b = pred_boxes[b] + # Number of target boxes in the batch, for loss normalization + num_boxes = max(sum(int(labels.numel()) for labels in tgt_labels_list), 1) - tgt_boxes = torch.as_tensor( - targets[b]["boxes"], - device=device, - dtype=pred_boxes.dtype, - ) - tgt_cls = torch.as_tensor( - targets[b]["labels"], - device=device, - dtype=torch.long, - ) + # Hungarian matching (one-to-one), performed independently for each sample + indices: list[tuple[torch.Tensor, torch.Tensor]] = [] + with torch.no_grad(): + prob = logits.sigmoid() + for b in range(batch_size): + tgt_boxes, tgt_labels = tgt_boxes_list[b], tgt_labels_list[b] + if tgt_labels.numel() == 0: + empty = torch.empty(0, dtype=torch.long, device=device) + indices.append((empty, empty)) + continue + + out_prob = prob[b] + out_boxes = pred_boxes[b] + + # Focal-style classification cost + neg_cost = (1 - alpha) * out_prob.pow(gamma) * (-(1 - out_prob + eps).log()) + pos_cost = alpha * (1 - out_prob).pow(gamma) * (-(out_prob + eps).log()) + cost_class = pos_cost[:, tgt_labels] - neg_cost[:, tgt_labels] + + # L1 cost on normalized (cx, cy, w, h) + cost_bbox = torch.cdist(out_boxes[:, :4], tgt_boxes[:, :4], p=1) + + # Rotated IoU cost, computed in pixel coordinates + # this term also carries the angle signal for the matching + cost_iou = -_probiou(out_boxes, tgt_boxes, pairwise=True, scale=box_scale) + + cost = 2.0 * cost_class + 5.0 * cost_bbox + 2.0 * cost_iou + + query_idx, tgt_idx = linear_sum_assignment(cost.cpu().numpy()) + indices.append(( + torch.as_tensor(query_idx, dtype=torch.long, device=device), + torch.as_tensor(tgt_idx, dtype=torch.long, device=device), + )) + + # Flatten the matched pairs across the batch + batch_idx = torch.cat([torch.full_like(src, b) for b, (src, _) in enumerate(indices)]) + query_idx = torch.cat([src for (src, _) in indices]) + matched_tgt_boxes = torch.cat([tgt_boxes_list[b][tgt] for b, (_, tgt) in enumerate(indices)]) + matched_tgt_labels = torch.cat([tgt_labels_list[b][tgt] for b, (_, tgt) in enumerate(indices)]) + matched_pred_boxes = pred_boxes[batch_idx, query_idx] + + prob = logits.sigmoid() + + # IoU-aware BCE classification loss (IA-BCE) + pos_weights = torch.zeros_like(logits) + neg_weights = prob.pow(gamma) + if len(batch_idx) > 0: + with torch.no_grad(): + ious = _probiou(matched_pred_boxes, matched_tgt_boxes, scale=box_scale).clamp(min=0.0, max=1.0) + t = prob[batch_idx, query_idx, matched_tgt_labels].pow(alpha) * ious.pow(1 - alpha) + t = t.clamp(min=0.01) + pos_weights[batch_idx, query_idx, matched_tgt_labels] = t + neg_weights[batch_idx, query_idx, matched_tgt_labels] = 1 - t - num_gt = len(tgt_cls) + cls_loss = -(pos_weights * (prob + eps).log() + neg_weights * (1 - prob + eps).log()) + loss = cls_loss_weight * cls_loss.sum() / num_boxes - pred_rot = F.normalize(pred_boxes_b[:, 4:6], dim=-1) - tgt_rot = F.normalize(tgt_boxes[:, 4:6], dim=-1) + if len(batch_idx) == 0: + return loss - with torch.no_grad(): - cls_prob = pred_logits.sigmoid() - alpha = 0.25 - gamma = 2.0 - - neg_cost = (1 - alpha) * (cls_prob**gamma) * (-(1 - cls_prob + 1e-8).log()) - - pos_cost = alpha * ((1 - cls_prob) ** gamma) * (-(cls_prob + 1e-8).log()) - - cost_cls = pos_cost[:, tgt_cls] - neg_cost[:, tgt_cls] - cost_l1 = torch.cdist( - pred_boxes_b[:, :4], - tgt_boxes[:, :4], - p=1, - ) - cost_rot = 1 - (pred_rot @ tgt_rot.T).abs() - total_cost = 5.0 * cost_cls + 2.0 * cost_l1 + 1.0 * cost_rot - matching_matrix = torch.zeros( - (Q, num_gt), - dtype=torch.bool, - device=device, - ) - - center_dist = torch.cdist( - pred_boxes_b[:, :2], - tgt_boxes[:, :2], - p=2, - ) - - iou_like = torch.exp(-center_dist) - dynamic_k = iou_like.sum(0).int().clamp(min=1, max=10) - - for gt_idx in range(num_gt): - _, candidate_idx = torch.topk(-total_cost[:, gt_idx], k=int(dynamic_k[gt_idx].item())) - matching_matrix[candidate_idx, gt_idx] = True - - # resolve duplicate matches - multiple_match_mask = matching_matrix.sum(1) > 1 - - if multiple_match_mask.any(): - duplicate_idx = multiple_match_mask.nonzero(as_tuple=False).squeeze(1) - min_cost_idx = total_cost[duplicate_idx].argmin(dim=1) - # Set all matches to False for the duplicate indices, - # then set the match with the lowest cost to True - matching_matrix[duplicate_idx] = False - matching_matrix[duplicate_idx, min_cost_idx] = True - - pos_idx, gt_indices = matching_matrix.nonzero(as_tuple=True) - - target_classes = torch.zeros((Q,), dtype=torch.long, device=device) - - # background = 0 - target_classes[pos_idx] = tgt_cls[gt_indices] - - total_cls += F.cross_entropy(pred_logits, target_classes) - - if len(pos_idx) == 0: - continue + # L1 loss on normalized (cx, cy, w, h) + l1_loss = F.l1_loss(matched_pred_boxes[:, :4], matched_tgt_boxes[:, :4], reduction="sum") / num_boxes + # ProbIoU loss on the whole oriented box (position, size and rotation), in pixel coordinates + probiou_loss = (1 - _probiou(matched_pred_boxes, matched_tgt_boxes, scale=box_scale)).sum() / num_boxes - pred_sel = pred_boxes_b[pos_idx] - tgt_sel = tgt_boxes[gt_indices] - # L1 loss on (cx, cy, w, h) - l1_loss = F.smooth_l1_loss(pred_sel[:, :4], tgt_sel[:, :4]) - # ProbIoU loss on the whole box (including rotation) - probiou_loss = _probiou_loss(pred_sel, tgt_sel).mean() - total_box += 2.0 * l1_loss + 0.5 * probiou_loss - # Rotation loss - cos_sim = (pred_rot[pos_idx] * tgt_rot[gt_indices]).sum(-1).abs() - rot_loss = (1 - cos_sim).mean() - total_rot += 0.5 * rot_loss - # Average the loss over the batch - return (total_cls + total_box + total_rot) / B + return loss + l1_loss_weight * l1_loss + iou_loss_weight * probiou_loss def _lw_detr( @@ -772,7 +750,7 @@ def _lw_detr( False, include_top=False, input_shape=default_cfgs[arch]["input_shape"], - patch_size=kwargs.get("patch_size", (16, 16)), + patch_size=kwargs.pop("patch_size", (16, 16)), ) feat_extractor = LWDETRBackbone(encoder_fn=backbone) diff --git a/doctr/transforms/modules/base.py b/doctr/transforms/modules/base.py index a53acf038f..5b68d45b06 100644 --- a/doctr/transforms/modules/base.py +++ b/doctr/transforms/modules/base.py @@ -272,7 +272,14 @@ def __init__(self, scale: tuple[float, float] = (0.08, 1.0), ratio: tuple[float, def extra_repr(self) -> str: return f"scale={self.scale}, ratio={self.ratio}" - def _crop_array(self, img: Any, target: np.ndarray, crop_box): + def _crop_image_only(self, img: Any, crop_box: tuple[float, float, float, float]) -> Any: + dummy_box = np.array([[0, 0, 1, 1]], dtype=np.float32) + cropped_img, _ = F.crop_detection(img, dummy_box, crop_box) + return cropped_img + + def _crop_array( + self, img: Any, target: np.ndarray, crop_box: tuple[float, float, float, float] + ) -> tuple[Any, np.ndarray]: is_polygon = target.shape[1:] == (4, 2) if is_polygon: @@ -339,17 +346,15 @@ def __call__(self, sample: Sample) -> Sample: (y + crop_h) / h, ) - r_mask = None - if mask is not None: - r_mask, _ = self._crop_array(mask, np.zeros((0, 4)), crop_box) - if target is None: - r_img, _ = self._crop_array(img, np.zeros((0, 4)), crop_box) + r_img = self._crop_image_only(img, crop_box) + r_mask = self._crop_image_only(mask, crop_box) if mask is not None else None return sample.replace(image=r_img, mask=r_mask) if isinstance(target, dict): cropped_targets = {} cropped_img = None + crop_rejected = False for cls_name, arr in target.items(): if len(arr) == 0: @@ -358,13 +363,29 @@ def __call__(self, sample: Sample) -> Sample: c_img, c_arr = self._crop_array(img, arr, crop_box) + if c_img is img: + crop_rejected = True + break + if cropped_img is None: cropped_img = c_img cropped_targets[cls_name] = c_arr - final_img = cropped_img if cropped_img is not None else img - return sample.replace(image=final_img, mask=r_mask, target=cropped_targets) + if crop_rejected or cropped_img is None: + return sample.replace( + image=img, + mask=mask, + target={key: value.copy() for key, value in target.items()}, + ) + + r_mask = self._crop_image_only(mask, crop_box) if mask is not None else None + return sample.replace(image=cropped_img, mask=r_mask, target=cropped_targets) c_img, c_target = self._crop_array(img, target, crop_box) + if c_img is img: + r_mask = mask + else: + r_mask = self._crop_image_only(mask, crop_box) if mask is not None else None + return sample.replace(image=c_img, mask=r_mask, target=c_target) diff --git a/doctr/transforms/modules/pytorch.py b/doctr/transforms/modules/pytorch.py index c7245f3a82..c6b7881b56 100644 --- a/doctr/transforms/modules/pytorch.py +++ b/doctr/transforms/modules/pytorch.py @@ -380,7 +380,7 @@ def __call__(self, sample: Sample) -> Sample: else: shadowed_image = random_shadow(sample.image, self.opacity_range).clip(0, 1) return sample.replace(image=shadowed_image) - except ValueError: + except ValueError: # pragma: no cover return sample def extra_repr(self) -> str: diff --git a/references/classification/train_character.py b/references/classification/train_character.py index c266bb2f5c..22a9298ecb 100644 --- a/references/classification/train_character.py +++ b/references/classification/train_character.py @@ -66,7 +66,7 @@ def record_lr( loss_recorder = [] if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for batch_idx, (images, targets) in enumerate(train_loader): targets = torch.tensor(targets) @@ -79,7 +79,7 @@ def record_lr( # Forward, Backward & update optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) train_loss = cross_entropy(out, targets) scaler.scale(train_loss).backward() @@ -110,7 +110,7 @@ def record_lr( def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None): if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") model.train() # Iterate over the batches of the dataset @@ -126,7 +126,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) train_loss = cross_entropy(out, targets) scaler.scale(train_loss).backward() @@ -168,7 +168,7 @@ def evaluate(model, val_loader, batch_transforms, amp=False, log=None): targets = targets.cuda() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) loss = cross_entropy(out, targets) else: diff --git a/references/classification/train_orientation.py b/references/classification/train_orientation.py index 86dc4c5931..0cbbf05616 100644 --- a/references/classification/train_orientation.py +++ b/references/classification/train_orientation.py @@ -78,7 +78,7 @@ def record_lr( loss_recorder = [] if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for batch_idx, (images, targets) in enumerate(train_loader): targets = torch.tensor(targets) @@ -91,7 +91,7 @@ def record_lr( # Forward, Backward & update optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) train_loss = cross_entropy(out, targets) scaler.scale(train_loss).backward() @@ -122,7 +122,7 @@ def record_lr( def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None): if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") model.train() # Iterate over the batches of the dataset @@ -138,7 +138,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) train_loss = cross_entropy(out, targets) scaler.scale(train_loss).backward() @@ -180,7 +180,7 @@ def evaluate(model, val_loader, batch_transforms, amp=False, log=None): targets = targets.cuda() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) loss = cross_entropy(out, targets) else: diff --git a/references/detection/evaluate.py b/references/detection/evaluate.py index 7c5fb597aa..a8674884f4 100644 --- a/references/detection/evaluate.py +++ b/references/detection/evaluate.py @@ -39,7 +39,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): images = batch_transforms(images) targets = [{CLASS_NAME: t} for t in targets] if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images, targets, return_preds=True) else: out = model(images, targets, return_preds=True) diff --git a/references/detection/train.py b/references/detection/train.py index 43e4ec3a56..2070b63c3e 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -63,7 +63,7 @@ def record_lr( loss_recorder = [] if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for batch_idx, (images, targets) in enumerate(train_loader): if torch.cuda.is_available(): @@ -74,7 +74,7 @@ def record_lr( # Forward, Backward & update optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(images, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping @@ -107,7 +107,7 @@ def record_lr( def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0): if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") model.train() # Iterate over the batches of the dataset @@ -120,7 +120,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(images, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping @@ -163,7 +163,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, args, amp=False, l images = images.cuda() images = batch_transforms(images) if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images, targets, return_preds=True) else: out = model(images, targets, return_preds=True) @@ -333,7 +333,7 @@ def main(args): [ T.RandomHorizontalFlip(0.15), T.OneOf([ - T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25), + T.RandomApply(T.RandomCrop(ratio=(0.85, 1.15), scale=(0.75, 1.0)), 0.25), T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), ]), T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), @@ -342,7 +342,7 @@ def main(args): else [ T.RandomHorizontalFlip(0.15), T.OneOf([ - T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25), + T.RandomApply(T.RandomCrop(ratio=(0.85, 1.15), scale=(0.75, 1.0)), 0.25), T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), ]), # Rotation augmentation @@ -363,7 +363,7 @@ def main(args): ) if distributed: - sampler = DistributedSampler(train_set, rank=rank, shuffle=False, drop_last=True) + sampler = DistributedSampler(train_set, rank=rank, shuffle=True, drop_last=True) else: sampler = RandomSampler(train_set) @@ -518,6 +518,8 @@ def log_at_step(train_loss=None, val_loss=None, lr=None): # Training loop for epoch in range(args.epochs): + if distributed: + sampler.set_epoch(epoch) train_loss, actual_lr = fit_one_epoch( model, train_loader, batch_transforms, optimizer, scheduler, amp=args.amp, log=log_at_step, rank=rank ) diff --git a/references/layout/evaluate.py b/references/layout/evaluate.py index 36fb78cf80..aa8ea788aa 100644 --- a/references/layout/evaluate.py +++ b/references/layout/evaluate.py @@ -39,7 +39,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): padding_masks = padding_masks.cuda() imgs = batch_transforms(imgs) if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(imgs, padding_masks, targets, return_preds=True) else: out = model(imgs, padding_masks, targets, return_preds=True) diff --git a/references/layout/train.py b/references/layout/train.py index f3f0d2c117..8f9ce372b1 100644 --- a/references/layout/train.py +++ b/references/layout/train.py @@ -17,7 +17,14 @@ # The following import is required for DDP import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR +from torch.optim.lr_scheduler import ( + CosineAnnealingLR, + LinearLR, + MultiplicativeLR, + OneCycleLR, + PolynomialLR, + SequentialLR, +) from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler from torchvision.transforms.v2 import Compose, Normalize, RandomGrayscale, RandomPhotometricDistort @@ -31,7 +38,7 @@ from doctr.datasets import LayoutDataset from doctr.models import layout, login_to_hub, push_to_hf_hub from doctr.utils.metrics import ObjectDetectionMetric -from utils import EarlyStopper, convert_target, plot_recorder, plot_samples +from utils import EarlyStopper, build_param_groups, convert_target, plot_recorder, plot_samples def record_lr( @@ -63,7 +70,7 @@ def record_lr( loss_recorder = [] if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for batch_idx, (images, targets) in enumerate(train_loader): imgs, padding_masks = images @@ -77,19 +84,19 @@ def record_lr( # Forward, Backward & update optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(imgs, padding_masks, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) # Update the params scaler.step(optimizer) scaler.update() else: train_loss = model(imgs, padding_masks, targets)["loss"] train_loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) optimizer.step() # Update LR scheduler.step() @@ -110,7 +117,7 @@ def record_lr( def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0): if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") model.train() # Iterate over the batches of the dataset @@ -125,19 +132,19 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(imgs, padding_masks, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) # Update the params scaler.step(optimizer) scaler.update() else: train_loss = model(imgs, padding_masks, targets)["loss"] train_loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) optimizer.step() scheduler.step() @@ -170,7 +177,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False, log=Non padding_masks = padding_masks.cuda() imgs = batch_transforms(imgs) if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(imgs, padding_masks, targets, return_preds=True) else: out = model(imgs, padding_masks, targets, return_preds=True) @@ -383,7 +390,7 @@ def main(args): [ T.RandomHorizontalFlip(0.15), T.OneOf([ - T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25), + T.RandomApply(T.RandomCrop(ratio=(0.85, 1.15), scale=(0.75, 1.0)), 0.25), T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), ]), T.Resize( @@ -397,7 +404,7 @@ def main(args): else [ T.RandomHorizontalFlip(0.15), T.OneOf([ - T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25), + T.RandomApply(T.RandomCrop(ratio=(0.85, 1.15), scale=(0.75, 1.0)), 0.25), T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), ]), # Rotation augmentation @@ -424,7 +431,7 @@ def main(args): ) if distributed: - sampler = DistributedSampler(train_set, rank=rank, shuffle=False, drop_last=True) + sampler = DistributedSampler(train_set, rank=rank, shuffle=True, drop_last=True) else: sampler = RandomSampler(train_set) @@ -472,24 +479,18 @@ def main(args): # construct DDP model model = DDP(model, device_ids=[rank]) + param_groups = build_param_groups( + model, + lr=args.lr, + backbone_lr=args.lr if not args.pretrained else args.lr * 0.1, + weight_decay=args.weight_decay or 1e-4, + ) + # Optimizer if args.optim == "adam": - optimizer = torch.optim.Adam( - [p for p in model.parameters() if p.requires_grad], - args.lr, - betas=(0.95, 0.999), - eps=1e-6, - weight_decay=args.weight_decay, - ) - + optimizer = torch.optim.Adam(param_groups, lr=args.lr, betas=(0.9, 0.999), eps=1e-8) elif args.optim == "adamw": - optimizer = torch.optim.AdamW( - [p for p in model.parameters() if p.requires_grad], - args.lr, - betas=(0.9, 0.999), - eps=1e-6, - weight_decay=args.weight_decay or 1e-4, - ) + optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.999), eps=1e-8) # LR Finder if rank == 0 and args.find_lr: @@ -498,12 +499,55 @@ def main(args): return # Scheduler + total_steps = args.epochs * len(train_loader) + warmup_steps = max(1, min(2000, int(0.05 * total_steps))) + if args.sched == "cosine": - scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4) + warmup = LinearLR( + optimizer, + start_factor=0.01, + end_factor=1.0, + total_iters=warmup_steps, + ) + cosine = CosineAnnealingLR( + optimizer, + T_max=max(1, total_steps - warmup_steps), + eta_min=args.lr * 0.01, + ) + scheduler = SequentialLR( + optimizer, + schedulers=[warmup, cosine], + milestones=[warmup_steps], + ) + elif args.sched == "onecycle": - scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader)) + scheduler = OneCycleLR( + optimizer, + max_lr=[g["lr"] for g in optimizer.param_groups], + total_steps=total_steps, + pct_start=warmup_steps / total_steps, + div_factor=100, + final_div_factor=100, + anneal_strategy="cos", + ) + elif args.sched == "poly": - scheduler = PolynomialLR(optimizer, args.epochs * len(train_loader)) + warmup = LinearLR( + optimizer, + start_factor=0.01, + end_factor=1.0, + total_iters=warmup_steps, + ) + poly = PolynomialLR( + optimizer, + total_iters=total_steps - warmup_steps, + power=1.0, + ) + scheduler = SequentialLR( + optimizer, + schedulers=[warmup, poly], + milestones=[warmup_steps], + ) # Training monitoring current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") @@ -591,6 +635,8 @@ def log_at_step(train_loss=None, val_loss=None, lr=None): # Training loop for epoch in range(args.epochs): + if distributed: + sampler.set_epoch(epoch) train_loss, actual_lr = fit_one_epoch( model, train_loader, @@ -690,8 +736,8 @@ def parse_args(): "--save-interval-epoch", dest="save_interval_epoch", action="store_true", help="Save model every epoch" ) parser.add_argument("--input_size", type=int, default=1024, help="model input size, H = W") - parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam or AdamW)") - parser.add_argument("--wd", "--weight-decay", default=0, type=float, help="weight decay", dest="weight_decay") + parser.add_argument("--lr", type=float, default=1e-4, help="learning rate for the optimizer (Adam or AdamW)") + parser.add_argument("--wd", "--weight-decay", default=1e-4, type=float, help="weight decay", dest="weight_decay") parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading") parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint") parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop") diff --git a/references/layout/utils.py b/references/layout/utils.py index 6629487ab7..f7e56f049f 100644 --- a/references/layout/utils.py +++ b/references/layout/utils.py @@ -86,6 +86,44 @@ def plot_samples( plt.show() +def build_param_groups(model: Any, lr: float, backbone_lr: float, weight_decay: float): + """Build parameter groups for the optimizer, separating backbone and non-backbone parameters, + and applying weight decay only to non-bias and non-norm parameters. + + Args: + model: the model containing the parameters + lr: learning rate for non-backbone parameters + backbone_lr: learning rate for backbone parameters + weight_decay: weight decay to apply to non-bias and non-norm parameters + + Returns: + a list of parameter groups to be passed to the optimizer + """ + no_decay_keys = ("bias", "norm", ".bn", "embed") # Embedding, LayerNorm, BN + + def is_backbone(name: str) -> bool: + return name.removeprefix("module.").startswith("feat_extractor.") + + groups: dict[tuple[bool, bool], list[Any]] = { + (False, True): [], + (False, False): [], + (True, True): [], + (True, False): [], + } + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + decay = not (p.ndim <= 1 or any(k in n.lower() for k in no_decay_keys)) + groups[(is_backbone(n), decay)].append(p) + + return [ + {"params": groups[(False, True)], "lr": lr, "weight_decay": weight_decay}, + {"params": groups[(False, False)], "lr": lr, "weight_decay": 0.0}, + {"params": groups[(True, True)], "lr": backbone_lr, "weight_decay": weight_decay}, + {"params": groups[(True, False)], "lr": backbone_lr, "weight_decay": 0.0}, + ] + + def plot_recorder(lr_recorder, loss_recorder, beta: float = 0.95, **kwargs) -> None: """Display the results of the LR grid search. Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py diff --git a/references/recognition/evaluate.py b/references/recognition/evaluate.py index 45a6b38306..22f69fa2cd 100644 --- a/references/recognition/evaluate.py +++ b/references/recognition/evaluate.py @@ -38,7 +38,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): images = images.cuda() images = batch_transforms(images) if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images, targets, return_preds=True) else: out = model(images, targets, return_preds=True) diff --git a/references/recognition/train.py b/references/recognition/train.py index dc6b7b1b24..913f2ed017 100644 --- a/references/recognition/train.py +++ b/references/recognition/train.py @@ -68,7 +68,7 @@ def record_lr( loss_recorder = [] if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for batch_idx, (images, targets) in enumerate(train_loader): if torch.cuda.is_available(): @@ -79,7 +79,7 @@ def record_lr( # Forward, Backward & update optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(images, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping @@ -112,7 +112,7 @@ def record_lr( def fit_one_epoch(model, device, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0): if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") model.train() # Iterate over the batches of the dataset @@ -125,7 +125,7 @@ def fit_one_epoch(model, device, train_loader, batch_transforms, optimizer, sche optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(images, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping @@ -167,7 +167,7 @@ def evaluate(model, device, val_loader, batch_transforms, val_metric, amp=False, images = images.to(device) images = batch_transforms(images) if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images, targets, return_preds=True) else: out = model(images, targets, return_preds=True) @@ -559,6 +559,8 @@ def log_at_step(train_loss=None, val_loss=None, lr=None): early_stopper = EarlyStopper(patience=args.early_stop_epochs, min_delta=args.early_stop_delta) # Training loop for epoch in range(args.epochs): + if distributed: + sampler.set_epoch(epoch) train_loss, actual_lr = fit_one_epoch( model, device, diff --git a/tests/pytorch/test_models_factory.py b/tests/pytorch/test_models_factory.py index 8cf0a9fa45..db9aed1cdf 100644 --- a/tests/pytorch/test_models_factory.py +++ b/tests/pytorch/test_models_factory.py @@ -50,7 +50,7 @@ def test_push_to_hf_hub(): ["vitstr_small", "recognition", "Felix92/doctr-dummy-torch-vitstr-small"], ["parseq", "recognition", "Felix92/doctr-dummy-torch-parseq"], ["viptr_tiny", "recognition", "Felix92/doctr-dummy-torch-viptr-tiny"], - # ["lw_detr_s", "layout", "Felix92/doctr-dummy-torch-lw-detr-s"], + ["lw_detr_s", "layout", "Felix92/doctr-dummy-torch-lw-detr-s"], ], ) def test_models_huggingface_hub(arch_name, task_name, dummy_model_id, tmpdir): diff --git a/tests/pytorch/test_models_layout.py b/tests/pytorch/test_models_layout.py index d0c2865411..5eb3f582e3 100644 --- a/tests/pytorch/test_models_layout.py +++ b/tests/pytorch/test_models_layout.py @@ -27,7 +27,7 @@ def test_layout_models(arch_name, input_shape, train_mode, use_polygons): model = model.train() if train_mode else model.eval() assert isinstance(model, torch.nn.Module) input_tensor = torch.rand((batch_size, *input_shape)) - input_masks = torch.ones((batch_size, input_shape[1], input_shape[2]), dtype=torch.bool) + input_masks = torch.zeros((batch_size, input_shape[1], input_shape[2]), dtype=torch.bool) class_names = model.class_names @@ -130,7 +130,7 @@ def test_models_onnx_export(arch_name, input_shape): batch_size = 2 model = layout.__dict__[arch_name](pretrained=True, exportable=True).eval() dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32) - dummy_masks = torch.ones((batch_size, input_shape[1], input_shape[2]), dtype=torch.bool) + dummy_masks = torch.zeros((batch_size, input_shape[1], input_shape[2]), dtype=torch.bool) pt = model(dummy_input, dummy_masks) pt_logits = pt["logits"].detach().cpu().numpy() pt_boxes = pt["pred_boxes"].detach().cpu().numpy() diff --git a/tests/pytorch/test_transforms_pt.py b/tests/pytorch/test_transforms_pt.py index 2a41a712b1..353dc87625 100644 --- a/tests/pytorch/test_transforms_pt.py +++ b/tests/pytorch/test_transforms_pt.py @@ -133,6 +133,21 @@ def test_resize(): assert mask.shape == (64, 64) assert mask.dtype == torch.bool + # Test with included mask in input sample + input_t = Sample( + image=torch.ones((3, 32, 64), dtype=torch.float32), + mask=torch.zeros((32, 64), dtype=torch.bool), + target=target_boxes, + ) + data = transfo(input_t) + out, mask, new_target = data.image, data.mask, data.target + + assert out.shape[-2:] == (64, 64) + assert new_target.shape == target_boxes.shape + assert np.all((0 <= new_target) & (new_target <= 1)) + assert mask.shape == (64, 64) + assert mask.dtype == torch.bool + # Test with invalid target shape input_t = torch.ones((3, 32, 64), dtype=torch.float32) target = np.ones((2, 5)) # Invalid shape @@ -312,6 +327,18 @@ def test_random_rotate(): assert r_targets["boxes"].shape == (0, 4) assert r_targets["polygons"].shape == (0, 4, 2) + # Test with mask in input sample + input_m = torch.zeros((50, 50), dtype=torch.bool) + data = rotator(Sample(image=input_t, mask=input_m, target=boxes)) + r_img, r_mask = data.image, data.mask + assert r_img.ndim == input_t.ndim + assert r_mask.ndim - 1 == input_m.ndim # Mask should be 2D + + # Test without target + data = rotator(Sample(image=input_t)) + r_img = data.image + assert r_img.ndim == input_t.ndim + # FP16 (only on GPU) if torch.cuda.is_available(): input_t = torch.ones((3, 50, 50), dtype=torch.float16).cuda() @@ -354,8 +381,8 @@ def test_crop_detection(): @pytest.mark.parametrize( "target", [ - np.array([[15, 20, 35, 30]]), # box - np.array([[[15, 20], [35, 20], [35, 30], [15, 30]]]), # polygon + np.array([[15, 20, 35, 30]], dtype=np.float32), # box + np.array([[[15, 20], [35, 20], [35, 30], [15, 30]]], dtype=np.float32), # polygon ], ) def test_random_crop(target): @@ -363,23 +390,28 @@ def test_random_crop(target): assert repr(cropper) == "RandomCrop(scale=(0.5, 1.0), ratio=(0.75, 1.33))" input_t = torch.ones((3, 50, 50), dtype=torch.float32) - sample = cropper(Sample(image=input_t, target=target)) - img, target = sample.image, sample.target + original_target = target.copy() + + sample = cropper(Sample(image=input_t, target=original_target.copy())) + img, cropped_target = sample.image, sample.target + # Check the scale assert img.shape[-1] * img.shape[-2] >= 0.4 * input_t.shape[-1] * input_t.shape[-2] # Check aspect ratio assert 0.65 <= img.shape[-2] / img.shape[-1] <= 1.6 # Check the target - assert np.all(target >= 0) - if target.ndim == 2: - assert np.all(target[:, [0, 2]] <= img.shape[-1]) and np.all(target[:, [1, 3]] <= img.shape[-2]) + assert np.all(cropped_target >= 0) + if cropped_target.ndim == 2: + assert np.all(cropped_target[:, [0, 2]] <= img.shape[-1]) + assert np.all(cropped_target[:, [1, 3]] <= img.shape[-2]) else: - assert np.all(target[..., 0] <= img.shape[-1]) and np.all(target[..., 1] <= img.shape[-2]) + assert np.all(cropped_target[..., 0] <= img.shape[-1]) + assert np.all(cropped_target[..., 1] <= img.shape[-2]) # Test dict targets dict_target = { - "boxes": np.array([[15, 20, 35, 30]]), - "polygons": np.array([[[15, 20], [35, 20], [35, 30], [15, 30]]]), + "boxes": np.array([[15, 20, 35, 30]], dtype=np.float32), + "polygons": np.array([[[15, 20], [35, 20], [35, 30], [15, 30]]], dtype=np.float32), } sample = cropper(Sample(image=input_t, target=dict_target)) img, cropped_targets = sample.image, sample.target @@ -401,6 +433,22 @@ def test_random_crop(target): assert np.all(cropped_targets["polygons"][..., 0] <= img.shape[-1]) assert np.all(cropped_targets["polygons"][..., 1] <= img.shape[-2]) + # Test with mask in input sample + input_m = torch.ones((50, 50), dtype=torch.bool) + sample = cropper(Sample(image=input_t, mask=input_m, target=original_target.copy())) + img, mask, cropped_target = sample.image, sample.mask, sample.target + + assert img.shape[-1] * img.shape[-2] >= 0.4 * input_t.shape[-1] * input_t.shape[-2] + assert 0.65 <= img.shape[-2] / img.shape[-1] <= 1.6 + assert mask.shape == img.shape[-2:] + assert mask.dtype == torch.bool + + # Test without target + sample = cropper(Sample(image=input_t)) + img = sample.image + assert img.shape[-1] * img.shape[-2] >= 0.4 * input_t.shape[-1] * input_t.shape[-2] + assert 0.65 <= img.shape[-2] / img.shape[-1] <= 1.6 + @pytest.mark.parametrize( "input_dtype, input_size", @@ -450,6 +498,7 @@ def test_gaussian_noise(input_dtype, input_shape): assert torch.all(transformed <= 255) else: assert torch.all(transformed <= 1.0) + assert repr(transform) == "GaussianNoise(mean=0.0, std=1.0)" @pytest.mark.parametrize(