From 03c0166ef81896331eb2d193d0b6b76c7d05e96c Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 27 May 2026 14:24:47 +0200 Subject: [PATCH 01/15] straight check --- doctr/models/layout/lw_detr/base.py | 164 +++--- doctr/models/layout/lw_detr/layers/pytorch.py | 219 +++----- doctr/models/layout/lw_detr/loss.py | 527 ++++++++++++++++++ doctr/models/layout/lw_detr/pytorch.py | 475 ++++------------ 4 files changed, 781 insertions(+), 604 deletions(-) create mode 100644 doctr/models/layout/lw_detr/loss.py diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index 2aa59e7d18..f7c926b13b 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -9,7 +9,6 @@ import numpy as np from doctr.models.core import BaseModel -from doctr.utils import order_points __all__ = ["_LWDETR", "LWDETRPostProcessor"] @@ -39,29 +38,36 @@ def __init__( self.topk = topk self.assume_straight_pages = assume_straight_pages - def _decode_boxes(self, boxes: np.ndarray) -> tuple[np.ndarray, np.ndarray]: - """Decode the predicted boxes from OBB format to polygon format - - Args: - boxes: array of predicted boxes in OBB format (N, 6) (cx, cy, w, h, sin(theta), cos(theta)) - - Returns: - tuple of (polys, angles) where polys is an array of decoded polygons (N, 4, 2) - and angles is an array of angles in radians (N,) + def _decode_boxes(self, boxes: np.ndarray) -> np.ndarray: """ - cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] - sin, cos = boxes[:, 4], boxes[:, 5] - - angles = np.arctan2(sin, cos) + Decode cxcywh -> polygons (axis-aligned rectangles) + """ + cx = boxes[:, 0] + cy = boxes[:, 1] + w = boxes[:, 2] + h = boxes[:, 3] polys = [] + for i in range(len(boxes)): - rect = ((float(cx[i]), float(cy[i])), (float(w[i]), float(h[i])), float(np.degrees(angles[i]))) + x1 = cx[i] - w[i] / 2 + y1 = cy[i] - h[i] / 2 + x2 = cx[i] + w[i] / 2 + y2 = cy[i] + h[i] / 2 + + poly = np.array( + [ + [x1, y1], + [x2, y1], + [x2, y2], + [x1, y2], + ], + dtype=np.float32, + ) - poly = order_points(cv2.boxPoints(rect)) polys.append(poly) - return np.asarray(polys, dtype=np.float32), angles + return np.asarray(polys, dtype=np.float32) def _iou(self, poly1: np.ndarray, poly2: np.ndarray) -> float: """Compute the IoU between two polygons @@ -136,22 +142,21 @@ def _nms(self, polys: np.ndarray, scores: np.ndarray, labels: np.ndarray) -> lis suppressed[j] = True return keep - def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int], np.ndarray, list[float]]]: + def __call__(self, logits: np.ndarray, boxes: np.ndarray): + logits = np.asarray(logits) boxes = np.asarray(boxes) - results: list[tuple[list[int], np.ndarray, list[float]]] = [] + results = [] for b in range(boxes.shape[0]): - # Convert logits to probabilities and get scores and labels exp = np.exp(logits[b] - logits[b].max(axis=-1, keepdims=True)) prob = exp / exp.sum(axis=-1, keepdims=True) - prob_fg = prob[:, :-1] # exclude background + prob_fg = prob[:, :-1] scores = prob_fg.max(axis=-1) labels = prob_fg.argmax(axis=-1) - # Keep only topk predictions before NMS if self.topk is not None and len(scores) > self.topk: idxs = np.argsort(scores)[::-1][: self.topk] else: @@ -167,44 +172,36 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int scores_b = scores_b[mask] labels_b = labels_b[mask] - polys, _ = ( - self._decode_boxes(bboxes) - if len(bboxes) > 0 - else ( - np.zeros((0, 4, 2), dtype=np.float32), - np.zeros((0,), dtype=np.float32), - ) - ) + polys = self._decode_boxes(bboxes) if len(bboxes) > 0 else np.zeros((0, 4, 2), dtype=np.float32) keep = self._nms(polys, scores_b, labels_b) if len(polys) > 0 else [] - final_labels = [] final_boxes = [] + final_labels = [] final_scores = [] for idx in keep: - poly = polys[idx].reshape(-1).tolist() + poly = polys[idx] + if self.assume_straight_pages: - x_coords = poly[0::2] - y_coords = poly[1::2] - xmin, xmax = min(x_coords), max(x_coords) - ymin, ymax = min(y_coords), max(y_coords) + # 👉 COCO-style axis aligned box from polygon + xmin = float(np.min(poly[:, 0])) + xmax = float(np.max(poly[:, 0])) + ymin = float(np.min(poly[:, 1])) + ymax = float(np.max(poly[:, 1])) + final_boxes.append([xmin, ymin, xmax, ymax]) else: - final_boxes.append(poly) + final_boxes.append(poly.reshape(-1).tolist()) final_labels.append(int(labels_b[idx])) final_scores.append(float(scores_b[idx])) - final_boxes_arr = ( - np.asarray(final_boxes, dtype=np.float32).reshape(-1, 4, 2) - if not self.assume_straight_pages - else np.asarray(final_boxes, dtype=np.float32).reshape(-1, 4) - ) + final_boxes_arr = np.asarray(final_boxes, dtype=np.float32) results.append(( final_labels, - final_boxes_arr, + final_boxes_arr, # <- NOW ALWAYS CLEAN FORMAT final_scores, )) @@ -221,55 +218,12 @@ def build_target( target: list[dict[str, np.ndarray]], class_names: list[str], ) -> list[dict[str, Any]]: - """Build the target for LW-DETR training - - Args: - target: list of dictionaries where each dictionary corresponds to a sample and has keys corresponding - to class names and values corresponding to lists of boxes in either polygon format (4, 2) - or bounding box format (4,) (xmin, ymin, xmax, ymax) - class_names: list of class names - - Returns: - list of dictionaries with keys "boxes" and "labels" where "boxes" is an array of shape (num_boxes, 6) - containing the box parameters in OBB format (cx, cy, w, h, sin(theta), cos(theta)) - and "labels" is an array of shape (num_boxes,) containing the class labels + """ + Build targets in COCO format: [xmin, ymin, w, h] """ targets = [] class_to_id = {name: i for i, name in enumerate(class_names)} - def _quad_to_obb(poly: np.ndarray): - poly = np.asarray(poly, dtype=np.float32) - - # Center point is simply the average of the relative vertices - cx, cy = np.mean(poly, axis=0) - - edges = np.stack([ - poly[1] - poly[0], - poly[2] - poly[1], - poly[3] - poly[2], - poly[0] - poly[3], - ]) - - lengths = np.linalg.norm(edges, axis=1) - i = np.argmax(lengths) - dx, dy = edges[i] - - theta = np.arctan2(dy, dx) - - # Width and height remain cleanly in relative coordinate space [0, 1] - w = np.mean([lengths[i], lengths[(i + 2) % 4]]) - h = np.mean([lengths[(i + 1) % 4], lengths[(i + 3) % 4]]) - - # Enforce strict unit-length normal vectors for rotation - sin_t = np.sin(theta) - cos_t = np.cos(theta) - norm = np.sqrt(sin_t**2 + cos_t**2) + 1e-8 - - return np.array( - [cx, cy, w, h, sin_t / norm, cos_t / norm], - dtype=np.float32, - ) - def to_quad(box: np.ndarray): box = np.asarray(box, dtype=np.float32) if box.shape == (4,): @@ -281,6 +235,19 @@ def to_quad(box: np.ndarray): return box.astype(np.float32) raise ValueError(f"Unsupported box shape: {box.shape}") + def quad_to_coco(poly: np.ndarray) -> np.ndarray: + xmin = float(np.min(poly[:, 0])) + xmax = float(np.max(poly[:, 0])) + ymin = float(np.min(poly[:, 1])) + ymax = float(np.max(poly[:, 1])) + + w = xmax - xmin + h = ymax - ymin + cx = xmin + w / 2.0 + cy = ymin + h / 2.0 + + return np.array([cx, cy, w, h], dtype=np.float32) + for sample in target: boxes_all = [] labels_all = [] @@ -295,20 +262,29 @@ def to_quad(box: np.ndarray): if boxes.ndim == 1: boxes = boxes[None, :] + # sanity check normalized coords + flat = boxes.ravel() + coord_vals = flat[flat > 0] + if len(coord_vals) > 0 and coord_vals.max() > 1.5: + raise ValueError("build_target expects normalized [0,1] coordinates.") + for box in boxes: poly = to_quad(box) - obb = _quad_to_obb(poly) + coco_box = quad_to_coco(poly) - # filter out degenerate boxes - if obb[2] <= 1e-5 or obb[3] <= 1e-5: + if coco_box[2] <= 1e-5 or coco_box[3] <= 1e-5: continue - boxes_all.append(obb) + boxes_all.append(coco_box) labels_all.append(cls_id) + if len(boxes_all) == 0: + boxes_all = np.zeros((0, 4), dtype=np.float32) + labels_all = np.zeros((0,), dtype=np.int64) + targets.append({ - "boxes": np.asarray(boxes_all, dtype=np.float32), - "labels": np.asarray(labels_all, dtype=np.int64), + "boxes": np.asarray(boxes_all, dtype=np.float32), # (N, 4) + "class_labels": np.asarray(labels_all, dtype=np.int64), }) return targets diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py index 08f9ea9fb4..9f6d75ab30 100644 --- a/doctr/models/layout/lw_detr/layers/pytorch.py +++ b/doctr/models/layout/lw_detr/layers/pytorch.py @@ -113,7 +113,7 @@ def forward( hidden_states_original = hidden_states if position_embeddings is not None: - hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings + hidden_states = hidden_states + position_embeddings if self.training: # at training, we use group detr technique to @@ -238,6 +238,7 @@ def forward( encoder_hidden_states=None, position_embeddings: torch.Tensor | None = None, reference_points=None, + spatial_shapes=None, spatial_shapes_list=None, ) -> tuple[torch.Tensor, torch.Tensor]: # add position embeddings to the hidden states before projecting to queries and keys @@ -263,35 +264,19 @@ def forward( ) # batch_size, num_queries, n_heads, n_levels, n_points, 2 num_coordinates = reference_points.shape[-1] - - if num_coordinates == 4: + if num_coordinates == 2: + offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif num_coordinates == 4: sampling_locations = ( reference_points[:, :, None, :, None, :2] + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 ) - elif num_coordinates == 6: - ref = reference_points[:, :, None, :, None, :] # (..., 6) - - center = ref[..., :2] # (cx, cy) - wh = ref[..., 2:4] # (w, h) - sin = ref[..., 4:5] # sinθ - cos = ref[..., 5:6] # cosθ - - # normalize offsets - offsets = sampling_offsets / self.n_points * wh * 0.5 - - dx = offsets[..., 0:1] - dy = offsets[..., 1:2] - - # rotate offsets - dx_rot = dx * cos - dy * sin - dy_rot = dx * sin + dy * cos - - rotated_offsets = torch.cat([dx_rot, dy_rot], dim=-1) - - sampling_locations = center + rotated_offsets else: - raise ValueError(f"Last dim of reference_points must be 4 or 6, but got {reference_points.shape[-1]}") + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") output = self.attn( value, @@ -361,6 +346,7 @@ def forward( hidden_states: torch.Tensor, position_embeddings: torch.Tensor | None = None, reference_points: torch.Tensor | None = None, + spatial_shapes: torch.Tensor | None = None, spatial_shapes_list: list[tuple] | None = None, encoder_hidden_states: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, @@ -379,6 +365,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, position_embeddings=position_embeddings, reference_points=reference_points, + spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list, ) cross_attention_output = F.dropout(cross_attention_output, p=self.dropout, training=self.training) @@ -393,43 +380,40 @@ def forward( # function to generate sine positional embedding for 4d coordinates # Borrowed from: https://github.com/Atten4Vis/LW-DETR/blob/main/models/transformer.py -def gen_sine_position_embeddings(pos_tensor: torch.Tensor, hidden_size: int = 256) -> torch.Tensor: - """ - This function computes position embeddings using sine and cosine functions from the input positional tensor, - which has a shape of (batch_size, num_queries, 4). - The last dimension of `pos_tensor` represents the following coordinates: - - 0: x-coord - - 1: y-coord - - 2: width - - 3: height - - The output shape is (batch_size, num_queries, 512), - where final dim (hidden_size*2 = 512) is the total embedding dimension - achieved by concatenating the sine and cosine values for each coordinate. +def encode_sinusoidal_position_embedding( + pos_tensor: torch.Tensor, + num_pos_feats: int = 128, + temperature: int = 10000, +) -> torch.Tensor: + """Sinusoidal position embeddings from normalized anchor coordinates. + + Each coordinate in `pos_tensor` is independently encoded with ``num_pos_feats`` + interleaved sin/cos components; per-coordinate embeddings are concatenated. + Handles 2-D ``(x, y)`` and N-D ``(x, y, w, h)`` inputs. For 2-D+ inputs the + x and y embeddings are swapped to follow the DETR ``[pos_y, pos_x, ...]`` convention. + + Args: + pos_tensor: Normalized coordinates in ``[0, 1]``, shape ``(..., n_coords)``. + num_pos_feats: Embedding dimension per coordinate. + temperature: Base for the frequency decay. + + Returns: + Tensor of shape ``(..., n_coords * num_pos_feats)``, same dtype as input. """ scale = 2 * math.pi - dim = hidden_size // 2 - dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device) - dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim) - x_embed = pos_tensor[:, :, 0] * scale - y_embed = pos_tensor[:, :, 1] * scale - pos_x = x_embed[:, :, None] / dim_t - pos_y = y_embed[:, :, None] / dim_t - pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) - pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) - if pos_tensor.size(-1) == 4: - w_embed = pos_tensor[:, :, 2] * scale - pos_w = w_embed[:, :, None] / dim_t - pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) - - h_embed = pos_tensor[:, :, 3] * scale - pos_h = h_embed[:, :, None] / dim_t - pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) - - pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) - else: - raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}") - return pos.to(pos_tensor.dtype) + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device) + dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats) + + coords = pos_tensor.unbind(-1) # list of (...,) tensors + embeddings = [coord[..., None] * scale / dim_t for coord in coords] # each (..., num_pos_feats) + embeddings = [ + torch.stack((e[..., 0::2].sin(), e[..., 1::2].cos()), dim=-1).flatten(-2) for e in embeddings + ] # each (..., num_pos_feats) + + if len(embeddings) >= 2: + embeddings[0], embeddings[1] = embeddings[1], embeddings[0] + + return torch.cat(embeddings, dim=-1).to(pos_tensor.dtype) class LWDETRDecoder(nn.Module): @@ -458,7 +442,6 @@ def __init__( dec_n_points: int = 2, group_detr: int = 13, dropout_prob: float = 0.0, - bbox_embed: nn.Module | None = None, ): super().__init__() self.dropout_prob = dropout_prob @@ -477,89 +460,30 @@ def __init__( for i in range(num_layers) ]) self.layernorm = nn.LayerNorm(self.d_model) - self.bbox_embed = bbox_embed self.ref_point_head = LWDETRHead(2 * self.d_model, self.d_model, self.d_model, num_layers=2) - self.angle_proj = nn.Sequential( - nn.Linear(4, self.d_model), - nn.ReLU(), - nn.Linear(self.d_model, self.d_model), - ) - def get_reference( - self, reference_points: torch.Tensor, valid_ratios: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - """This function computes the reference point inputs and positional embeddings for the decoder layers. - - Args: - reference_points: (batch_size, num_queries, 6) - tensor containing the current reference points in the format (cx, cy, w, h, sinθ, cosθ) - valid_ratios: (batch_size, num_levels, 2) - tensor containing the valid ratios for each level of the input feature maps - - Returns: - reference_points_inputs: (batch_size, num_queries, 1, num_levels, 4) - tensor containing the reference point inputs for the decoder layers, - which are the normalized center coordinates, - width and height of the bounding boxes w.r.t. the valid ratios of the input feature maps - query_pos: (batch_size, num_queries, d_model) - tensor containing the positional embeddings for the decoder layers, - which are computed from the reference points using sine and cosine functions and a linear projection - """ + def get_reference(self, reference_points, valid_ratios): + # batch_size, num_queries, batch_size, 4 obj_center = reference_points[..., :4] - spatial_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] - # Extract angles - angle = reference_points[..., 4:6] # (sin, cos) - angle_expanded = angle[:, :, None] - reference_points_inputs = torch.cat([spatial_inputs, angle_expanded], dim=-1) - # DETR positional encoding - query_sine_embed = gen_sine_position_embeddings(spatial_inputs[:, :, 0, :], self.d_model) - base_query_pos = self.ref_point_head(query_sine_embed) - # Angle embedding - sin_t = angle[..., 0:1] - cos_t = angle[..., 1:2] - - angle_feat = torch.cat( - [ - sin_t, - cos_t, - 2 * sin_t * cos_t, - cos_t**2 - sin_t**2, - ], - dim=-1, - ) - - angle_emb = self.angle_proj(angle_feat) - # Combine - query_pos = base_query_pos + angle_emb - return reference_points_inputs, query_pos - - def refine_boxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: - reference_points = reference_points.to(deltas.device) - cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] - - # Clamp deltas to prevent exp() from shooting to Infinity during early training - wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=2.0).exp() * reference_points[..., 2:4] - - # Add eps=1e-6 to avoid division-by-zero NaN creation - delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) - sin_delta = delta_rot[..., 0:1] - cos_delta = delta_rot[..., 1:2] - sin_ref = reference_points[..., 4:5] - cos_ref = reference_points[..., 5:6] - sin_new = sin_ref * cos_delta + cos_ref * sin_delta - cos_new = cos_ref * cos_delta - sin_ref * sin_delta + # batch_size, num_queries, num_levels, 4 + reference_points_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] - # Add eps=1e-6 here too - rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) + # batch_size, num_queries, d_model * 2 + query_sine_embed = encode_sinusoidal_position_embedding( + reference_points_inputs[:, :, 0, :], num_pos_feats=self.d_model // 2 + ) - return torch.cat((cxcy, wh, rot), dim=-1) + # batch_size, num_queries, d_model + query_pos = self.ref_point_head(query_sine_embed) + return reference_points_inputs, query_pos def forward( self, inputs_embeds: torch.Tensor | None, reference_points: torch.Tensor, + spatial_shapes: torch.Tensor, spatial_shapes_list: torch.Tensor, valid_ratios: torch.Tensor, encoder_hidden_states: torch.Tensor, @@ -581,35 +505,18 @@ def forward( encoder_attention_mask=encoder_attention_mask, position_embeddings=query_pos, reference_points=reference_points_inputs, + spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list, ) - hidden_states_norm = self.layernorm(hidden_states) - - # iterative refinement - if self.bbox_embed is not None: - delta = self.bbox_embed(hidden_states_norm) - - reference_points = self.refine_boxes( - reference_points.squeeze(2), - delta, - ) - - intermediate_reference_points.append(reference_points) - - reference_points_inputs, query_pos = self.get_reference( - reference_points, - valid_ratios, - ) - - intermediate.append(hidden_states_norm) - - intermediate_stack = torch.stack(intermediate) - last_hidden_state = intermediate_stack[-1] + intermediate_hidden_states = self.layernorm(hidden_states) + intermediate.append(intermediate_hidden_states) - intermediate_reference_points_stack = torch.stack(intermediate_reference_points) + intermediate = torch.stack(intermediate) + last_hidden_state = intermediate[-1] + intermediate_reference_points = torch.stack(intermediate_reference_points) - return last_hidden_state, intermediate_stack, intermediate_reference_points_stack + return last_hidden_state, intermediate, intermediate_reference_points class MultiScaleProjector(nn.Module): diff --git a/doctr/models/layout/lw_detr/loss.py b/doctr/models/layout/lw_detr/loss.py new file mode 100644 index 0000000000..ba4364764f --- /dev/null +++ b/doctr/models/layout/lw_detr/loss.py @@ -0,0 +1,527 @@ +import numpy as np +import torch +import torch.nn as nn +from scipy.optimize import linear_sum_assignment +from torch import Tensor + +__all__ = ["lw_detr_for_object_detection_loss"] + + +def center_to_corners_format(bboxes_center: "torch.Tensor") -> "torch.Tensor": + center_x, center_y, width, height = bboxes_center.unbind(-1) + bbox_corners = torch.stack( + # top left x, top left y, bottom right x, bottom right y + [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)], + dim=-1, + ) + return bbox_corners + + +def dice_loss(inputs, targets, num_boxes): + """ + Compute the DICE loss, similar to generalized IOU for masks + + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs (0 for the negative class and 1 for the positive + class). + num_boxes: Normalization factor, typically the number of target boxes in the batch. This is used to scale the + loss to an absolute value, and is used in the original implementation of DETR and LW-DETR. + It doesn't have to be + exactly the number of target boxes, but it should be correlated to it for the loss to be meaningful. + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_boxes + + +def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): + """ + Loss used in RetinaNet for dense detection: https://huggingface.co/papers/1708.02002. + + Args: + inputs (`torch.FloatTensor` of arbitrary shape): + The predictions for each example. + targets (`torch.FloatTensor` with the same shape as `inputs`): + A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class + and 1 for the positive class). + num_boxes (`int`): + Normalization factor, typically the number of target boxes in the batch. This is used to scale the loss + to an absolute value, and is used in the original implementation of DETR and LW-DETR. It doesn't have to be + exactly the number of target boxes, but it should be correlated to it for the loss to be meaningful. + alpha (`float`, *optional*, defaults to `0.25`): + Optional weighting factor in the range (0,1) to balance positive vs. negative examples. + gamma (`int`, *optional*, defaults to `2`): + Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. + + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + # add modulating factor + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + return loss.mean(1).sum() / num_boxes + + +def _upcast(t: Tensor) -> Tensor: + # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type + if t.is_floating_point(): + return t if t.dtype in (torch.float32, torch.float64) else t.float() + else: + return t if t.dtype in (torch.int32, torch.int64) else t.int() + + +def box_area(boxes: Tensor) -> Tensor: + """ + Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates. + + Args: + boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`): + Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1 + < x2` and `0 <= y1 < y2`. + + Returns: + `torch.FloatTensor`: a tensor containing the area for each box. + """ + boxes = _upcast(boxes) + return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2] + inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format. + + Returns: + `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + if not (boxes1[:, 2:] >= boxes1[:, :2]).all(): + raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}") + if not (boxes2[:, 2:] >= boxes2[:, :2]).all(): + raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}") + iou, union = box_iou(boxes1, boxes2) + + top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2] + area = width_height[:, :, 0] * width_height[:, :, 1] + + return iou - (area - union) / area + + +# below: taken from https://github.com/facebookresearch/detr/blob/master/util/misc.py#L306 +def _max_by_axis(the_list): + # type: (list[list[int]]) -> list[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor: + def __init__(self, tensors, mask: Tensor | None): + self.tensors = tensors + self.mask = mask + + def to(self, device): + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: list[Tensor]): + if tensor_list[0].ndim == 3: + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + batch_shape = [len(tensor_list)] + max_size + batch_size, num_channels, height, width = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], : img.shape[2]] = False + else: + raise ValueError("Only 3-dimensional tensors are supported") + return NestedTensor(tensor, mask) + + +# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py +def _set_aux_loss(outputs_class, outputs_coord): + return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + + +class LwDetrHungarianMatcher(nn.Module): + def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1): + super().__init__() + + self.class_cost = class_cost + self.bbox_cost = bbox_cost + self.giou_cost = giou_cost + if class_cost == 0 and bbox_cost == 0 and giou_cost == 0: + raise ValueError("All costs of the Matcher can't be 0") + + @torch.no_grad() + def forward(self, outputs, targets, group_detr): + """ + Differences: + - out_prob = outputs["logits"].flatten(0, 1).sigmoid() instead of softmax + - class_cost uses alpha and gamma + """ + batch_size, num_queries = outputs["logits"].shape[:2] + + # We flatten to compute the cost matrices in a batch + out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes] + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + + # Also concat the target labels and boxes + target_ids = torch.cat([torch.as_tensor(v["class_labels"], dtype=torch.int64) for v in targets]).to( + out_prob.device + ) + target_bbox = torch.cat([torch.as_tensor(v["boxes"], dtype=torch.float32) for v in targets]).to(out_bbox.device) + + # Compute the classification cost. + alpha = 0.25 + gamma = 2.0 + neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log()) + pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) + class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids] + + # Compute the L1 cost between boxes, cdist only supports float32 + dtype = out_bbox.dtype + out_bbox = out_bbox.to(torch.float32) + target_bbox = target_bbox.to(torch.float32) + bbox_cost = torch.cdist(out_bbox, target_bbox, p=1) + bbox_cost = bbox_cost.to(dtype) + + # Compute the giou cost between boxes + giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox)) + + # Final cost matrix + cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost + cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu() + + sizes = [len(v["boxes"]) for v in targets] + indices = [] + group_num_queries = num_queries // group_detr + cost_matrix_list = cost_matrix.split(group_num_queries, dim=1) + for group_id in range(group_detr): + group_cost_matrix = cost_matrix_list[group_id] + group_indices = [linear_sum_assignment(c[i]) for i, c in enumerate(group_cost_matrix.split(sizes, -1))] + if group_id == 0: + indices = group_indices + else: + indices = [ + ( + np.concatenate([indice1[0], indice2[0] + group_num_queries * group_id]), + np.concatenate([indice1[1], indice2[1]]), + ) + for indice1, indice2 in zip(indices, group_indices) + ] + return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] + + +class LwDetrImageLoss(nn.Module): + def __init__(self, matcher, num_classes, focal_alpha, losses, group_detr): + super().__init__() + self.matcher = matcher + self.num_classes = num_classes + self.focal_alpha = focal_alpha + self.losses = losses + self.group_detr = group_detr + + # removed logging parameter, which was part of the original implementation + def loss_labels(self, outputs, targets, indices, num_boxes): + if "logits" not in outputs: + raise KeyError("No logits were found in the outputs") + source_logits = outputs["logits"] + dtype = source_logits.dtype + + idx = self._get_source_permutation_idx(indices) + target_classes_o = torch.cat([ + torch.as_tensor(np.atleast_1d(t["class_labels"][J]), dtype=torch.int64) + for t, (_, J) in zip(targets, indices) + ]).to(source_logits.device) + alpha = self.focal_alpha + gamma = 2 + src_boxes = outputs["pred_boxes"][idx] + target_boxes = torch.cat( + [torch.as_tensor(np.atleast_2d(t["boxes"][i]), dtype=torch.float32) for t, (_, i) in zip(targets, indices)], + dim=0, + ).to(src_boxes.device) + iou_targets = torch.diag( + box_iou(center_to_corners_format(src_boxes.detach()), center_to_corners_format(target_boxes))[0] + ) + # Convert to the same dtype as the source logits as box_iou upcasts to float32 + iou_targets = iou_targets.to(dtype) + pos_ious = iou_targets.clone().detach() + prob = source_logits.sigmoid() + # init positive weights and negative weights + pos_weights = torch.zeros_like(source_logits) + # pow promotes to float32 under float16 CUDA autocast; cast back to preserve original dtype + neg_weights = prob.pow(gamma).to(dtype) + pos_ind = idx + (target_classes_o,) + + pos_quality = prob[pos_ind].pow(alpha) * pos_ious.pow(1 - alpha) + pos_quality = torch.clamp(pos_quality, 0.01).detach().to(dtype) + + pos_weights[pos_ind] = pos_quality + neg_weights[pos_ind] = 1 - pos_quality + loss_ce = -pos_weights * prob.log() - neg_weights * (1 - prob).log() + loss_ce = loss_ce.sum() / num_boxes + losses = {"loss_ce": loss_ce} + + return losses + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes): + """ + Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes. + + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients. + """ + logits = outputs["logits"] + device = logits.device + target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device) + # Count the number of predictions that are NOT "no-object" (sigmoid > 0.5 threshold) + card_pred = (logits.sigmoid().max(-1).values > 0.5).sum(1) + card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float()) + losses = {"cardinality_error": card_err} + return losses + + # Copied from loss.loss_for_object_detection.ImageLoss.loss_boxes + def loss_boxes(self, outputs, targets, indices, num_boxes): + """ + Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss. + + Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes + are expected in format (center_x, center_y, w, h), normalized by the image size. + """ + if "pred_boxes" not in outputs: + raise KeyError("No predicted boxes found in outputs") + idx = self._get_source_permutation_idx(indices) + source_boxes = outputs["pred_boxes"][idx] + target_boxes = torch.cat( + [torch.as_tensor(np.atleast_2d(t["boxes"][i]), dtype=torch.float32) for t, (_, i) in zip(targets, indices)], + dim=0, + ).to(source_boxes.device) + + loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none") + + losses = {} + losses["loss_bbox"] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag( + generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes)) + ) + losses["loss_giou"] = loss_giou.sum() / num_boxes + return losses + + # Copied from loss.loss_for_object_detection.ImageLoss.loss_masks + def loss_masks(self, outputs, targets, indices, num_boxes): + """ + Compute the losses related to the masks: the focal loss and the dice loss. + + Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]. + """ + if "pred_masks" not in outputs: + raise KeyError("No predicted masks found in outputs") + + source_idx = self._get_source_permutation_idx(indices) + target_idx = self._get_target_permutation_idx(indices) + source_masks = outputs["pred_masks"] + source_masks = source_masks[source_idx] + masks = [t["masks"] for t in targets] + # TODO use valid to mask invalid areas due to padding in loss + target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() + target_masks = target_masks.to(source_masks) + target_masks = target_masks[target_idx] + + # upsample predictions to the target size + source_masks = nn.functional.interpolate( + source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False + ) + source_masks = source_masks[:, 0].flatten(1) + + target_masks = target_masks.flatten(1) + target_masks = target_masks.view(source_masks.shape) + losses = { + "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes), + "loss_dice": dice_loss(source_masks, target_masks, num_boxes), + } + return losses + + # Copied from loss.loss_for_object_detection.ImageLoss._get_source_permutation_idx + def _get_source_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)]) + source_idx = torch.cat([source for (source, _) in indices]) + return batch_idx, source_idx + + # Copied from loss.loss_for_object_detection.ImageLoss._get_target_permutation_idx + def _get_target_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)]) + target_idx = torch.cat([target for (_, target) in indices]) + return batch_idx, target_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes): + loss_map = { + "labels": self.loss_labels, + "cardinality": self.loss_cardinality, + "boxes": self.loss_boxes, + "masks": self.loss_masks, + } + if loss not in loss_map: + raise ValueError(f"Loss {loss} not supported") + return loss_map[loss](outputs, targets, indices, num_boxes) + + def forward(self, outputs, targets): + """ + This performs the loss computation. + + Args: + outputs (`dict`, *optional*): + Dictionary of tensors, see the output specification of the model for the format. + targets (`list[dict]`, *optional*): + List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the + losses applied, see each loss' doc. + """ + group_detr = self.group_detr if self.training else 1 + outputs_without_aux_and_enc = { + k: v for k, v in outputs.items() if k != "enc_outputs" and k != "auxiliary_outputs" + } + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux_and_enc, targets, group_detr) + + # Compute the average number of target boxes across all nodes, for normalization purposes + num_boxes = sum(len(t["class_labels"]) for t in targets) + num_boxes = num_boxes * group_detr + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + world_size = 1 + num_boxes = torch.clamp(num_boxes / world_size, min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if "auxiliary_outputs" in outputs: + for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]): + indices = self.matcher(auxiliary_outputs, targets, group_detr) + for loss in self.losses: + if loss == "masks": + # Intermediate masks losses are too costly to compute, we ignore them. + continue + l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes) + l_dict = {k + f"_{i}": v for k, v in l_dict.items()} + losses.update(l_dict) + + if "enc_outputs" in outputs: + enc_outputs = outputs["enc_outputs"] + indices = self.matcher(enc_outputs, targets, group_detr=group_detr) + for loss in self.losses: + l_dict = self.get_loss(loss, enc_outputs, targets, indices, num_boxes) + l_dict = {k + "_enc": v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses + + +def lw_detr_for_object_detection_loss( + logits, + labels, + device, + pred_boxes, + outputs_class=None, + outputs_coord=None, + enc_outputs_class=None, + enc_outputs_coord=None, + use_aux_loss=False, + group_detr=1, + num_labels=None, + num_decoder_layers=None, + **kwargs, +): + """Loss computation for LW-DETR for object detection.""" + # First: create the matcher + matcher = LwDetrHungarianMatcher(class_cost=2.0, bbox_cost=5, giou_cost=2) + # Second: create the criterion + losses = ["labels", "boxes", "cardinality"] + criterion = LwDetrImageLoss( + matcher=matcher, + num_classes=num_labels, + focal_alpha=0.1, + losses=losses, + group_detr=group_detr, + ) + criterion.to(device) + # Third: compute the losses, based on outputs and labels + outputs_loss = {} + auxiliary_outputs = None + outputs_loss["logits"] = logits + outputs_loss["pred_boxes"] = pred_boxes + outputs_loss["enc_outputs"] = { + "logits": enc_outputs_class, + "pred_boxes": enc_outputs_coord, + } + if use_aux_loss: + auxiliary_outputs = _set_aux_loss(outputs_class, outputs_coord) + outputs_loss["auxiliary_outputs"] = auxiliary_outputs + loss_dict = criterion(outputs_loss, labels) + # Fourth: compute total loss, as a weighted sum of the various losses + weight_dict = {"loss_ce": 1, "loss_bbox": 5} + weight_dict["loss_giou"] = 2 + if use_aux_loss: + aux_weight_dict = {} + for i in range(num_decoder_layers - 1): + aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + enc_weight_dict = {k + "_enc": v for k, v in weight_dict.items()} + weight_dict.update(enc_weight_dict) + loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict if k in weight_dict) + return loss, loss_dict, auxiliary_outputs diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 848d44280e..2ca0ecee4a 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -18,6 +18,7 @@ from ...utils import load_pretrained_params from .base import _LWDETR, LWDETRPostProcessor from .layers import LWDETRDecoder, LWDETRHead, LWDETRMultiscaleDeformableAttention, MultiScaleProjector +from .loss import lw_detr_for_object_detection_loss __all__ = ["LWDETR", "lw_detr_s", "lw_detr_m"] @@ -153,8 +154,8 @@ def __init__( self, feat_extractor: LWDETRBackbone, class_names: list[str], - score_thresh: float = 0.3, - iou_thresh: float = 0.5, + score_thresh: float = 0.0, + iou_thresh: float = 0.1, d_model: int = 256, num_queries: int = 130, group_detr: int = 1, @@ -171,7 +172,7 @@ def __init__( super().__init__() self.class_names: list[str] = class_names - self.num_classes = len(self.class_names) + 1 # +1 for background class + self.num_classes = len(self.class_names) + 1 # +1 for background class (NO OBJECT) self.cfg = cfg self.exportable = exportable self.assume_straight_pages = assume_straight_pages @@ -181,20 +182,13 @@ def __init__( self.group_detr = group_detr self.num_queries = num_queries self.d_model = d_model + self.dec_layers = dec_layers - self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 6) - # Initialize angle to (sin=0, cos=1) - with torch.no_grad(): - self.reference_point_embed.weight[:, 4] = 0.0 # sinθ - self.reference_point_embed.weight[:, 5] = 1.0 # cosθ - + self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 4) self.query_feat = nn.Embedding(self.num_queries * self.group_detr, self.d_model) - self.class_embed = nn.Linear(self.d_model, self.num_classes) - self.bbox_embed = LWDETRHead(self.d_model, self.d_model, 6, num_layers=3) - self.decoder = LWDETRDecoder( - num_layers=dec_layers, + num_layers=self.dec_layers, d_model=d_model, sa_num_heads=sa_num_heads, ca_num_heads=ca_num_heads, @@ -202,19 +196,21 @@ def __init__( dec_n_points=dec_n_points, group_detr=group_detr, dropout_prob=dropout_prob, - bbox_embed=self.bbox_embed, ) self.enc_output = nn.ModuleList([nn.Linear(self.d_model, self.d_model) for _ in range(self.group_detr)]) self.enc_output_norm = nn.ModuleList([nn.LayerNorm(self.d_model) for _ in range(self.group_detr)]) self.enc_out_bbox_embed = nn.ModuleList([ - LWDETRHead(self.d_model, self.d_model, 6, num_layers=3) for _ in range(self.group_detr) + LWDETRHead(self.d_model, self.d_model, 4, num_layers=3) for _ in range(self.group_detr) ]) self.enc_out_class_embed = nn.ModuleList([ nn.Linear(self.d_model, self.num_classes) for _ in range(self.group_detr) ]) + self.class_embed = nn.Linear(self.d_model, self.num_classes) + self.bbox_embed = LWDETRHead(self.d_model, self.d_model, 4, num_layers=3) + self.postprocessor = LWDETRPostProcessor( num_classes=self.num_classes, score_thresh=score_thresh, @@ -226,60 +222,34 @@ def __init__( # Don't override the initialization of the backbone if n.startswith("feat_extractor."): continue - - if isinstance(m, nn.Linear): - nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): - if hasattr(m, "weight") and m.weight is not None: - nn.init.ones_(m.weight) - if hasattr(m, "bias") and m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.Embedding): - nn.init.normal_(m.weight, std=0.02) - elif isinstance(m, LWDETRMultiscaleDeformableAttention): + if isinstance(m, LWDETRMultiscaleDeformableAttention): nn.init.constant_(m.sampling_offsets.weight, 0.0) - - thetas = torch.arange(m.n_heads, dtype=torch.float32) * (2.0 * math.pi / m.n_heads) + thetas = torch.arange(m.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / m.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) .view(m.n_heads, 1, 1, 2) .repeat(1, m.n_levels, m.n_points, 1) ) - for i in range(m.n_points): grid_init[:, :, i, :] *= i + 1 - with torch.no_grad(): m.sampling_offsets.bias.copy_(grid_init.view(-1)) - nn.init.constant_(m.attention_weights.weight, 0.0) nn.init.constant_(m.attention_weights.bias, 0.0) - nn.init.xavier_uniform_(m.value_proj.weight) - nn.init.zeros_(m.value_proj.bias) - + nn.init.constant_(m.value_proj.bias, 0.0) nn.init.xavier_uniform_(m.output_proj.weight) - nn.init.zeros_(m.output_proj.bias) - - if isinstance(m, nn.Linear) and m.out_features == self.num_classes: + nn.init.constant_(m.output_proj.bias, 0.0) + if hasattr(m, "refpoint_embed") and m.refpoint_embed is not None: + nn.init.constant_(m.refpoint_embed.weight, 0) + if hasattr(m, "class_embed") and m.class_embed is not None: prior_prob = 0.01 bias_value = -math.log((1 - prior_prob) / prior_prob) - if m.bias is not None: - nn.init.constant_(m.bias, bias_value) - if isinstance(m, LWDETRHead): - last = m.layers[-1] - if isinstance(last, nn.Linear): - nn.init.zeros_(last.weight) - nn.init.zeros_(last.bias) - if last.bias.shape[0] == 6: - nn.init.constant_(last.bias[5], 1.0) + nn.init.constant_(m.class_embed.bias, bias_value) + if hasattr(m, "bbox_embed") and m.bbox_embed is not None: + nn.init.constant_(m.bbox_embed.layers[-1].weight, 0) + nn.init.constant_(m.bbox_embed.layers[-1].bias, 0) def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: """Load pretrained parameters onto the model @@ -290,79 +260,39 @@ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: """ load_pretrained_params(self, path_or_url, **kwargs) - def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: - """Refine bounding boxes by applying the predicted deltas to the reference points. - The reference points are in the format (cx, cy, w, h, sinθ, cosθ), and the deltas are in the same format. - The refined boxes are computed as follows: - - cx' = cx + delta_cx * w - cy' = cy + delta_cy * h - w' = w * exp(delta_w) - h' = h * exp(delta_h) - sinθ' = sinθ * cosΔ + cosθ * sinΔ - cosθ' = cosθ * cosΔ - sinθ * sinΔ - - Args: - reference_points: (N, S, 6) tensor containing the reference points - deltas: (N, S, 6) tensor containing the predicted deltas - - Returns: - refined_boxes: (N, S, 6) tensor containing the refined bounding boxes - """ + def refine_bboxes(self, reference_points, deltas): reference_points = reference_points.to(deltas.device) - # center - cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] - # size - wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=2.0).exp() * reference_points[..., 2:4] - # rotation - delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) - sin_delta = delta_rot[..., 0:1] - cos_delta = delta_rot[..., 1:2] - sin_ref = reference_points[..., 4:5] - cos_ref = reference_points[..., 5:6] - - # compose rotations - sin_new = sin_ref * cos_delta + cos_ref * sin_delta - cos_new = cos_ref * cos_delta - sin_ref * sin_delta - rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) - - return torch.cat((cxcy, wh, rot), dim=-1) - - def get_valid_ratio(self, mask: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: - """Get the valid ratio of all feature maps. - - Args: - mask: (N, H, W) binary tensor containing 1 on padded pixels - dtype: the desired data type of the output tensor + new_reference_points_cxcy = deltas[..., :2] * reference_points[..., 2:] + reference_points[..., :2] + new_reference_points_wh = deltas[..., 2:].exp() * reference_points[..., 2:] + new_reference_points = torch.cat((new_reference_points_cxcy, new_reference_points_wh), -1) + return new_reference_points - Returns: - valid_ratio: (N, 2) tensor containing the valid ratio of width and height for each image in the batch - """ + def get_valid_ratio(self, mask, dtype=torch.float32): + """Get the valid ratio of all feature maps.""" _, height, width = mask.shape - valid_height = torch.sum(~mask[:, :, 0], 1) - valid_width = torch.sum(~mask[:, 0, :], 1) + valid_height = torch.sum(mask[:, :, 0], 1) + valid_width = torch.sum(mask[:, 0, :], 1) valid_ratio_height = valid_height.to(dtype) / height valid_ratio_width = valid_width.to(dtype) / width valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1) return valid_ratio - def gen_encoder_output_proposals( - self, enc_output: torch.Tensor, padding_mask: torch.Tensor, spatial_shapes: list[tuple[int, int]] - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes): """Generate the encoder output proposals from encoded enc_output. Args: - enc_output: Output of the encoder - padding_mask: Padding mask for `enc_output` - spatial_shapes: Spatial shapes of the feature maps + enc_output (Tensor[batch_size, sequence_length, hidden_size]): Output of the encoder. + padding_mask (Tensor[batch_size, sequence_length]): Padding mask for `enc_output`. + spatial_shapes (list[tuple[int, int]]): Spatial shapes of the feature maps. Returns: - A tuple of feature map and bbox prediction. - - object_query: Object query features. Later used to directly predict a bounding box. - - output_proposals: Normalized proposals in [0, 1] space. - Invalid positions (padding or out-of-bounds) are filled with 0. - - invalid_mask: Boolean mask that is True for invalid positions - (padded pixels or proposals whose coordinates fall outside (0.01, 0.99)). + `tuple(torch.FloatTensor)`: A tuple of feature map and bbox prediction. + - object_query (Tensor[batch_size, sequence_length, hidden_size]): Object query features. Later used to + directly predict a bounding box. (without the need of a decoder) + - output_proposals (Tensor[batch_size, sequence_length, 4]): Normalized proposals in [0, 1] space. + Invalid positions (padding or out-of-bounds) are filled with 0. + - invalid_mask (Tensor[batch_size, sequence_length, 1]): Boolean mask that is True for invalid positions + (padded pixels or proposals whose coordinates fall outside (0.01, 0.99)). """ batch_size = enc_output.shape[0] proposals = [] @@ -394,17 +324,11 @@ def gen_encoder_output_proposals( scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2) grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale width_height = torch.ones_like(grid) * 0.05 * (2.0**level) - # add default rotation (sin=0, cos=1) - sin = torch.zeros_like(grid[..., :1]) - cos = torch.ones_like(grid[..., :1]) - proposal = torch.cat((grid, width_height, sin, cos), -1).view(batch_size, -1, 6) + proposal = torch.cat((grid, width_height), -1).view(batch_size, -1, 4) proposals.append(proposal) _cur += height * width output_proposals = torch.cat(proposals, 1) - - spatial_valid = ((output_proposals[..., :4] > 0.01) & (output_proposals[..., :4] < 0.99)).all(-1, keepdim=True) - output_proposals_valid = spatial_valid - invalid_mask = padding_mask | ~output_proposals_valid.squeeze(-1) + output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) invalid_mask = padding_mask.unsqueeze(-1) | ~output_proposals_valid output_proposals = output_proposals.masked_fill(invalid_mask, float(0)) @@ -455,6 +379,7 @@ def forward( mask_flatten_list.append(mask) source_flatten = torch.cat(source_flatten_list, 1) mask_flatten = torch.cat(mask_flatten_list, 1) + spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device) valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in feats_masks], 1) tgt = query_feat.unsqueeze(0).expand(batch_size, -1, -1) @@ -467,51 +392,73 @@ def forward( group_detr = self.group_detr if self.training else 1 topk = self.num_queries - topk_coords_logits_list: list[torch.Tensor] = [] - - # encoder predictions for auxiliary losses - all_group_enc_logits: list[torch.Tensor] = [] - all_group_enc_coords: list[torch.Tensor] = [] + topk_coords_logits = [] + topk_coords_logits_undetach = [] + object_query_undetach = [] for group_id in range(group_detr): group_object_query = self.enc_output[group_id](object_query_embedding) group_object_query = self.enc_output_norm[group_id](group_object_query) group_enc_outputs_class = self.enc_out_class_embed[group_id](group_object_query) - all_group_enc_logits.append(group_enc_outputs_class) - - group_enc_outputs_class_masked = group_enc_outputs_class.masked_fill(invalid_mask, float("-inf")) - + group_enc_outputs_class = group_enc_outputs_class.masked_fill(invalid_mask, float("-inf")) group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query) group_enc_outputs_coord = self.refine_bboxes(output_proposals, group_delta_bbox) - all_group_enc_coords.append(group_enc_outputs_coord) - - group_topk_proposals = torch.topk(group_enc_outputs_class_masked.max(-1)[0], topk, dim=1)[1] - + group_topk_proposals = torch.topk(group_enc_outputs_class.max(-1)[0], topk, dim=1)[1] group_topk_coords_logits_undetach = torch.gather( group_enc_outputs_coord, 1, - group_topk_proposals.unsqueeze(-1).repeat(1, 1, 6), + group_topk_proposals.unsqueeze(-1).repeat(1, 1, 4), ) - group_topk_coords_logits = group_topk_coords_logits_undetach - topk_coords_logits_list.append(group_topk_coords_logits) + group_topk_coords_logits = group_topk_coords_logits_undetach.detach() + group_object_query_undetach = torch.gather( + group_object_query, 1, group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model) + ) + + topk_coords_logits.append(group_topk_coords_logits) + topk_coords_logits_undetach.append(group_topk_coords_logits_undetach) + object_query_undetach.append(group_object_query_undetach) + + topk_coords_logits = torch.cat(topk_coords_logits, 1) + topk_coords_logits_undetach = torch.cat(topk_coords_logits_undetach, 1) + object_query_undetach = torch.cat(object_query_undetach, 1) + + enc_outputs_class_logits = object_query_undetach + enc_outputs_boxes_logits = topk_coords_logits_undetach - topk_coords_logits = torch.cat(topk_coords_logits_list, 1) reference_points = self.refine_bboxes(topk_coords_logits, reference_points) - last_hidden_states, intermediate, intermediate_reference_points = self.decoder( + init_reference_points = reference_points + last_hidden_state, intermediate, intermediate_reference_points = self.decoder( inputs_embeds=tgt, reference_points=reference_points, + spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list, valid_ratios=valid_ratios, encoder_hidden_states=source_flatten, + encoder_attention_mask=mask_flatten, ) - logits = self.class_embed(last_hidden_states) - pred_boxes_delta = self.bbox_embed(last_hidden_states) + logits = self.class_embed(last_hidden_state) + pred_boxes_delta = self.bbox_embed(last_hidden_state) pred_boxes = self.refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta) + enc_outputs_class_logits_list = enc_outputs_class_logits.split(self.num_queries, dim=1) + pred_class = [] + group_detr = self.group_detr if self.training else 1 + for group_index in range(group_detr): + group_pred_class = self.enc_out_class_embed[group_index](enc_outputs_class_logits_list[group_index]) + pred_class.append(group_pred_class) + enc_outputs_class_logits = torch.cat(pred_class, dim=1) + + if target is not None: + outputs_class, outputs_coord = None, None + intermediate_hidden_states = intermediate + outputs_coord_delta = self.bbox_embed(intermediate_hidden_states) + outputs_coord = self.refine_bboxes(intermediate_reference_points, outputs_coord_delta) + outputs_class = self.class_embed(intermediate_hidden_states) + out: dict[str, Any] = {} if self.exportable: @@ -533,224 +480,44 @@ def _postprocess(logits, boxes): if target is not None: # Build target processed_targets = self.build_target(target, self.class_names) - - # Main loss from final decoder layer (group DETR) - split_logits = logits.chunk(group_detr, dim=1) - split_boxes = pred_boxes.chunk(group_detr, dim=1) - - main_loss: float | torch.Tensor = 0.0 - for g_logits, g_boxes in zip(split_logits, split_boxes): - main_loss += self.compute_loss(g_logits, g_boxes, processed_targets) - loss = main_loss / group_detr - - # Auxiliary losses from intermediate decoder layers - for i in range(intermediate.shape[0] - 1): - aux_logits = self.class_embed(intermediate[i]) - aux_boxes_delta = self.bbox_embed(intermediate[i]) - aux_boxes = self.refine_bboxes(intermediate_reference_points[i], aux_boxes_delta) - - split_aux_logits = aux_logits.chunk(group_detr, dim=1) - split_aux_boxes = aux_boxes.chunk(group_detr, dim=1) - - aux_loss: float | torch.Tensor = 0.0 - for g_logits, g_boxes in zip(split_aux_logits, split_aux_boxes): - aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets) - loss += 0.5 * (aux_loss / group_detr) - - # Auxiliary losses for encoder proposals - enc_loss: float | torch.Tensor = 0.0 - for group_logits, group_coords in zip(all_group_enc_logits, all_group_enc_coords): - enc_loss += self.compute_loss(group_logits, group_coords, processed_targets) - loss += 0.1 * (enc_loss / group_detr) - - out["loss"] = loss + out["loss"] = self.compute_loss( + logits, + processed_targets, + pred_boxes, + outputs_class, + outputs_coord, + enc_outputs_class_logits, + enc_outputs_boxes_logits, + ) return out def compute_loss( - self, logits: torch.Tensor, pred_boxes: torch.Tensor, targets: list[dict[str, np.ndarray]] - ) -> torch.Tensor: - """ - Compute the loss for LW-DETR. The loss consists of three components: - classification loss, box regression loss, and rotation loss. - The classification loss is a cross-entropy loss between the predicted class logits and the target classes. - The box regression loss is a Smooth L1 loss between the predicted boxes and the target boxes, - computed only on the positive samples. - The rotation loss is computed as 1 - cosine similarity between the predicted rotation and the target rotation, - averaged over the positive samples. - The positive samples are determined using a SimOTA-like assignment strategy, where for each ground truth box, - we select the top-k queries with the lowest cost - (combination of classification cost, box regression cost, and rotation cost). - - Args: - logits: (B, Q, C) tensor containing the predicted class logits for each query - pred_boxes: (B, Q, 6) tensor containing the predicted boxes for each query - targets: list of dictionaries where each dictionary corresponds to a sample and has keys corresponding - to class names and values corresponding to lists of boxes in either polygon format (4, 2) - or bounding box format (4,) (xmin, ymin, xmax, ymax) - - Returns: - loss: the computed loss value - """ - - def rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - Convert rotated boxes in (cx, cy, w, h, sinθ, cosθ) format to Gaussian distribution parameters - (mean and covariance). - """ - cxcy = boxes[..., :2] - - w = boxes[..., 2].clamp(min=1e-6) - h = boxes[..., 3].clamp(min=1e-6) - - sin = boxes[..., 4] - cos = boxes[..., 5] - - R = torch.stack( - [ - torch.stack([cos, -sin], dim=-1), - torch.stack([sin, cos], dim=-1), - ], - dim=-2, - ) - - sx = (w / 2) ** 2 - sy = (h / 2) ** 2 - - S = torch.zeros((*boxes.shape[:-1], 2, 2), device=boxes.device) - - S[..., 0, 0] = sx - S[..., 1, 1] = sy - - covariance = R @ S @ R.transpose(-1, -2) - return cxcy, covariance - - def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Tensor: - """Compute the ProbIoU loss between predicted boxes and target boxes.""" - mu1, sigma1 = rotated_boxes_to_gaussian(pred_boxes) - mu2, sigma2 = rotated_boxes_to_gaussian(tgt_boxes) - - delta = (mu1 - mu2).unsqueeze(-1) - sigma = (sigma1 + sigma2) * 0.5 - - eps = 1e-6 - eye = torch.eye(2, device=sigma.device) * eps - sigma_safe = sigma + eye - sigma1_safe = sigma1 + eye - sigma2_safe = sigma2 + eye - - sigma_inv = torch.linalg.inv(sigma_safe) - - mahalanobis = (delta.transpose(-1, -2) @ sigma_inv @ delta).squeeze(-1).squeeze(-1) - - det_sigma = torch.linalg.det(sigma_safe).clamp(min=eps) - det_sigma1 = torch.linalg.det(sigma1_safe).clamp(min=eps) - det_sigma2 = torch.linalg.det(sigma2_safe).clamp(min=eps) - - bhattacharyya = 0.125 * mahalanobis + 0.5 * torch.log(det_sigma / torch.sqrt(det_sigma1 * det_sigma2)) - - probiou = torch.exp(-bhattacharyya) - return 1 - probiou - - device = logits.device - B, Q, C = logits.shape - - total_cls = torch.tensor(0.0, device=device) - total_box = torch.tensor(0.0, device=device) - total_rot = torch.tensor(0.0, device=device) - - for b in range(B): - pred_logits = logits[b] - pred_boxes_b = pred_boxes[b] - - tgt_boxes = torch.as_tensor( - targets[b]["boxes"], - device=device, - dtype=pred_boxes.dtype, - ) - tgt_cls = torch.as_tensor( - targets[b]["labels"], - device=device, - dtype=torch.long, - ) - - num_gt = len(tgt_cls) - - pred_rot = F.normalize(pred_boxes_b[:, 4:6], dim=-1) - tgt_rot = F.normalize(tgt_boxes[:, 4:6], dim=-1) - - with torch.no_grad(): - cls_prob = pred_logits.sigmoid() - alpha = 0.25 - gamma = 2.0 - - neg_cost = (1 - alpha) * (cls_prob**gamma) * (-(1 - cls_prob + 1e-8).log()) - - pos_cost = alpha * ((1 - cls_prob) ** gamma) * (-(cls_prob + 1e-8).log()) - - cost_cls = pos_cost[:, tgt_cls] - neg_cost[:, tgt_cls] - cost_l1 = torch.cdist( - pred_boxes_b[:, :4], - tgt_boxes[:, :4], - p=1, - ) - cost_rot = 1 - (pred_rot @ tgt_rot.T).abs() - total_cost = 5.0 * cost_cls + 2.0 * cost_l1 + 1.0 * cost_rot - matching_matrix = torch.zeros( - (Q, num_gt), - dtype=torch.bool, - device=device, - ) - - center_dist = torch.cdist( - pred_boxes_b[:, :2], - tgt_boxes[:, :2], - p=2, - ) - - iou_like = torch.exp(-center_dist) - dynamic_k = iou_like.sum(0).int().clamp(min=1, max=10) - - for gt_idx in range(num_gt): - _, candidate_idx = torch.topk(-total_cost[:, gt_idx], k=int(dynamic_k[gt_idx].item())) - matching_matrix[candidate_idx, gt_idx] = True - - # resolve duplicate matches - multiple_match_mask = matching_matrix.sum(1) > 1 - - if multiple_match_mask.any(): - duplicate_idx = multiple_match_mask.nonzero(as_tuple=False).squeeze(1) - min_cost_idx = total_cost[duplicate_idx].argmin(dim=1) - # Set all matches to False for the duplicate indices, - # then set the match with the lowest cost to True - matching_matrix[duplicate_idx] = False - matching_matrix[duplicate_idx, min_cost_idx] = True - - pos_idx, gt_indices = matching_matrix.nonzero(as_tuple=True) - - target_classes = torch.zeros((Q,), dtype=torch.long, device=device) - - # background = 0 - target_classes[pos_idx] = tgt_cls[gt_indices] - - total_cls += F.cross_entropy(pred_logits, target_classes) - - if len(pos_idx) == 0: - continue - - pred_sel = pred_boxes_b[pos_idx] - tgt_sel = tgt_boxes[gt_indices] - # L1 loss on (cx, cy, w, h) - l1_loss = F.smooth_l1_loss(pred_sel[:, :4], tgt_sel[:, :4]) - # ProbIoU loss on the whole box (including rotation) - probiou_loss = _probiou_loss(pred_sel, tgt_sel).mean() - total_box += 2.0 * l1_loss + 0.5 * probiou_loss - # Rotation loss - cos_sim = (pred_rot[pos_idx] * tgt_rot[gt_indices]).sum(-1).abs() - rot_loss = (1 - cos_sim).mean() - total_rot += 0.5 * rot_loss - # Average the loss over the batch - return (total_cls + total_box + total_rot) / B + self, + logits, + targets, + pred_boxes, + outputs_class, + outputs_coord, + enc_outputs_class_logits, + enc_outputs_boxes_logits, + ): + + loss_calc = lw_detr_for_object_detection_loss( + logits=logits, + device=logits.device, + labels=targets, + pred_boxes=pred_boxes, + outputs_class=outputs_class, + outputs_coord=outputs_coord, + enc_outputs_class=enc_outputs_class_logits, + enc_outputs_coord=enc_outputs_boxes_logits, + use_aux_loss=True, + group_detr=self.group_detr, + num_decoder_layers=self.dec_layers, + num_labels=self.num_classes, + ) + return loss_calc[0] def _lw_detr( From 4c5629c73d34843b096d885479b7656507a01405 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 28 May 2026 11:24:14 +0200 Subject: [PATCH 02/15] rot-check --- doctr/models/layout/lw_detr/base.py | 179 +++--- doctr/models/layout/lw_detr/layers/pytorch.py | 235 +++++--- doctr/models/layout/lw_detr/loss.py | 527 ------------------ doctr/models/layout/lw_detr/pytorch.py | 466 ++++++++++++---- 4 files changed, 626 insertions(+), 781 deletions(-) delete mode 100644 doctr/models/layout/lw_detr/loss.py diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index f7c926b13b..1a7cb452cf 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -9,6 +9,7 @@ import numpy as np from doctr.models.core import BaseModel +from doctr.utils import order_points __all__ = ["_LWDETR", "LWDETRPostProcessor"] @@ -38,36 +39,29 @@ def __init__( self.topk = topk self.assume_straight_pages = assume_straight_pages - def _decode_boxes(self, boxes: np.ndarray) -> np.ndarray: - """ - Decode cxcywh -> polygons (axis-aligned rectangles) + def _decode_boxes(self, boxes: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Decode the predicted boxes from OBB format to polygon format + + Args: + boxes: array of predicted boxes in OBB format (N, 6) (cx, cy, w, h, sin(theta), cos(theta)) + + Returns: + tuple of (polys, angles) where polys is an array of decoded polygons (N, 4, 2) + and angles is an array of angles in radians (N,) """ - cx = boxes[:, 0] - cy = boxes[:, 1] - w = boxes[:, 2] - h = boxes[:, 3] + cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] + sin, cos = boxes[:, 4], boxes[:, 5] - polys = [] + angles = np.arctan2(sin, cos) + polys = [] for i in range(len(boxes)): - x1 = cx[i] - w[i] / 2 - y1 = cy[i] - h[i] / 2 - x2 = cx[i] + w[i] / 2 - y2 = cy[i] + h[i] / 2 - - poly = np.array( - [ - [x1, y1], - [x2, y1], - [x2, y2], - [x1, y2], - ], - dtype=np.float32, - ) + rect = ((float(cx[i]), float(cy[i])), (float(w[i]), float(h[i])), float(np.degrees(angles[i]))) + poly = order_points(cv2.boxPoints(rect)) polys.append(poly) - return np.asarray(polys, dtype=np.float32) + return np.asarray(polys, dtype=np.float32), angles def _iou(self, poly1: np.ndarray, poly2: np.ndarray) -> float: """Compute the IoU between two polygons @@ -142,23 +136,30 @@ def _nms(self, polys: np.ndarray, scores: np.ndarray, labels: np.ndarray) -> lis suppressed[j] = True return keep - def __call__(self, logits: np.ndarray, boxes: np.ndarray): - + def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int], np.ndarray, list[float]]]: logits = np.asarray(logits) boxes = np.asarray(boxes) - results = [] + results: list[tuple[list[int], np.ndarray, list[float]]] = [] for b in range(boxes.shape[0]): + # Convert logits to probabilities and get scores and labels exp = np.exp(logits[b] - logits[b].max(axis=-1, keepdims=True)) prob = exp / exp.sum(axis=-1, keepdims=True) - prob_fg = prob[:, :-1] - scores = prob_fg.max(axis=-1) - labels = prob_fg.argmax(axis=-1) + scores = prob.max(axis=-1) + labels = prob.argmax(axis=-1) + # treat background as invalid prediction + bg = self.num_classes - 1 + valid = labels != bg + + scores = scores * valid + + # Keep only topk predictions before NMS if self.topk is not None and len(scores) > self.topk: - idxs = np.argsort(scores)[::-1][: self.topk] + idxs = np.argpartition(-scores, self.topk)[: self.topk] + idxs = idxs[np.argsort(-scores[idxs])] else: idxs = np.arange(len(scores)) @@ -172,36 +173,44 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray): scores_b = scores_b[mask] labels_b = labels_b[mask] - polys = self._decode_boxes(bboxes) if len(bboxes) > 0 else np.zeros((0, 4, 2), dtype=np.float32) + polys, _ = ( + self._decode_boxes(bboxes) + if len(bboxes) > 0 + else ( + np.zeros((0, 4, 2), dtype=np.float32), + np.zeros((0,), dtype=np.float32), + ) + ) keep = self._nms(polys, scores_b, labels_b) if len(polys) > 0 else [] - final_boxes = [] final_labels = [] + final_boxes = [] final_scores = [] for idx in keep: - poly = polys[idx] - + poly = polys[idx].reshape(-1).tolist() if self.assume_straight_pages: - # 👉 COCO-style axis aligned box from polygon - xmin = float(np.min(poly[:, 0])) - xmax = float(np.max(poly[:, 0])) - ymin = float(np.min(poly[:, 1])) - ymax = float(np.max(poly[:, 1])) - + x_coords = poly[0::2] + y_coords = poly[1::2] + xmin, xmax = min(x_coords), max(x_coords) + ymin, ymax = min(y_coords), max(y_coords) final_boxes.append([xmin, ymin, xmax, ymax]) else: - final_boxes.append(poly.reshape(-1).tolist()) + final_boxes.append(poly) final_labels.append(int(labels_b[idx])) final_scores.append(float(scores_b[idx])) - final_boxes_arr = np.asarray(final_boxes, dtype=np.float32) + final_boxes_arr = ( + np.asarray(final_boxes, dtype=np.float32).reshape(-1, 4, 2) + if not self.assume_straight_pages + else np.asarray(final_boxes, dtype=np.float32).reshape(-1, 4) + ) results.append(( final_labels, - final_boxes_arr, # <- NOW ALWAYS CLEAN FORMAT + final_boxes_arr, final_scores, )) @@ -218,12 +227,55 @@ def build_target( target: list[dict[str, np.ndarray]], class_names: list[str], ) -> list[dict[str, Any]]: - """ - Build targets in COCO format: [xmin, ymin, w, h] + """Build the target for LW-DETR training + + Args: + target: list of dictionaries where each dictionary corresponds to a sample and has keys corresponding + to class names and values corresponding to lists of boxes in either polygon format (4, 2) + or bounding box format (4,) (xmin, ymin, xmax, ymax) + class_names: list of class names + + Returns: + list of dictionaries with keys "boxes" and "labels" where "boxes" is an array of shape (num_boxes, 6) + containing the box parameters in OBB format (cx, cy, w, h, sin(theta), cos(theta)) + and "labels" is an array of shape (num_boxes,) containing the class labels """ targets = [] class_to_id = {name: i for i, name in enumerate(class_names)} + def _quad_to_obb(poly: np.ndarray): + poly = np.asarray(poly, dtype=np.float32) + + # Center point is simply the average of the relative vertices + cx, cy = np.mean(poly, axis=0) + + edges = np.stack([ + poly[1] - poly[0], + poly[2] - poly[1], + poly[3] - poly[2], + poly[0] - poly[3], + ]) + + lengths = np.linalg.norm(edges, axis=1) + i = np.argmax(lengths) + dx, dy = edges[i] + + theta = np.arctan2(dy, dx) + + # Width and height remain cleanly in relative coordinate space [0, 1] + w = np.mean([lengths[i], lengths[(i + 2) % 4]]) + h = np.mean([lengths[(i + 1) % 4], lengths[(i + 3) % 4]]) + + # Enforce strict unit-length normal vectors for rotation + sin_t = np.sin(theta) + cos_t = np.cos(theta) + norm = np.sqrt(sin_t**2 + cos_t**2) + 1e-8 + + return np.array( + [cx, cy, w, h, sin_t / norm, cos_t / norm], + dtype=np.float32, + ) + def to_quad(box: np.ndarray): box = np.asarray(box, dtype=np.float32) if box.shape == (4,): @@ -235,19 +287,6 @@ def to_quad(box: np.ndarray): return box.astype(np.float32) raise ValueError(f"Unsupported box shape: {box.shape}") - def quad_to_coco(poly: np.ndarray) -> np.ndarray: - xmin = float(np.min(poly[:, 0])) - xmax = float(np.max(poly[:, 0])) - ymin = float(np.min(poly[:, 1])) - ymax = float(np.max(poly[:, 1])) - - w = xmax - xmin - h = ymax - ymin - cx = xmin + w / 2.0 - cy = ymin + h / 2.0 - - return np.array([cx, cy, w, h], dtype=np.float32) - for sample in target: boxes_all = [] labels_all = [] @@ -262,29 +301,31 @@ def quad_to_coco(poly: np.ndarray) -> np.ndarray: if boxes.ndim == 1: boxes = boxes[None, :] - # sanity check normalized coords + # Sanity check: coordinates must be in [0, 1] normalized space. + # Values > 1.5 almost certainly indicate pixel coordinates were passed in. flat = boxes.ravel() coord_vals = flat[flat > 0] if len(coord_vals) > 0 and coord_vals.max() > 1.5: - raise ValueError("build_target expects normalized [0,1] coordinates.") + raise ValueError( + f"build_target expects normalized [0, 1] box coordinates, " + f"but found values up to {coord_vals.max():.1f} for class '{class_name}'. " + f"Divide your coordinates by image width/height before calling build_target." + ) for box in boxes: poly = to_quad(box) - coco_box = quad_to_coco(poly) + obb = _quad_to_obb(poly) - if coco_box[2] <= 1e-5 or coco_box[3] <= 1e-5: + # filter out degenerate boxes + if obb[2] <= 1e-5 or obb[3] <= 1e-5: continue - boxes_all.append(coco_box) + boxes_all.append(obb) labels_all.append(cls_id) - if len(boxes_all) == 0: - boxes_all = np.zeros((0, 4), dtype=np.float32) - labels_all = np.zeros((0,), dtype=np.int64) - targets.append({ - "boxes": np.asarray(boxes_all, dtype=np.float32), # (N, 4) - "class_labels": np.asarray(labels_all, dtype=np.int64), + "boxes": np.asarray(boxes_all, dtype=np.float32), + "labels": np.asarray(labels_all, dtype=np.int64), }) return targets diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py index 9f6d75ab30..bb6b503e97 100644 --- a/doctr/models/layout/lw_detr/layers/pytorch.py +++ b/doctr/models/layout/lw_detr/layers/pytorch.py @@ -113,7 +113,7 @@ def forward( hidden_states_original = hidden_states if position_embeddings is not None: - hidden_states = hidden_states + position_embeddings + hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings if self.training: # at training, we use group detr technique to @@ -238,7 +238,6 @@ def forward( encoder_hidden_states=None, position_embeddings: torch.Tensor | None = None, reference_points=None, - spatial_shapes=None, spatial_shapes_list=None, ) -> tuple[torch.Tensor, torch.Tensor]: # add position embeddings to the hidden states before projecting to queries and keys @@ -251,7 +250,7 @@ def forward( value = self.value_proj(encoder_hidden_states) if attention_mask is not None: # we invert the attention_mask - value = value.masked_fill(~attention_mask[..., None], float(0)) + value = value.masked_fill(attention_mask[..., None], float(0)) value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) sampling_offsets = self.sampling_offsets(hidden_states).view( batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 @@ -264,19 +263,35 @@ def forward( ) # batch_size, num_queries, n_heads, n_levels, n_points, 2 num_coordinates = reference_points.shape[-1] - if num_coordinates == 2: - offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) - sampling_locations = ( - reference_points[:, :, None, :, None, :] - + sampling_offsets / offset_normalizer[None, None, None, :, None, :] - ) - elif num_coordinates == 4: + + if num_coordinates == 4: sampling_locations = ( reference_points[:, :, None, :, None, :2] + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 ) + elif num_coordinates == 6: + ref = reference_points[:, :, None, :, None, :] # (..., 6) + + center = ref[..., :2] # (cx, cy) + wh = ref[..., 2:4] # (w, h) + sin = ref[..., 4:5] # sinθ + cos = ref[..., 5:6] # cosθ + + # normalize offsets + offsets = sampling_offsets / self.n_points * wh * 0.5 + + dx = offsets[..., 0:1] + dy = offsets[..., 1:2] + + # rotate offsets + dx_rot = dx * cos - dy * sin + dy_rot = dx * sin + dy * cos + + rotated_offsets = torch.cat([dx_rot, dy_rot], dim=-1) + + sampling_locations = center + rotated_offsets else: - raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") + raise ValueError(f"Last dim of reference_points must be 4 or 6, but got {reference_points.shape[-1]}") output = self.attn( value, @@ -346,7 +361,6 @@ def forward( hidden_states: torch.Tensor, position_embeddings: torch.Tensor | None = None, reference_points: torch.Tensor | None = None, - spatial_shapes: torch.Tensor | None = None, spatial_shapes_list: list[tuple] | None = None, encoder_hidden_states: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, @@ -365,7 +379,6 @@ def forward( encoder_hidden_states=encoder_hidden_states, position_embeddings=position_embeddings, reference_points=reference_points, - spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list, ) cross_attention_output = F.dropout(cross_attention_output, p=self.dropout, training=self.training) @@ -380,40 +393,45 @@ def forward( # function to generate sine positional embedding for 4d coordinates # Borrowed from: https://github.com/Atten4Vis/LW-DETR/blob/main/models/transformer.py -def encode_sinusoidal_position_embedding( - pos_tensor: torch.Tensor, - num_pos_feats: int = 128, - temperature: int = 10000, -) -> torch.Tensor: - """Sinusoidal position embeddings from normalized anchor coordinates. - - Each coordinate in `pos_tensor` is independently encoded with ``num_pos_feats`` - interleaved sin/cos components; per-coordinate embeddings are concatenated. - Handles 2-D ``(x, y)`` and N-D ``(x, y, w, h)`` inputs. For 2-D+ inputs the - x and y embeddings are swapped to follow the DETR ``[pos_y, pos_x, ...]`` convention. - - Args: - pos_tensor: Normalized coordinates in ``[0, 1]``, shape ``(..., n_coords)``. - num_pos_feats: Embedding dimension per coordinate. - temperature: Base for the frequency decay. - - Returns: - Tensor of shape ``(..., n_coords * num_pos_feats)``, same dtype as input. +def gen_sine_position_embeddings(pos_tensor: torch.Tensor, hidden_size: int = 256) -> torch.Tensor: + """ + This function computes position embeddings using sine and cosine functions from the input positional tensor, + which has a shape of (batch_size, num_queries, 4). + The last dimension of `pos_tensor` represents the following coordinates: + - 0: x-coord + - 1: y-coord + - 2: width + - 3: height + + The output shape is (batch_size, num_queries, 512), + where final dim (hidden_size*2 = 512) is the total embedding dimension + achieved by concatenating the sine and cosine values for each coordinate. """ scale = 2 * math.pi - dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device) - dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats) - - coords = pos_tensor.unbind(-1) # list of (...,) tensors - embeddings = [coord[..., None] * scale / dim_t for coord in coords] # each (..., num_pos_feats) - embeddings = [ - torch.stack((e[..., 0::2].sin(), e[..., 1::2].cos()), dim=-1).flatten(-2) for e in embeddings - ] # each (..., num_pos_feats) - - if len(embeddings) >= 2: - embeddings[0], embeddings[1] = embeddings[1], embeddings[0] - - return torch.cat(embeddings, dim=-1).to(pos_tensor.dtype) + dim = hidden_size // 2 + # Keep dim_t in float32 for numerical precision; cast output to match caller dtype + dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device) + dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim) + x_embed = pos_tensor[:, :, 0].float() * scale + y_embed = pos_tensor[:, :, 1].float() * scale + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) + pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) + if pos_tensor.size(-1) == 4: + w_embed = pos_tensor[:, :, 2].float() * scale + pos_w = w_embed[:, :, None] / dim_t + pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) + + h_embed = pos_tensor[:, :, 3].float() * scale + pos_h = h_embed[:, :, None] / dim_t + pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) + else: + raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}") + # Cast back to the caller's dtype (supports bfloat16 / float16 AMP) + return pos.to(pos_tensor.dtype) class LWDETRDecoder(nn.Module): @@ -442,6 +460,7 @@ def __init__( dec_n_points: int = 2, group_detr: int = 13, dropout_prob: float = 0.0, + bbox_embed: nn.Module | None = None, ): super().__init__() self.dropout_prob = dropout_prob @@ -460,30 +479,102 @@ def __init__( for i in range(num_layers) ]) self.layernorm = nn.LayerNorm(self.d_model) + self.bbox_embed = bbox_embed self.ref_point_head = LWDETRHead(2 * self.d_model, self.d_model, self.d_model, num_layers=2) + self.angle_proj = nn.Sequential( + nn.Linear(4, self.d_model), + nn.ReLU(), + nn.Linear(self.d_model, self.d_model), + ) - def get_reference(self, reference_points, valid_ratios): - # batch_size, num_queries, batch_size, 4 + def get_reference( + self, reference_points: torch.Tensor, valid_ratios: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """This function computes the reference point inputs and positional embeddings for the decoder layers. + + Args: + reference_points: (batch_size, num_queries, 6) + tensor containing the current reference points in the format (cx, cy, w, h, sinθ, cosθ) + valid_ratios: (batch_size, num_levels, 2) + tensor containing the valid ratios for each level of the input feature maps + + Returns: + reference_points_inputs: (batch_size, num_queries, 1, num_levels, 6) + tensor containing the reference point inputs for the decoder layers, + which are the normalized center coordinates, + width and height of the bounding boxes w.r.t. the valid ratios of the input feature maps + query_pos: (batch_size, num_queries, d_model) + tensor containing the positional embeddings for the decoder layers, + which are computed from the reference points using sine and cosine functions and a linear projection + """ obj_center = reference_points[..., :4] - - # batch_size, num_queries, num_levels, 4 - reference_points_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] - - # batch_size, num_queries, d_model * 2 - query_sine_embed = encode_sinusoidal_position_embedding( - reference_points_inputs[:, :, 0, :], num_pos_feats=self.d_model // 2 + spatial_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] + # Extract angles + angle = reference_points[..., 4:6] # (sin, cos) + angle_expanded = angle[:, :, None] + reference_points_inputs = torch.cat([spatial_inputs, angle_expanded], dim=-1) + # DETR positional encoding + query_sine_embed = gen_sine_position_embeddings(spatial_inputs[:, :, 0, :], self.d_model) + base_query_pos = self.ref_point_head(query_sine_embed) + # Angle embedding + sin_t = angle[..., 0:1] + cos_t = angle[..., 1:2] + + angle_feat = torch.cat( + [ + sin_t, + cos_t, + 2 * sin_t * cos_t, + cos_t**2 - sin_t**2, + ], + dim=-1, ) - # batch_size, num_queries, d_model - query_pos = self.ref_point_head(query_sine_embed) + angle_emb = self.angle_proj(angle_feat) + # Combine + query_pos = base_query_pos + angle_emb return reference_points_inputs, query_pos + def refine_boxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: + """Refine bounding boxes by applying the predicted deltas to the reference points. + The reference points are in the format (cx, cy, w, h, sinθ, cosθ), and the deltas are in the same format. + The refined boxes are computed as follows: + + cx' = cx + delta_cx * w + cy' = cy + delta_cy * h + w' = w * exp(delta_w) + h' = h * exp(delta_h) + sinθ' = sinθ * cosΔ + cosθ * sinΔ + cosθ' = cosθ * cosΔ - sinθ * sinΔ + + Args: + reference_points: (N, S, 6) tensor containing the reference points + deltas: (N, S, 6) tensor containing the predicted deltas + + Returns: + refined_boxes: (N, S, 6) tensor containing the refined bounding boxes + """ + reference_points = reference_points.to(deltas.device) + cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] + # size + wh = torch.clamp(deltas[..., 2:4], min=-10.0, max=10.0).exp() * reference_points[..., 2:4] + # rotation + delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) + sin_delta = delta_rot[..., 0:1] + cos_delta = delta_rot[..., 1:2] + sin_ref = reference_points[..., 4:5] + cos_ref = reference_points[..., 5:6] + # compose rotations + sin_new = sin_ref * cos_delta + cos_ref * sin_delta + cos_new = cos_ref * cos_delta - sin_ref * sin_delta + rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) + return torch.cat((cxcy, wh, rot), dim=-1) + def forward( self, inputs_embeds: torch.Tensor | None, reference_points: torch.Tensor, - spatial_shapes: torch.Tensor, spatial_shapes_list: torch.Tensor, valid_ratios: torch.Tensor, encoder_hidden_states: torch.Tensor, @@ -505,18 +596,34 @@ def forward( encoder_attention_mask=encoder_attention_mask, position_embeddings=query_pos, reference_points=reference_points_inputs, - spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list, ) - intermediate_hidden_states = self.layernorm(hidden_states) - intermediate.append(intermediate_hidden_states) + hidden_states_norm = self.layernorm(hidden_states) + + # iterative refinement + if self.bbox_embed is not None: + delta = self.bbox_embed(hidden_states_norm) + + reference_points = self.refine_boxes( + reference_points.squeeze(2), + delta, + ) + intermediate_reference_points.append(reference_points) + + reference_points_inputs, query_pos = self.get_reference( + reference_points, + valid_ratios, + ) + + intermediate.append(hidden_states_norm) + + intermediate_stack = torch.stack(intermediate) + last_hidden_state = intermediate_stack[-1] - intermediate = torch.stack(intermediate) - last_hidden_state = intermediate[-1] - intermediate_reference_points = torch.stack(intermediate_reference_points) + intermediate_reference_points_stack = torch.stack(intermediate_reference_points) - return last_hidden_state, intermediate, intermediate_reference_points + return last_hidden_state, intermediate_stack, intermediate_reference_points_stack class MultiScaleProjector(nn.Module): diff --git a/doctr/models/layout/lw_detr/loss.py b/doctr/models/layout/lw_detr/loss.py deleted file mode 100644 index ba4364764f..0000000000 --- a/doctr/models/layout/lw_detr/loss.py +++ /dev/null @@ -1,527 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -from scipy.optimize import linear_sum_assignment -from torch import Tensor - -__all__ = ["lw_detr_for_object_detection_loss"] - - -def center_to_corners_format(bboxes_center: "torch.Tensor") -> "torch.Tensor": - center_x, center_y, width, height = bboxes_center.unbind(-1) - bbox_corners = torch.stack( - # top left x, top left y, bottom right x, bottom right y - [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)], - dim=-1, - ) - return bbox_corners - - -def dice_loss(inputs, targets, num_boxes): - """ - Compute the DICE loss, similar to generalized IOU for masks - - Args: - inputs: A float tensor of arbitrary shape. - The predictions for each example. - targets: A float tensor with the same shape as inputs. Stores the binary - classification label for each element in inputs (0 for the negative class and 1 for the positive - class). - num_boxes: Normalization factor, typically the number of target boxes in the batch. This is used to scale the - loss to an absolute value, and is used in the original implementation of DETR and LW-DETR. - It doesn't have to be - exactly the number of target boxes, but it should be correlated to it for the loss to be meaningful. - """ - inputs = inputs.sigmoid() - inputs = inputs.flatten(1) - numerator = 2 * (inputs * targets).sum(1) - denominator = inputs.sum(-1) + targets.sum(-1) - loss = 1 - (numerator + 1) / (denominator + 1) - return loss.sum() / num_boxes - - -def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): - """ - Loss used in RetinaNet for dense detection: https://huggingface.co/papers/1708.02002. - - Args: - inputs (`torch.FloatTensor` of arbitrary shape): - The predictions for each example. - targets (`torch.FloatTensor` with the same shape as `inputs`): - A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class - and 1 for the positive class). - num_boxes (`int`): - Normalization factor, typically the number of target boxes in the batch. This is used to scale the loss - to an absolute value, and is used in the original implementation of DETR and LW-DETR. It doesn't have to be - exactly the number of target boxes, but it should be correlated to it for the loss to be meaningful. - alpha (`float`, *optional*, defaults to `0.25`): - Optional weighting factor in the range (0,1) to balance positive vs. negative examples. - gamma (`int`, *optional*, defaults to `2`): - Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. - - Returns: - Loss tensor - """ - prob = inputs.sigmoid() - ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none") - # add modulating factor - p_t = prob * targets + (1 - prob) * (1 - targets) - loss = ce_loss * ((1 - p_t) ** gamma) - - if alpha >= 0: - alpha_t = alpha * targets + (1 - alpha) * (1 - targets) - loss = alpha_t * loss - - return loss.mean(1).sum() / num_boxes - - -def _upcast(t: Tensor) -> Tensor: - # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type - if t.is_floating_point(): - return t if t.dtype in (torch.float32, torch.float64) else t.float() - else: - return t if t.dtype in (torch.int32, torch.int64) else t.int() - - -def box_area(boxes: Tensor) -> Tensor: - """ - Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates. - - Args: - boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`): - Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1 - < x2` and `0 <= y1 < y2`. - - Returns: - `torch.FloatTensor`: a tensor containing the area for each box. - """ - boxes = _upcast(boxes) - return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) - - -# modified from torchvision to also return the union -def box_iou(boxes1, boxes2): - area1 = box_area(boxes1) - area2 = box_area(boxes2) - - left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] - right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] - - width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2] - inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M] - - union = area1[:, None] + area2 - inter - - iou = inter / union - return iou, union - - -def generalized_box_iou(boxes1, boxes2): - """ - Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format. - - Returns: - `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) - """ - # degenerate boxes gives inf / nan results - # so do an early check - if not (boxes1[:, 2:] >= boxes1[:, :2]).all(): - raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}") - if not (boxes2[:, 2:] >= boxes2[:, :2]).all(): - raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}") - iou, union = box_iou(boxes1, boxes2) - - top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2]) - bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) - - width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2] - area = width_height[:, :, 0] * width_height[:, :, 1] - - return iou - (area - union) / area - - -# below: taken from https://github.com/facebookresearch/detr/blob/master/util/misc.py#L306 -def _max_by_axis(the_list): - # type: (list[list[int]]) -> list[int] - maxes = the_list[0] - for sublist in the_list[1:]: - for index, item in enumerate(sublist): - maxes[index] = max(maxes[index], item) - return maxes - - -class NestedTensor: - def __init__(self, tensors, mask: Tensor | None): - self.tensors = tensors - self.mask = mask - - def to(self, device): - cast_tensor = self.tensors.to(device) - mask = self.mask - if mask is not None: - cast_mask = mask.to(device) - else: - cast_mask = None - return NestedTensor(cast_tensor, cast_mask) - - def decompose(self): - return self.tensors, self.mask - - def __repr__(self): - return str(self.tensors) - - -def nested_tensor_from_tensor_list(tensor_list: list[Tensor]): - if tensor_list[0].ndim == 3: - max_size = _max_by_axis([list(img.shape) for img in tensor_list]) - batch_shape = [len(tensor_list)] + max_size - batch_size, num_channels, height, width = batch_shape - dtype = tensor_list[0].dtype - device = tensor_list[0].device - tensor = torch.zeros(batch_shape, dtype=dtype, device=device) - mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device) - for img, pad_img, m in zip(tensor_list, tensor, mask): - pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) - m[: img.shape[1], : img.shape[2]] = False - else: - raise ValueError("Only 3-dimensional tensors are supported") - return NestedTensor(tensor, mask) - - -# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py -def _set_aux_loss(outputs_class, outputs_coord): - return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] - - -class LwDetrHungarianMatcher(nn.Module): - def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1): - super().__init__() - - self.class_cost = class_cost - self.bbox_cost = bbox_cost - self.giou_cost = giou_cost - if class_cost == 0 and bbox_cost == 0 and giou_cost == 0: - raise ValueError("All costs of the Matcher can't be 0") - - @torch.no_grad() - def forward(self, outputs, targets, group_detr): - """ - Differences: - - out_prob = outputs["logits"].flatten(0, 1).sigmoid() instead of softmax - - class_cost uses alpha and gamma - """ - batch_size, num_queries = outputs["logits"].shape[:2] - - # We flatten to compute the cost matrices in a batch - out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes] - out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] - - # Also concat the target labels and boxes - target_ids = torch.cat([torch.as_tensor(v["class_labels"], dtype=torch.int64) for v in targets]).to( - out_prob.device - ) - target_bbox = torch.cat([torch.as_tensor(v["boxes"], dtype=torch.float32) for v in targets]).to(out_bbox.device) - - # Compute the classification cost. - alpha = 0.25 - gamma = 2.0 - neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log()) - pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) - class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids] - - # Compute the L1 cost between boxes, cdist only supports float32 - dtype = out_bbox.dtype - out_bbox = out_bbox.to(torch.float32) - target_bbox = target_bbox.to(torch.float32) - bbox_cost = torch.cdist(out_bbox, target_bbox, p=1) - bbox_cost = bbox_cost.to(dtype) - - # Compute the giou cost between boxes - giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox)) - - # Final cost matrix - cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost - cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu() - - sizes = [len(v["boxes"]) for v in targets] - indices = [] - group_num_queries = num_queries // group_detr - cost_matrix_list = cost_matrix.split(group_num_queries, dim=1) - for group_id in range(group_detr): - group_cost_matrix = cost_matrix_list[group_id] - group_indices = [linear_sum_assignment(c[i]) for i, c in enumerate(group_cost_matrix.split(sizes, -1))] - if group_id == 0: - indices = group_indices - else: - indices = [ - ( - np.concatenate([indice1[0], indice2[0] + group_num_queries * group_id]), - np.concatenate([indice1[1], indice2[1]]), - ) - for indice1, indice2 in zip(indices, group_indices) - ] - return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] - - -class LwDetrImageLoss(nn.Module): - def __init__(self, matcher, num_classes, focal_alpha, losses, group_detr): - super().__init__() - self.matcher = matcher - self.num_classes = num_classes - self.focal_alpha = focal_alpha - self.losses = losses - self.group_detr = group_detr - - # removed logging parameter, which was part of the original implementation - def loss_labels(self, outputs, targets, indices, num_boxes): - if "logits" not in outputs: - raise KeyError("No logits were found in the outputs") - source_logits = outputs["logits"] - dtype = source_logits.dtype - - idx = self._get_source_permutation_idx(indices) - target_classes_o = torch.cat([ - torch.as_tensor(np.atleast_1d(t["class_labels"][J]), dtype=torch.int64) - for t, (_, J) in zip(targets, indices) - ]).to(source_logits.device) - alpha = self.focal_alpha - gamma = 2 - src_boxes = outputs["pred_boxes"][idx] - target_boxes = torch.cat( - [torch.as_tensor(np.atleast_2d(t["boxes"][i]), dtype=torch.float32) for t, (_, i) in zip(targets, indices)], - dim=0, - ).to(src_boxes.device) - iou_targets = torch.diag( - box_iou(center_to_corners_format(src_boxes.detach()), center_to_corners_format(target_boxes))[0] - ) - # Convert to the same dtype as the source logits as box_iou upcasts to float32 - iou_targets = iou_targets.to(dtype) - pos_ious = iou_targets.clone().detach() - prob = source_logits.sigmoid() - # init positive weights and negative weights - pos_weights = torch.zeros_like(source_logits) - # pow promotes to float32 under float16 CUDA autocast; cast back to preserve original dtype - neg_weights = prob.pow(gamma).to(dtype) - pos_ind = idx + (target_classes_o,) - - pos_quality = prob[pos_ind].pow(alpha) * pos_ious.pow(1 - alpha) - pos_quality = torch.clamp(pos_quality, 0.01).detach().to(dtype) - - pos_weights[pos_ind] = pos_quality - neg_weights[pos_ind] = 1 - pos_quality - loss_ce = -pos_weights * prob.log() - neg_weights * (1 - prob).log() - loss_ce = loss_ce.sum() / num_boxes - losses = {"loss_ce": loss_ce} - - return losses - - @torch.no_grad() - def loss_cardinality(self, outputs, targets, indices, num_boxes): - """ - Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes. - - This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients. - """ - logits = outputs["logits"] - device = logits.device - target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device) - # Count the number of predictions that are NOT "no-object" (sigmoid > 0.5 threshold) - card_pred = (logits.sigmoid().max(-1).values > 0.5).sum(1) - card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float()) - losses = {"cardinality_error": card_err} - return losses - - # Copied from loss.loss_for_object_detection.ImageLoss.loss_boxes - def loss_boxes(self, outputs, targets, indices, num_boxes): - """ - Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss. - - Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes - are expected in format (center_x, center_y, w, h), normalized by the image size. - """ - if "pred_boxes" not in outputs: - raise KeyError("No predicted boxes found in outputs") - idx = self._get_source_permutation_idx(indices) - source_boxes = outputs["pred_boxes"][idx] - target_boxes = torch.cat( - [torch.as_tensor(np.atleast_2d(t["boxes"][i]), dtype=torch.float32) for t, (_, i) in zip(targets, indices)], - dim=0, - ).to(source_boxes.device) - - loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none") - - losses = {} - losses["loss_bbox"] = loss_bbox.sum() / num_boxes - - loss_giou = 1 - torch.diag( - generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes)) - ) - losses["loss_giou"] = loss_giou.sum() / num_boxes - return losses - - # Copied from loss.loss_for_object_detection.ImageLoss.loss_masks - def loss_masks(self, outputs, targets, indices, num_boxes): - """ - Compute the losses related to the masks: the focal loss and the dice loss. - - Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]. - """ - if "pred_masks" not in outputs: - raise KeyError("No predicted masks found in outputs") - - source_idx = self._get_source_permutation_idx(indices) - target_idx = self._get_target_permutation_idx(indices) - source_masks = outputs["pred_masks"] - source_masks = source_masks[source_idx] - masks = [t["masks"] for t in targets] - # TODO use valid to mask invalid areas due to padding in loss - target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() - target_masks = target_masks.to(source_masks) - target_masks = target_masks[target_idx] - - # upsample predictions to the target size - source_masks = nn.functional.interpolate( - source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False - ) - source_masks = source_masks[:, 0].flatten(1) - - target_masks = target_masks.flatten(1) - target_masks = target_masks.view(source_masks.shape) - losses = { - "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes), - "loss_dice": dice_loss(source_masks, target_masks, num_boxes), - } - return losses - - # Copied from loss.loss_for_object_detection.ImageLoss._get_source_permutation_idx - def _get_source_permutation_idx(self, indices): - # permute predictions following indices - batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)]) - source_idx = torch.cat([source for (source, _) in indices]) - return batch_idx, source_idx - - # Copied from loss.loss_for_object_detection.ImageLoss._get_target_permutation_idx - def _get_target_permutation_idx(self, indices): - # permute targets following indices - batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)]) - target_idx = torch.cat([target for (_, target) in indices]) - return batch_idx, target_idx - - def get_loss(self, loss, outputs, targets, indices, num_boxes): - loss_map = { - "labels": self.loss_labels, - "cardinality": self.loss_cardinality, - "boxes": self.loss_boxes, - "masks": self.loss_masks, - } - if loss not in loss_map: - raise ValueError(f"Loss {loss} not supported") - return loss_map[loss](outputs, targets, indices, num_boxes) - - def forward(self, outputs, targets): - """ - This performs the loss computation. - - Args: - outputs (`dict`, *optional*): - Dictionary of tensors, see the output specification of the model for the format. - targets (`list[dict]`, *optional*): - List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the - losses applied, see each loss' doc. - """ - group_detr = self.group_detr if self.training else 1 - outputs_without_aux_and_enc = { - k: v for k, v in outputs.items() if k != "enc_outputs" and k != "auxiliary_outputs" - } - - # Retrieve the matching between the outputs of the last layer and the targets - indices = self.matcher(outputs_without_aux_and_enc, targets, group_detr) - - # Compute the average number of target boxes across all nodes, for normalization purposes - num_boxes = sum(len(t["class_labels"]) for t in targets) - num_boxes = num_boxes * group_detr - num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) - world_size = 1 - num_boxes = torch.clamp(num_boxes / world_size, min=1).item() - - # Compute all the requested losses - losses = {} - for loss in self.losses: - losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) - - # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. - if "auxiliary_outputs" in outputs: - for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]): - indices = self.matcher(auxiliary_outputs, targets, group_detr) - for loss in self.losses: - if loss == "masks": - # Intermediate masks losses are too costly to compute, we ignore them. - continue - l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes) - l_dict = {k + f"_{i}": v for k, v in l_dict.items()} - losses.update(l_dict) - - if "enc_outputs" in outputs: - enc_outputs = outputs["enc_outputs"] - indices = self.matcher(enc_outputs, targets, group_detr=group_detr) - for loss in self.losses: - l_dict = self.get_loss(loss, enc_outputs, targets, indices, num_boxes) - l_dict = {k + "_enc": v for k, v in l_dict.items()} - losses.update(l_dict) - - return losses - - -def lw_detr_for_object_detection_loss( - logits, - labels, - device, - pred_boxes, - outputs_class=None, - outputs_coord=None, - enc_outputs_class=None, - enc_outputs_coord=None, - use_aux_loss=False, - group_detr=1, - num_labels=None, - num_decoder_layers=None, - **kwargs, -): - """Loss computation for LW-DETR for object detection.""" - # First: create the matcher - matcher = LwDetrHungarianMatcher(class_cost=2.0, bbox_cost=5, giou_cost=2) - # Second: create the criterion - losses = ["labels", "boxes", "cardinality"] - criterion = LwDetrImageLoss( - matcher=matcher, - num_classes=num_labels, - focal_alpha=0.1, - losses=losses, - group_detr=group_detr, - ) - criterion.to(device) - # Third: compute the losses, based on outputs and labels - outputs_loss = {} - auxiliary_outputs = None - outputs_loss["logits"] = logits - outputs_loss["pred_boxes"] = pred_boxes - outputs_loss["enc_outputs"] = { - "logits": enc_outputs_class, - "pred_boxes": enc_outputs_coord, - } - if use_aux_loss: - auxiliary_outputs = _set_aux_loss(outputs_class, outputs_coord) - outputs_loss["auxiliary_outputs"] = auxiliary_outputs - loss_dict = criterion(outputs_loss, labels) - # Fourth: compute total loss, as a weighted sum of the various losses - weight_dict = {"loss_ce": 1, "loss_bbox": 5} - weight_dict["loss_giou"] = 2 - if use_aux_loss: - aux_weight_dict = {} - for i in range(num_decoder_layers - 1): - aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) - weight_dict.update(aux_weight_dict) - enc_weight_dict = {k + "_enc": v for k, v in weight_dict.items()} - weight_dict.update(enc_weight_dict) - loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict if k in weight_dict) - return loss, loss_dict, auxiliary_outputs diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 2ca0ecee4a..c8d378d27a 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -10,6 +10,7 @@ import numpy as np import torch +from scipy.optimize import linear_sum_assignment from torch import nn from torch.nn import functional as F @@ -18,7 +19,6 @@ from ...utils import load_pretrained_params from .base import _LWDETR, LWDETRPostProcessor from .layers import LWDETRDecoder, LWDETRHead, LWDETRMultiscaleDeformableAttention, MultiScaleProjector -from .loss import lw_detr_for_object_detection_loss __all__ = ["LWDETR", "lw_detr_s", "lw_detr_m"] @@ -154,10 +154,10 @@ def __init__( self, feat_extractor: LWDETRBackbone, class_names: list[str], - score_thresh: float = 0.0, - iou_thresh: float = 0.1, + score_thresh: float = 0.05, + iou_thresh: float = 0.05, d_model: int = 256, - num_queries: int = 130, + num_queries: int = 50, group_detr: int = 1, dec_layers: int = 3, sa_num_heads: int = 8, @@ -172,7 +172,7 @@ def __init__( super().__init__() self.class_names: list[str] = class_names - self.num_classes = len(self.class_names) + 1 # +1 for background class (NO OBJECT) + self.num_classes = len(self.class_names) + 1 # +1 for background class self.cfg = cfg self.exportable = exportable self.assume_straight_pages = assume_straight_pages @@ -182,13 +182,22 @@ def __init__( self.group_detr = group_detr self.num_queries = num_queries self.d_model = d_model - self.dec_layers = dec_layers - self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 4) + self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 6) + # Initialize angle to (sin=0, cos=1) + with torch.no_grad(): + self.reference_point_embed.weight[:, 0:2].uniform_(0.05, 0.95) + self.reference_point_embed.weight[:, 2:4].fill_(0.1) + self.reference_point_embed.weight[:, 4].zero_() + self.reference_point_embed.weight[:, 5].fill_(1.0) + self.query_feat = nn.Embedding(self.num_queries * self.group_detr, self.d_model) + self.class_embed = nn.Linear(self.d_model, self.num_classes) + self.bbox_embed = LWDETRHead(self.d_model, self.d_model, 6, num_layers=3) + self.decoder = LWDETRDecoder( - num_layers=self.dec_layers, + num_layers=dec_layers, d_model=d_model, sa_num_heads=sa_num_heads, ca_num_heads=ca_num_heads, @@ -196,21 +205,19 @@ def __init__( dec_n_points=dec_n_points, group_detr=group_detr, dropout_prob=dropout_prob, + bbox_embed=self.bbox_embed, ) self.enc_output = nn.ModuleList([nn.Linear(self.d_model, self.d_model) for _ in range(self.group_detr)]) self.enc_output_norm = nn.ModuleList([nn.LayerNorm(self.d_model) for _ in range(self.group_detr)]) self.enc_out_bbox_embed = nn.ModuleList([ - LWDETRHead(self.d_model, self.d_model, 4, num_layers=3) for _ in range(self.group_detr) + LWDETRHead(self.d_model, self.d_model, 6, num_layers=3) for _ in range(self.group_detr) ]) self.enc_out_class_embed = nn.ModuleList([ nn.Linear(self.d_model, self.num_classes) for _ in range(self.group_detr) ]) - self.class_embed = nn.Linear(self.d_model, self.num_classes) - self.bbox_embed = LWDETRHead(self.d_model, self.d_model, 4, num_layers=3) - self.postprocessor = LWDETRPostProcessor( num_classes=self.num_classes, score_thresh=score_thresh, @@ -222,9 +229,28 @@ def __init__( # Don't override the initialization of the backbone if n.startswith("feat_extractor."): continue - if isinstance(m, LWDETRMultiscaleDeformableAttention): + + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): + if hasattr(m, "weight") and m.weight is not None: + nn.init.ones_(m.weight) + if hasattr(m, "bias") and m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + # Don't overwrite the carefully seeded reference point embedding + if m is not self.reference_point_embed: + nn.init.normal_(m.weight, std=0.02) + elif isinstance(m, LWDETRMultiscaleDeformableAttention): nn.init.constant_(m.sampling_offsets.weight, 0.0) - thetas = torch.arange(m.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / m.n_heads) + + thetas = torch.arange(m.n_heads, dtype=torch.float32) * (2.0 * math.pi / m.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) @@ -235,21 +261,29 @@ def __init__( grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): m.sampling_offsets.bias.copy_(grid_init.view(-1)) + nn.init.constant_(m.attention_weights.weight, 0.0) nn.init.constant_(m.attention_weights.bias, 0.0) nn.init.xavier_uniform_(m.value_proj.weight) - nn.init.constant_(m.value_proj.bias, 0.0) + nn.init.zeros_(m.value_proj.bias) nn.init.xavier_uniform_(m.output_proj.weight) - nn.init.constant_(m.output_proj.bias, 0.0) - if hasattr(m, "refpoint_embed") and m.refpoint_embed is not None: - nn.init.constant_(m.refpoint_embed.weight, 0) - if hasattr(m, "class_embed") and m.class_embed is not None: - prior_prob = 0.01 - bias_value = -math.log((1 - prior_prob) / prior_prob) - nn.init.constant_(m.class_embed.bias, bias_value) - if hasattr(m, "bbox_embed") and m.bbox_embed is not None: - nn.init.constant_(m.bbox_embed.layers[-1].weight, 0) - nn.init.constant_(m.bbox_embed.layers[-1].bias, 0) + nn.init.zeros_(m.output_proj.bias) + if isinstance(m, nn.Linear) and m.out_features == self.num_classes: + if m.bias is not None: + with torch.no_grad(): + # Focal-loss prior: foreground starts with low confidence (~0.01), + # preventing background from dominating gradients at the start of training. + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + nn.init.constant_(m.bias, 0.0) + m.bias[:-1].fill_(bias_value) + if isinstance(m, LWDETRHead): + last = m.layers[-1] + if isinstance(last, nn.Linear): + nn.init.zeros_(last.weight) + nn.init.zeros_(last.bias) + if last.bias.shape[0] == 6: + nn.init.constant_(last.bias[5], 1.0) def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: """Load pretrained parameters onto the model @@ -260,39 +294,76 @@ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: """ load_pretrained_params(self, path_or_url, **kwargs) - def refine_bboxes(self, reference_points, deltas): + def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: + """Refine bounding boxes by applying the predicted deltas to the reference points. + The reference points are in the format (cx, cy, w, h, sinθ, cosθ), and the deltas are in the same format. + The refined boxes are computed as follows: + + cx' = cx + delta_cx * w + cy' = cy + delta_cy * h + w' = w * exp(delta_w) + h' = h * exp(delta_h) + sinθ' = sinθ * cosΔ + cosθ * sinΔ + cosθ' = cosθ * cosΔ - sinθ * sinΔ + + Args: + reference_points: (N, S, 6) tensor containing the reference points + deltas: (N, S, 6) tensor containing the predicted deltas + + Returns: + refined_boxes: (N, S, 6) tensor containing the refined bounding boxes + """ reference_points = reference_points.to(deltas.device) - new_reference_points_cxcy = deltas[..., :2] * reference_points[..., 2:] + reference_points[..., :2] - new_reference_points_wh = deltas[..., 2:].exp() * reference_points[..., 2:] - new_reference_points = torch.cat((new_reference_points_cxcy, new_reference_points_wh), -1) - return new_reference_points + cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] + # size + wh = torch.clamp(deltas[..., 2:4], min=-10.0, max=10.0).exp() * reference_points[..., 2:4] + # rotation + delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) + sin_delta = delta_rot[..., 0:1] + cos_delta = delta_rot[..., 1:2] + sin_ref = reference_points[..., 4:5] + cos_ref = reference_points[..., 5:6] + # compose rotations + sin_new = sin_ref * cos_delta + cos_ref * sin_delta + cos_new = cos_ref * cos_delta - sin_ref * sin_delta + rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) + return torch.cat((cxcy, wh, rot), dim=-1) + + def get_valid_ratio(self, mask: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Get the valid ratio of all feature maps. + + Args: + mask: (N, H, W) binary tensor containing 1 on padded pixels + dtype: the desired data type of the output tensor - def get_valid_ratio(self, mask, dtype=torch.float32): - """Get the valid ratio of all feature maps.""" + Returns: + valid_ratio: (N, 2) tensor containing the valid ratio of width and height for each image in the batch + """ _, height, width = mask.shape - valid_height = torch.sum(mask[:, :, 0], 1) - valid_width = torch.sum(mask[:, 0, :], 1) + valid_height = torch.sum(~mask[:, :, 0], 1) + valid_width = torch.sum(~mask[:, 0, :], 1) valid_ratio_height = valid_height.to(dtype) / height valid_ratio_width = valid_width.to(dtype) / width valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1) return valid_ratio - def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes): + def gen_encoder_output_proposals( + self, enc_output: torch.Tensor, padding_mask: torch.Tensor, spatial_shapes: list[tuple[int, int]] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Generate the encoder output proposals from encoded enc_output. Args: - enc_output (Tensor[batch_size, sequence_length, hidden_size]): Output of the encoder. - padding_mask (Tensor[batch_size, sequence_length]): Padding mask for `enc_output`. - spatial_shapes (list[tuple[int, int]]): Spatial shapes of the feature maps. + enc_output: Output of the encoder + padding_mask: Padding mask for `enc_output` + spatial_shapes: Spatial shapes of the feature maps Returns: - `tuple(torch.FloatTensor)`: A tuple of feature map and bbox prediction. - - object_query (Tensor[batch_size, sequence_length, hidden_size]): Object query features. Later used to - directly predict a bounding box. (without the need of a decoder) - - output_proposals (Tensor[batch_size, sequence_length, 4]): Normalized proposals in [0, 1] space. - Invalid positions (padding or out-of-bounds) are filled with 0. - - invalid_mask (Tensor[batch_size, sequence_length, 1]): Boolean mask that is True for invalid positions - (padded pixels or proposals whose coordinates fall outside (0.01, 0.99)). + A tuple of feature map and bbox prediction. + - object_query: Object query features. Later used to directly predict a bounding box. + - output_proposals: Normalized proposals in [0, 1] space. + Invalid positions (padding or out-of-bounds) are filled with 0. + - invalid_mask: Boolean mask that is True for invalid positions + (padded pixels or proposals whose coordinates fall outside (0.01, 0.99)). """ batch_size = enc_output.shape[0] proposals = [] @@ -324,17 +395,23 @@ def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes) scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2) grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale width_height = torch.ones_like(grid) * 0.05 * (2.0**level) - proposal = torch.cat((grid, width_height), -1).view(batch_size, -1, 4) + # add default rotation (sin=0, cos=1) + sin = torch.zeros_like(grid[..., :1]) + cos = torch.ones_like(grid[..., :1]) + proposal = torch.cat((grid, width_height, sin, cos), -1).view(batch_size, -1, 6) proposals.append(proposal) _cur += height * width output_proposals = torch.cat(proposals, 1) - output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + + spatial_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + output_proposals_valid = spatial_valid invalid_mask = padding_mask.unsqueeze(-1) | ~output_proposals_valid output_proposals = output_proposals.masked_fill(invalid_mask, float(0)) # assign each pixel as an object query object_query = enc_output - object_query = object_query.masked_fill(invalid_mask, float(0)) + object_query = object_query.masked_fill(invalid_mask, 0.0) + return object_query, output_proposals, invalid_mask def forward( @@ -379,7 +456,6 @@ def forward( mask_flatten_list.append(mask) source_flatten = torch.cat(source_flatten_list, 1) mask_flatten = torch.cat(mask_flatten_list, 1) - spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device) valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in feats_masks], 1) tgt = query_feat.unsqueeze(0).expand(batch_size, -1, -1) @@ -392,73 +468,65 @@ def forward( group_detr = self.group_detr if self.training else 1 topk = self.num_queries - topk_coords_logits = [] - topk_coords_logits_undetach = [] - object_query_undetach = [] + topk_coords_logits_list: list[torch.Tensor] = [] + topk_content_list: list[torch.Tensor] = [] + + # encoder predictions for auxiliary losses + all_group_enc_logits: list[torch.Tensor] = [] + all_group_enc_coords: list[torch.Tensor] = [] for group_id in range(group_detr): group_object_query = self.enc_output[group_id](object_query_embedding) group_object_query = self.enc_output_norm[group_id](group_object_query) group_enc_outputs_class = self.enc_out_class_embed[group_id](group_object_query) - group_enc_outputs_class = group_enc_outputs_class.masked_fill(invalid_mask, float("-inf")) + all_group_enc_logits.append(group_enc_outputs_class) + + group_enc_outputs_class_masked = group_enc_outputs_class.masked_fill(invalid_mask, float("-inf")) + group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query) group_enc_outputs_coord = self.refine_bboxes(output_proposals, group_delta_bbox) - group_topk_proposals = torch.topk(group_enc_outputs_class.max(-1)[0], topk, dim=1)[1] + all_group_enc_coords.append(group_enc_outputs_coord) + + scores = group_enc_outputs_class_masked[..., :-1].max(-1).values + group_topk_proposals = torch.topk(scores, topk, dim=1)[1] + group_topk_coords_logits_undetach = torch.gather( group_enc_outputs_coord, 1, - group_topk_proposals.unsqueeze(-1).repeat(1, 1, 4), + group_topk_proposals.unsqueeze(-1).repeat(1, 1, 6), ) group_topk_coords_logits = group_topk_coords_logits_undetach.detach() - group_object_query_undetach = torch.gather( - group_object_query, 1, group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model) + topk_coords_logits_list.append(group_topk_coords_logits) + group_topk_content = torch.gather( + group_object_query, + 1, + group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model), ) + topk_content_list.append(group_topk_content) - topk_coords_logits.append(group_topk_coords_logits) - topk_coords_logits_undetach.append(group_topk_coords_logits_undetach) - object_query_undetach.append(group_object_query_undetach) - - topk_coords_logits = torch.cat(topk_coords_logits, 1) - topk_coords_logits_undetach = torch.cat(topk_coords_logits_undetach, 1) - object_query_undetach = torch.cat(object_query_undetach, 1) + topk_coords_logits = torch.cat(topk_coords_logits_list, 1) + reference_points = topk_coords_logits - enc_outputs_class_logits = object_query_undetach - enc_outputs_boxes_logits = topk_coords_logits_undetach + topk_content = torch.cat(topk_content_list, 1).detach() + tgt = tgt + topk_content - reference_points = self.refine_bboxes(topk_coords_logits, reference_points) + encoder_attention_mask = mask_flatten - init_reference_points = reference_points - last_hidden_state, intermediate, intermediate_reference_points = self.decoder( + last_hidden_states, intermediate, intermediate_reference_points = self.decoder( inputs_embeds=tgt, reference_points=reference_points, - spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list, valid_ratios=valid_ratios, encoder_hidden_states=source_flatten, - encoder_attention_mask=mask_flatten, + encoder_attention_mask=encoder_attention_mask, ) - logits = self.class_embed(last_hidden_state) - pred_boxes_delta = self.bbox_embed(last_hidden_state) + logits = self.class_embed(last_hidden_states) + pred_boxes_delta = self.bbox_embed(last_hidden_states) pred_boxes = self.refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta) - enc_outputs_class_logits_list = enc_outputs_class_logits.split(self.num_queries, dim=1) - pred_class = [] - group_detr = self.group_detr if self.training else 1 - for group_index in range(group_detr): - group_pred_class = self.enc_out_class_embed[group_index](enc_outputs_class_logits_list[group_index]) - pred_class.append(group_pred_class) - enc_outputs_class_logits = torch.cat(pred_class, dim=1) - - if target is not None: - outputs_class, outputs_coord = None, None - intermediate_hidden_states = intermediate - outputs_coord_delta = self.bbox_embed(intermediate_hidden_states) - outputs_coord = self.refine_bboxes(intermediate_reference_points, outputs_coord_delta) - outputs_class = self.class_embed(intermediate_hidden_states) - out: dict[str, Any] = {} if self.exportable: @@ -480,44 +548,200 @@ def _postprocess(logits, boxes): if target is not None: # Build target processed_targets = self.build_target(target, self.class_names) - out["loss"] = self.compute_loss( - logits, - processed_targets, - pred_boxes, - outputs_class, - outputs_coord, - enc_outputs_class_logits, - enc_outputs_boxes_logits, - ) + + # Main loss from final decoder layer (group DETR) + split_logits = logits.chunk(group_detr, dim=1) + split_boxes = pred_boxes.chunk(group_detr, dim=1) + + main_loss: float | torch.Tensor = 0.0 + for g_logits, g_boxes in zip(split_logits, split_boxes): + main_loss += self.compute_loss(g_logits, g_boxes, processed_targets) + loss = main_loss / group_detr + + # Auxiliary losses from intermediate decoder layers (group DETR) + for i in range(intermediate.shape[0] - 1): + aux_logits = self.class_embed(intermediate[i]) + aux_boxes_delta = self.bbox_embed(intermediate[i]) + aux_boxes = self.refine_bboxes(intermediate_reference_points[i + 1], aux_boxes_delta) + + split_aux_logits = aux_logits.chunk(group_detr, dim=1) + split_aux_boxes = aux_boxes.chunk(group_detr, dim=1) + + aux_loss: float | torch.Tensor = 0.0 + for g_logits, g_boxes in zip(split_aux_logits, split_aux_boxes): + aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets) + loss += aux_loss + + # Auxiliary losses for encoder proposals + enc_loss: float | torch.Tensor = 0.0 + for group_logits, group_coords in zip(all_group_enc_logits, all_group_enc_coords): + enc_loss += self.compute_loss(group_logits, group_coords, processed_targets) + loss += enc_loss + + out["loss"] = loss return out def compute_loss( self, - logits, - targets, - pred_boxes, - outputs_class, - outputs_coord, - enc_outputs_class_logits, - enc_outputs_boxes_logits, - ): - - loss_calc = lw_detr_for_object_detection_loss( - logits=logits, - device=logits.device, - labels=targets, - pred_boxes=pred_boxes, - outputs_class=outputs_class, - outputs_coord=outputs_coord, - enc_outputs_class=enc_outputs_class_logits, - enc_outputs_coord=enc_outputs_boxes_logits, - use_aux_loss=True, - group_detr=self.group_detr, - num_decoder_layers=self.dec_layers, - num_labels=self.num_classes, - ) - return loss_calc[0] + logits: torch.Tensor, + pred_boxes: torch.Tensor, + targets: list[dict[str, np.ndarray]], + ) -> torch.Tensor: + + def _rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Convert rotated boxes in (cx, cy, w, h, sinθ, cosθ) format + to Gaussian distributions (mean and covariance). + The mean is simply (cx, cy), and the covariance is computed from the width, height, and rotation angle. + + Args: + boxes: (N, S, 6) tensor containing the rotated boxes in (cx, cy, w, h, sinθ, cosθ) format + Returns: + A tuple of (mean, covariance) where: + - mean is a (N, S, 2) tensor containing the mean (cx, cy) of the Gaussian distributions + - covariance is a (N, S, 2, 2) tensor containing the covariance matrices of the Gaussian distributions + """ + cxcy = boxes[..., :2] + + w = boxes[..., 2].clamp(min=1e-6) + h = boxes[..., 3].clamp(min=1e-6) + + sin = boxes[..., 4] + cos = boxes[..., 5] + + R = torch.stack( + [ + torch.stack([cos, -sin], dim=-1), + torch.stack([sin, cos], dim=-1), + ], + dim=-2, + ) + + # Variance for a box half-width/half-height: σ² = (w/2)² + # Using w²/12 (uniform distribution) produces ~8x smaller variance, + # which collapses Bhattacharyya distance to the clamp ceiling and kills gradients. + sx = (w / 2) ** 2 + sy = (h / 2) ** 2 + + S = torch.zeros((*boxes.shape[:-1], 2, 2), device=boxes.device) + S[..., 0, 0] = sx + S[..., 1, 1] = sy + + covariance = R @ S @ R.transpose(-1, -2) + return cxcy, covariance + + def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Tensor: + mu1, sigma1 = _rotated_boxes_to_gaussian(pred_boxes) + mu2, sigma2 = _rotated_boxes_to_gaussian(tgt_boxes) + + delta = (mu1 - mu2).unsqueeze(-1) + sigma = (sigma1 + sigma2) * 0.5 + + eps = 1e-6 + eye = torch.eye(2, device=sigma.device) * eps + + sigma_safe = sigma + eye + sigma1_safe = sigma1 + eye + sigma2_safe = sigma2 + eye + + sigma_inv = torch.linalg.inv(sigma_safe) + + mahalanobis = (delta.transpose(-1, -2) @ sigma_inv @ delta).squeeze(-1).squeeze(-1) + + det_sigma = torch.linalg.det(sigma_safe).clamp(min=eps) + det_sigma1 = torch.linalg.det(sigma1_safe).clamp(min=eps) + det_sigma2 = torch.linalg.det(sigma2_safe).clamp(min=eps) + + bhattacharyya = 0.125 * mahalanobis + 0.5 * torch.log(det_sigma / torch.sqrt(det_sigma1 * det_sigma2)) + + bhattacharyya = torch.clamp(bhattacharyya, min=0.0, max=10.0) + probiou = torch.exp(-bhattacharyya) + return 1 - probiou + + device = logits.device + B, Q, C = logits.shape + + total_cls = torch.tensor(0.0, device=device) + total_box = torch.tensor(0.0, device=device) + + # FIX (issue #7): track total matched boxes across the batch for proper normalisation. + # Classification loss is still normalised by B (it covers all Q queries per image), + # but box/rotation losses are normalised by the actual number of matched pairs. + num_matched_total = 0 + + for b in range(B): + pred_logits = logits[b] + pred_boxes_b = pred_boxes[b] + + boxes = targets[b]["boxes"] + + if len(boxes) == 0: + # Penalize the model for any foreground boxes it guessed on this empty image + background_idx = self.num_classes - 1 + target_classes = torch.full((Q,), background_idx, device=device, dtype=torch.long) + total_cls += F.cross_entropy(pred_logits, target_classes) + continue + + tgt_boxes = torch.as_tensor(boxes, device=device, dtype=pred_boxes.dtype) + tgt_cls = torch.as_tensor(targets[b]["labels"], device=device, dtype=torch.long) + + if tgt_boxes.ndim == 1: + tgt_boxes = tgt_boxes.unsqueeze(0) + + pred_rot = F.normalize(pred_boxes_b[:, 4:6], dim=-1) + tgt_rot = F.normalize(tgt_boxes[:, 4:6], dim=-1) + + with torch.no_grad(): + out_logprob = pred_logits.log_softmax(-1) + + cost_cls = -out_logprob[:, tgt_cls] # stable + cost_l1 = torch.cdist(pred_boxes_b[:, :4], tgt_boxes[:, :4], p=1) + cost_rot = 1.0 - torch.abs(pred_rot @ tgt_rot.T) + + total_cost = 2.0 * cost_cls + 5.0 * cost_l1 + 2.0 * cost_rot + + cost_np = total_cost.detach().cpu().numpy() + row_ind, col_ind = linear_sum_assignment(cost_np) + + pos_idx = torch.as_tensor(row_ind, device=device) + gt_idx = torch.as_tensor(col_ind, device=device) + + background_idx = self.num_classes - 1 + + target_classes = torch.full((Q,), background_idx, device=device, dtype=torch.long) + target_classes[pos_idx] = tgt_cls[gt_idx] + + cls_weights = torch.ones(self.num_classes, device=device) + cls_weights[background_idx] = 1.0 + + total_cls += F.cross_entropy(pred_logits, target_classes, weight=cls_weights) + + if pos_idx.numel() == 0: + continue + + num_matched_total += pos_idx.numel() + + pred_sel = pred_boxes_b[pos_idx] + tgt_sel = tgt_boxes[gt_idx] + + # Smooth L1 (Huber) loss with beta=0.1: behaves like L2 for large errors early in + # training (gentle, stable gradient) and like L1 for small errors later (sharp + # localisation). Raw L1 with a large weight causes loss explosion when predictions + # are far from targets (e.g. 5 * 2.2 = 11.0 per pair vs cls ~2.0). + l1_loss = F.smooth_l1_loss(pred_sel[:, :4], tgt_sel[:, :4], reduction="sum", beta=0.1) + probiou_loss = _probiou_loss(pred_sel, tgt_sel).sum() + total_box += 5.0 * l1_loss + 2.0 * probiou_loss + + # FIX (issue #7): normalise box loss by total matched boxes (min 1), not by batch size. + # This prevents images with many GT boxes from dominating over images with few. + # Normalise box loss by total matched boxes, with a floor of 1 to keep + # the box/cls loss ratio stable even when images have very few GT boxes. + num_matched_total = max(num_matched_total, 1) + + loss_cls = total_cls / B + loss_box = total_box / num_matched_total + + return loss_cls + loss_box def _lw_detr( From e2491aa667a89b62161d96e4e089ab1355f22f2d Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 28 May 2026 15:24:41 +0200 Subject: [PATCH 03/15] update --- doctr/models/layout/lw_detr/layers/pytorch.py | 5 +- doctr/models/layout/lw_detr/pytorch.py | 67 ++++++++++--------- 2 files changed, 37 insertions(+), 35 deletions(-) diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py index bb6b503e97..e9f319e664 100644 --- a/doctr/models/layout/lw_detr/layers/pytorch.py +++ b/doctr/models/layout/lw_detr/layers/pytorch.py @@ -605,10 +605,7 @@ def forward( if self.bbox_embed is not None: delta = self.bbox_embed(hidden_states_norm) - reference_points = self.refine_boxes( - reference_points.squeeze(2), - delta, - ) + reference_points = self.refine_boxes(reference_points, delta) intermediate_reference_points.append(reference_points) reference_points_inputs, query_pos = self.get_reference( diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index c8d378d27a..e813becd0e 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -154,11 +154,11 @@ def __init__( self, feat_extractor: LWDETRBackbone, class_names: list[str], - score_thresh: float = 0.05, - iou_thresh: float = 0.05, + score_thresh: float = 0.3, + iou_thresh: float = 0.5, d_model: int = 256, - num_queries: int = 50, - group_detr: int = 1, + num_queries: int = 300, + group_detr: int = 13, dec_layers: int = 3, sa_num_heads: int = 8, ca_num_heads: int = 16, @@ -403,7 +403,7 @@ def gen_encoder_output_proposals( _cur += height * width output_proposals = torch.cat(proposals, 1) - spatial_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + spatial_valid = ((output_proposals[..., :4] > 0.01) & (output_proposals[..., :4] < 0.99)).all(-1, keepdim=True) output_proposals_valid = spatial_valid invalid_mask = padding_mask.unsqueeze(-1) | ~output_proposals_valid output_proposals = output_proposals.masked_fill(invalid_mask, float(0)) @@ -507,12 +507,8 @@ def forward( topk_content_list.append(group_topk_content) topk_coords_logits = torch.cat(topk_coords_logits_list, 1) - reference_points = topk_coords_logits - topk_content = torch.cat(topk_content_list, 1).detach() - tgt = tgt + topk_content - - encoder_attention_mask = mask_flatten + reference_points = self.refine_bboxes(topk_coords_logits, reference_points) last_hidden_states, intermediate, intermediate_reference_points = self.decoder( inputs_embeds=tgt, @@ -520,12 +516,11 @@ def forward( spatial_shapes_list=spatial_shapes_list, valid_ratios=valid_ratios, encoder_hidden_states=source_flatten, - encoder_attention_mask=encoder_attention_mask, + encoder_attention_mask=mask_flatten, ) logits = self.class_embed(last_hidden_states) - pred_boxes_delta = self.bbox_embed(last_hidden_states) - pred_boxes = self.refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta) + pred_boxes = intermediate_reference_points[-1] out: dict[str, Any] = {} @@ -561,8 +556,7 @@ def _postprocess(logits, boxes): # Auxiliary losses from intermediate decoder layers (group DETR) for i in range(intermediate.shape[0] - 1): aux_logits = self.class_embed(intermediate[i]) - aux_boxes_delta = self.bbox_embed(intermediate[i]) - aux_boxes = self.refine_bboxes(intermediate_reference_points[i + 1], aux_boxes_delta) + aux_boxes = intermediate_reference_points[i + 1] split_aux_logits = aux_logits.chunk(group_detr, dim=1) split_aux_boxes = aux_boxes.chunk(group_detr, dim=1) @@ -570,13 +564,13 @@ def _postprocess(logits, boxes): aux_loss: float | torch.Tensor = 0.0 for g_logits, g_boxes in zip(split_aux_logits, split_aux_boxes): aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets) - loss += aux_loss + loss += aux_loss / group_detr # Auxiliary losses for encoder proposals enc_loss: float | torch.Tensor = 0.0 for group_logits, group_coords in zip(all_group_enc_logits, all_group_enc_coords): enc_loss += self.compute_loss(group_logits, group_coords, processed_targets) - loss += enc_loss + loss += enc_loss / group_detr out["loss"] = loss @@ -588,6 +582,18 @@ def compute_loss( pred_boxes: torch.Tensor, targets: list[dict[str, np.ndarray]], ) -> torch.Tensor: + """Compute the loss between predicted logits and boxes and target labels and boxes. + + Args: + logits: (N, S, C) tensor containing the predicted class logits for each query + pred_boxes: (N, S, 6) tensor containing the predicted boxes in (cx, cy, w, h, sinθ, cosθ) format + targets: list of length N, where each element is a dict with keys "labels" and "boxes", + containing the ground truth labels and boxes for each image in the batch. + The boxes are in (cx, cy, w, h, sinθ, cosθ) format. + + Returns: + A scalar tensor containing the computed loss. + """ def _rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Convert rotated boxes in (cx, cy, w, h, sinθ, cosθ) format @@ -631,6 +637,17 @@ def _rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch return cxcy, covariance def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Tensor: + """Compute the ProbIoU loss between predicted and target boxes, + where boxes are represented as Gaussian distributions. + The ProbIoU loss is defined as 1 - exp(-Bhattacharyya distance), + where the Bhattacharyya distance is computed between the two Gaussian distributions. + + Args: + pred_boxes: (N, S, 6) tensor containing the predicted boxes in (cx, cy, w, h, sinθ, cosθ) format + tgt_boxes: (N, S, 6) tensor containing the target boxes in (cx, cy, w, h, sinθ, cosθ) format + Returns: + A (N, S) tensor containing the ProbIoU loss for each pair of predicted and target boxes + """ mu1, sigma1 = _rotated_boxes_to_gaussian(pred_boxes) mu2, sigma2 = _rotated_boxes_to_gaussian(tgt_boxes) @@ -664,9 +681,6 @@ def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Te total_cls = torch.tensor(0.0, device=device) total_box = torch.tensor(0.0, device=device) - # FIX (issue #7): track total matched boxes across the batch for proper normalisation. - # Classification loss is still normalised by B (it covers all Q queries per image), - # but box/rotation losses are normalised by the actual number of matched pairs. num_matched_total = 0 for b in range(B): @@ -694,7 +708,7 @@ def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Te with torch.no_grad(): out_logprob = pred_logits.log_softmax(-1) - cost_cls = -out_logprob[:, tgt_cls] # stable + cost_cls = -out_logprob[:, tgt_cls] cost_l1 = torch.cdist(pred_boxes_b[:, :4], tgt_boxes[:, :4], p=1) cost_rot = 1.0 - torch.abs(pred_rot @ tgt_rot.T) @@ -712,7 +726,7 @@ def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Te target_classes[pos_idx] = tgt_cls[gt_idx] cls_weights = torch.ones(self.num_classes, device=device) - cls_weights[background_idx] = 1.0 + cls_weights[background_idx] = 0.1 total_cls += F.cross_entropy(pred_logits, target_classes, weight=cls_weights) @@ -724,20 +738,11 @@ def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Te pred_sel = pred_boxes_b[pos_idx] tgt_sel = tgt_boxes[gt_idx] - # Smooth L1 (Huber) loss with beta=0.1: behaves like L2 for large errors early in - # training (gentle, stable gradient) and like L1 for small errors later (sharp - # localisation). Raw L1 with a large weight causes loss explosion when predictions - # are far from targets (e.g. 5 * 2.2 = 11.0 per pair vs cls ~2.0). l1_loss = F.smooth_l1_loss(pred_sel[:, :4], tgt_sel[:, :4], reduction="sum", beta=0.1) probiou_loss = _probiou_loss(pred_sel, tgt_sel).sum() total_box += 5.0 * l1_loss + 2.0 * probiou_loss - # FIX (issue #7): normalise box loss by total matched boxes (min 1), not by batch size. - # This prevents images with many GT boxes from dominating over images with few. - # Normalise box loss by total matched boxes, with a floor of 1 to keep - # the box/cls loss ratio stable even when images have very few GT boxes. num_matched_total = max(num_matched_total, 1) - loss_cls = total_cls / B loss_box = total_box / num_matched_total From 4774d9fede460bcbfce0f9307cfb3cfe3b2df756 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 28 May 2026 15:28:14 +0200 Subject: [PATCH 04/15] update --- doctr/models/layout/lw_detr/base.py | 11 ----------- doctr/models/layout/lw_detr/pytorch.py | 1 - 2 files changed, 12 deletions(-) diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index 1a7cb452cf..8c461c9fc2 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -301,17 +301,6 @@ def to_quad(box: np.ndarray): if boxes.ndim == 1: boxes = boxes[None, :] - # Sanity check: coordinates must be in [0, 1] normalized space. - # Values > 1.5 almost certainly indicate pixel coordinates were passed in. - flat = boxes.ravel() - coord_vals = flat[flat > 0] - if len(coord_vals) > 0 and coord_vals.max() > 1.5: - raise ValueError( - f"build_target expects normalized [0, 1] box coordinates, " - f"but found values up to {coord_vals.max():.1f} for class '{class_name}'. " - f"Divide your coordinates by image width/height before calling build_target." - ) - for box in boxes: poly = to_quad(box) obb = _quad_to_obb(poly) diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index e813becd0e..d760504378 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -244,7 +244,6 @@ def __init__( if hasattr(m, "bias") and m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): - # Don't overwrite the carefully seeded reference point embedding if m is not self.reference_point_embed: nn.init.normal_(m.weight, std=0.02) elif isinstance(m, LWDETRMultiscaleDeformableAttention): From 1fa67955691b28f49793dde2a430a6b5a0ba356c Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 5 Jun 2026 09:21:27 +0200 Subject: [PATCH 05/15] loss and model fixes --- doctr/models/layout/lw_detr/base.py | 19 +- doctr/models/layout/lw_detr/layers/pytorch.py | 42 +- doctr/models/layout/lw_detr/pytorch.py | 379 +++++++++--------- 3 files changed, 218 insertions(+), 222 deletions(-) diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index 8c461c9fc2..9c6890e6b0 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -144,18 +144,11 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int for b in range(boxes.shape[0]): # Convert logits to probabilities and get scores and labels - exp = np.exp(logits[b] - logits[b].max(axis=-1, keepdims=True)) - prob = exp / exp.sum(axis=-1, keepdims=True) + prob = 1.0 / (1.0 + np.exp(-logits[b])) scores = prob.max(axis=-1) labels = prob.argmax(axis=-1) - # treat background as invalid prediction - bg = self.num_classes - 1 - valid = labels != bg - - scores = scores * valid - # Keep only topk predictions before NMS if self.topk is not None and len(scores) > self.topk: idxs = np.argpartition(-scores, self.topk)[: self.topk] @@ -167,11 +160,11 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int labels_b = labels[idxs] bboxes = boxes[b][idxs] - mask = scores_b > self.score_thresh - - bboxes = bboxes[mask] - scores_b = scores_b[mask] - labels_b = labels_b[mask] + # Filter by score threshold + thresh_mask = scores_b >= self.score_thresh + scores_b = scores_b[thresh_mask] + labels_b = labels_b[thresh_mask] + bboxes = bboxes[thresh_mask] polys, _ = ( self._decode_boxes(bboxes) diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py index e9f319e664..f0bc63c865 100644 --- a/doctr/models/layout/lw_detr/layers/pytorch.py +++ b/doctr/models/layout/lw_detr/layers/pytorch.py @@ -293,6 +293,8 @@ def forward( else: raise ValueError(f"Last dim of reference_points must be 4 or 6, but got {reference_points.shape[-1]}") + # clamp sampling locations to keep them within the valid range [0, 1] for grid sampling + sampling_locations = sampling_locations.clamp(0.0, 1.0) output = self.attn( value, spatial_shapes_list, @@ -482,11 +484,7 @@ def __init__( self.bbox_embed = bbox_embed self.ref_point_head = LWDETRHead(2 * self.d_model, self.d_model, self.d_model, num_layers=2) - self.angle_proj = nn.Sequential( - nn.Linear(4, self.d_model), - nn.ReLU(), - nn.Linear(self.d_model, self.d_model), - ) + self.angle_proj = nn.Linear(2, self.d_model) def get_reference( self, reference_points: torch.Tensor, valid_ratios: torch.Tensor @@ -517,26 +515,13 @@ def get_reference( # DETR positional encoding query_sine_embed = gen_sine_position_embeddings(spatial_inputs[:, :, 0, :], self.d_model) base_query_pos = self.ref_point_head(query_sine_embed) - # Angle embedding - sin_t = angle[..., 0:1] - cos_t = angle[..., 1:2] - - angle_feat = torch.cat( - [ - sin_t, - cos_t, - 2 * sin_t * cos_t, - cos_t**2 - sin_t**2, - ], - dim=-1, - ) - angle_emb = self.angle_proj(angle_feat) + angle_emb = self.angle_proj(angle) # Combine query_pos = base_query_pos + angle_emb return reference_points_inputs, query_pos - def refine_boxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: + def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: """Refine bounding boxes by applying the predicted deltas to the reference points. The reference points are in the format (cx, cy, w, h, sinθ, cosθ), and the deltas are in the same format. The refined boxes are computed as follows: @@ -556,11 +541,20 @@ def refine_boxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> refined_boxes: (N, S, 6) tensor containing the refined bounding boxes """ reference_points = reference_points.to(deltas.device) - cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] # size - wh = torch.clamp(deltas[..., 2:4], min=-10.0, max=10.0).exp() * reference_points[..., 2:4] + wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=4.0).exp() * reference_points[..., 2:4] + wh = wh.clamp(min=1e-4, max=1.0) + # center + raw_cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] + half_wh = wh / 2 + cxcy = raw_cxcy.clamp( + min=half_wh, + max=1.0 - half_wh, + ) # rotation - delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) + sin_d = deltas[..., 4:5] + cos_d = deltas[..., 5:6] + 1.0 + delta_rot = F.normalize(torch.cat([sin_d, cos_d], dim=-1), dim=-1, eps=1e-6) sin_delta = delta_rot[..., 0:1] cos_delta = delta_rot[..., 1:2] sin_ref = reference_points[..., 4:5] @@ -605,7 +599,7 @@ def forward( if self.bbox_embed is not None: delta = self.bbox_embed(hidden_states_norm) - reference_points = self.refine_boxes(reference_points, delta) + reference_points = self.refine_bboxes(reference_points, delta) intermediate_reference_points.append(reference_points) reference_points_inputs, query_pos = self.get_reference( diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index d760504378..57331c8c85 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -154,17 +154,17 @@ def __init__( self, feat_extractor: LWDETRBackbone, class_names: list[str], - score_thresh: float = 0.3, + score_thresh: float = 0.5, iou_thresh: float = 0.5, d_model: int = 256, - num_queries: int = 300, + num_queries: int = 100, group_detr: int = 13, dec_layers: int = 3, sa_num_heads: int = 8, ca_num_heads: int = 16, ff_dim: int = 2048, dec_n_points: int = 2, - dropout_prob: float = 0.0, + dropout_prob: float = 0.1, assume_straight_pages: bool = True, exportable: bool = False, cfg: dict[str, Any] | None = None, @@ -186,10 +186,11 @@ def __init__( self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 6) # Initialize angle to (sin=0, cos=1) with torch.no_grad(): - self.reference_point_embed.weight[:, 0:2].uniform_(0.05, 0.95) - self.reference_point_embed.weight[:, 2:4].fill_(0.1) - self.reference_point_embed.weight[:, 4].zero_() - self.reference_point_embed.weight[:, 5].fill_(1.0) + self.reference_point_embed.weight[:, 0:2].uniform_(0.05, 0.95) # cx, cy + self.reference_point_embed.weight[:, 2].uniform_(0.1, 0.6) # w + self.reference_point_embed.weight[:, 3].uniform_(0.02, 0.3) # h + self.reference_point_embed.weight[:, 4].zero_() # sinθ + self.reference_point_embed.weight[:, 5].fill_(1.0) # cosθ self.query_feat = nn.Embedding(self.num_queries * self.group_detr, self.d_model) @@ -281,8 +282,6 @@ def __init__( if isinstance(last, nn.Linear): nn.init.zeros_(last.weight) nn.init.zeros_(last.bias) - if last.bias.shape[0] == 6: - nn.init.constant_(last.bias[5], 1.0) def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: """Load pretrained parameters onto the model @@ -293,41 +292,6 @@ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: """ load_pretrained_params(self, path_or_url, **kwargs) - def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: - """Refine bounding boxes by applying the predicted deltas to the reference points. - The reference points are in the format (cx, cy, w, h, sinθ, cosθ), and the deltas are in the same format. - The refined boxes are computed as follows: - - cx' = cx + delta_cx * w - cy' = cy + delta_cy * h - w' = w * exp(delta_w) - h' = h * exp(delta_h) - sinθ' = sinθ * cosΔ + cosθ * sinΔ - cosθ' = cosθ * cosΔ - sinθ * sinΔ - - Args: - reference_points: (N, S, 6) tensor containing the reference points - deltas: (N, S, 6) tensor containing the predicted deltas - - Returns: - refined_boxes: (N, S, 6) tensor containing the refined bounding boxes - """ - reference_points = reference_points.to(deltas.device) - cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] - # size - wh = torch.clamp(deltas[..., 2:4], min=-10.0, max=10.0).exp() * reference_points[..., 2:4] - # rotation - delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) - sin_delta = delta_rot[..., 0:1] - cos_delta = delta_rot[..., 1:2] - sin_ref = reference_points[..., 4:5] - cos_ref = reference_points[..., 5:6] - # compose rotations - sin_new = sin_ref * cos_delta + cos_ref * sin_delta - cos_new = cos_ref * cos_delta - sin_ref * sin_delta - rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) - return torch.cat((cxcy, wh, rot), dim=-1) - def get_valid_ratio(self, mask: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: """Get the valid ratio of all feature maps. @@ -402,7 +366,7 @@ def gen_encoder_output_proposals( _cur += height * width output_proposals = torch.cat(proposals, 1) - spatial_valid = ((output_proposals[..., :4] > 0.01) & (output_proposals[..., :4] < 0.99)).all(-1, keepdim=True) + spatial_valid = ((output_proposals[..., :2] > 0.01) & (output_proposals[..., :2] < 0.99)).all(-1, keepdim=True) output_proposals_valid = spatial_valid invalid_mask = padding_mask.unsqueeze(-1) | ~output_proposals_valid output_proposals = output_proposals.masked_fill(invalid_mask, float(0)) @@ -484,7 +448,7 @@ def forward( group_enc_outputs_class_masked = group_enc_outputs_class.masked_fill(invalid_mask, float("-inf")) group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query) - group_enc_outputs_coord = self.refine_bboxes(output_proposals, group_delta_bbox) + group_enc_outputs_coord = self.decoder.refine_bboxes(output_proposals, group_delta_bbox) all_group_enc_coords.append(group_enc_outputs_coord) @@ -507,7 +471,7 @@ def forward( topk_coords_logits = torch.cat(topk_coords_logits_list, 1) - reference_points = self.refine_bboxes(topk_coords_logits, reference_points) + reference_points = self.decoder.refine_bboxes(topk_coords_logits, reference_points) last_hidden_states, intermediate, intermediate_reference_points = self.decoder( inputs_embeds=tgt, @@ -543,33 +507,19 @@ def _postprocess(logits, boxes): # Build target processed_targets = self.build_target(target, self.class_names) - # Main loss from final decoder layer (group DETR) - split_logits = logits.chunk(group_detr, dim=1) - split_boxes = pred_boxes.chunk(group_detr, dim=1) - - main_loss: float | torch.Tensor = 0.0 - for g_logits, g_boxes in zip(split_logits, split_boxes): - main_loss += self.compute_loss(g_logits, g_boxes, processed_targets) - loss = main_loss / group_detr + # Main loss from final decoder layer + loss = self.compute_loss(logits, pred_boxes, processed_targets) - # Auxiliary losses from intermediate decoder layers (group DETR) + # Auxiliary losses from intermediate decoder layers for i in range(intermediate.shape[0] - 1): aux_logits = self.class_embed(intermediate[i]) aux_boxes = intermediate_reference_points[i + 1] - - split_aux_logits = aux_logits.chunk(group_detr, dim=1) - split_aux_boxes = aux_boxes.chunk(group_detr, dim=1) - - aux_loss: float | torch.Tensor = 0.0 - for g_logits, g_boxes in zip(split_aux_logits, split_aux_boxes): - aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets) - loss += aux_loss / group_detr + loss += self.compute_loss(aux_logits, aux_boxes, processed_targets) # Auxiliary losses for encoder proposals - enc_loss: float | torch.Tensor = 0.0 - for group_logits, group_coords in zip(all_group_enc_logits, all_group_enc_coords): - enc_loss += self.compute_loss(group_logits, group_coords, processed_targets) - loss += enc_loss / group_detr + enc_logits = torch.cat(all_group_enc_logits, dim=1) + enc_coords = torch.cat(all_group_enc_coords, dim=1) + loss += 0.2 * self.compute_loss(enc_logits, enc_coords, processed_targets) out["loss"] = loss @@ -581,36 +531,41 @@ def compute_loss( pred_boxes: torch.Tensor, targets: list[dict[str, np.ndarray]], ) -> torch.Tensor: - """Compute the loss between predicted logits and boxes and target labels and boxes. + """Compute the loss using Grouped Hungarian Matching + and consistent ProbIoU semantics for rotated bounding boxes. Args: - logits: (N, S, C) tensor containing the predicted class logits for each query - pred_boxes: (N, S, 6) tensor containing the predicted boxes in (cx, cy, w, h, sinθ, cosθ) format - targets: list of length N, where each element is a dict with keys "labels" and "boxes", + logits: (B, Q, C) tensor containing the predicted class logits for each query + pred_boxes: (B, Q, 6) tensor containing the predicted boxes in (cx, cy, w, h, sinθ, cosθ) format + targets: list of length B, where each element is a dict with keys "labels" and "boxes", containing the ground truth labels and boxes for each image in the batch. - The boxes are in (cx, cy, w, h, sinθ, cosθ) format. Returns: A scalar tensor containing the computed loss. """ + device = logits.device + dtype = logits.dtype + B, Q, C = logits.shape + + # Consistent coefficients across matcher and loss components + class_weight = 2.0 + bbox_weight = 5.0 + probiou_weight = 2.0 + rot_weight = 0.5 + + # Focal Loss Params + alpha = 0.25 + gamma = 2.0 + eps = 1e-7 + + group_detr = getattr(self, "group_detr", 1) def _rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Convert rotated boxes in (cx, cy, w, h, sinθ, cosθ) format - to Gaussian distributions (mean and covariance). - The mean is simply (cx, cy), and the covariance is computed from the width, height, and rotation angle. - - Args: - boxes: (N, S, 6) tensor containing the rotated boxes in (cx, cy, w, h, sinθ, cosθ) format - Returns: - A tuple of (mean, covariance) where: - - mean is a (N, S, 2) tensor containing the mean (cx, cy) of the Gaussian distributions - - covariance is a (N, S, 2, 2) tensor containing the covariance matrices of the Gaussian distributions - """ + """Convert rotated boxes to Gaussian distributions using the true + variance of a uniform continuous rectangle (w^2 / 12).""" cxcy = boxes[..., :2] - w = boxes[..., 2].clamp(min=1e-6) h = boxes[..., 3].clamp(min=1e-6) - sin = boxes[..., 4] cos = boxes[..., 5] @@ -622,130 +577,184 @@ def _rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch dim=-2, ) - # Variance for a box half-width/half-height: σ² = (w/2)² - # Using w²/12 (uniform distribution) produces ~8x smaller variance, - # which collapses Bhattacharyya distance to the clamp ceiling and kills gradients. - sx = (w / 2) ** 2 - sy = (h / 2) ** 2 + sx = (w**2) / 12.0 + sy = (h**2) / 12.0 - S = torch.zeros((*boxes.shape[:-1], 2, 2), device=boxes.device) + S = torch.zeros((*boxes.shape[:-1], 2, 2), device=boxes.device, dtype=boxes.dtype) S[..., 0, 0] = sx S[..., 1, 1] = sy covariance = R @ S @ R.transpose(-1, -2) return cxcy, covariance - def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Tensor: - """Compute the ProbIoU loss between predicted and target boxes, - where boxes are represented as Gaussian distributions. - The ProbIoU loss is defined as 1 - exp(-Bhattacharyya distance), - where the Bhattacharyya distance is computed between the two Gaussian distributions. - - Args: - pred_boxes: (N, S, 6) tensor containing the predicted boxes in (cx, cy, w, h, sinθ, cosθ) format - tgt_boxes: (N, S, 6) tensor containing the target boxes in (cx, cy, w, h, sinθ, cosθ) format - Returns: - A (N, S) tensor containing the ProbIoU loss for each pair of predicted and target boxes - """ - mu1, sigma1 = _rotated_boxes_to_gaussian(pred_boxes) - mu2, sigma2 = _rotated_boxes_to_gaussian(tgt_boxes) - + def _bhattacharyya_distance( + mu1: torch.Tensor, sigma1: torch.Tensor, mu2: torch.Tensor, sigma2: torch.Tensor + ) -> torch.Tensor: + """Compute Bhattacharyya distance with broadcast support.""" delta = (mu1 - mu2).unsqueeze(-1) sigma = (sigma1 + sigma2) * 0.5 - eps = 1e-6 - eye = torch.eye(2, device=sigma.device) * eps - + eye = torch.eye(2, device=sigma.device, dtype=sigma.dtype) * 1e-6 sigma_safe = sigma + eye sigma1_safe = sigma1 + eye sigma2_safe = sigma2 + eye - sigma_inv = torch.linalg.inv(sigma_safe) + L = torch.linalg.cholesky(sigma_safe) + sigma_inv = torch.cholesky_inverse(L) mahalanobis = (delta.transpose(-1, -2) @ sigma_inv @ delta).squeeze(-1).squeeze(-1) - det_sigma = torch.linalg.det(sigma_safe).clamp(min=eps) - det_sigma1 = torch.linalg.det(sigma1_safe).clamp(min=eps) - det_sigma2 = torch.linalg.det(sigma2_safe).clamp(min=eps) + det_sigma = torch.linalg.det(sigma_safe).clamp(min=1e-6) + det_sigma1 = torch.linalg.det(sigma1_safe).clamp(min=1e-6) + det_sigma2 = torch.linalg.det(sigma2_safe).clamp(min=1e-6) bhattacharyya = 0.125 * mahalanobis + 0.5 * torch.log(det_sigma / torch.sqrt(det_sigma1 * det_sigma2)) + return bhattacharyya.clamp(min=0.0) + + # Prepare targets for matching + target_labels = [] + target_boxes = [] + sizes = [] + for t in targets: + lbls = torch.as_tensor(t["labels"], device=device, dtype=torch.long) + bxs = torch.as_tensor(t["boxes"], device=device, dtype=pred_boxes.dtype) + if bxs.ndim == 1 and bxs.numel() > 0: + bxs = bxs.unsqueeze(0) + target_labels.append(lbls) + target_boxes.append(bxs) + sizes.append(len(lbls)) + + # Unified formulation for empty batches + if sum(sizes) == 0: + prob = logits.sigmoid() + prob_safe = prob.clamp(min=eps, max=1.0 - eps) + neg_weights = prob.pow(gamma) + loss_ce = -neg_weights * (1.0 - prob_safe).log() + return class_weight * (loss_ce.sum() / (B * Q)) + + tgt_ids = torch.cat(target_labels) + tgt_bbox = torch.cat(target_boxes) + + # Matcher: Grouped Hungarian Assignment with a balanced cost matrix + with torch.no_grad(): + out_prob = logits.flatten(0, 1).sigmoid() + out_bbox = pred_boxes.flatten(0, 1) - bhattacharyya = torch.clamp(bhattacharyya, min=0.0, max=10.0) - probiou = torch.exp(-bhattacharyya) - return 1 - probiou - - device = logits.device - B, Q, C = logits.shape - - total_cls = torch.tensor(0.0, device=device) - total_box = torch.tensor(0.0, device=device) - - num_matched_total = 0 - - for b in range(B): - pred_logits = logits[b] - pred_boxes_b = pred_boxes[b] - - boxes = targets[b]["boxes"] - - if len(boxes) == 0: - # Penalize the model for any foreground boxes it guessed on this empty image - background_idx = self.num_classes - 1 - target_classes = torch.full((Q,), background_idx, device=device, dtype=torch.long) - total_cls += F.cross_entropy(pred_logits, target_classes) - continue - - tgt_boxes = torch.as_tensor(boxes, device=device, dtype=pred_boxes.dtype) - tgt_cls = torch.as_tensor(targets[b]["labels"], device=device, dtype=torch.long) - - if tgt_boxes.ndim == 1: - tgt_boxes = tgt_boxes.unsqueeze(0) - - pred_rot = F.normalize(pred_boxes_b[:, 4:6], dim=-1) - tgt_rot = F.normalize(tgt_boxes[:, 4:6], dim=-1) - - with torch.no_grad(): - out_logprob = pred_logits.log_softmax(-1) - - cost_cls = -out_logprob[:, tgt_cls] - cost_l1 = torch.cdist(pred_boxes_b[:, :4], tgt_boxes[:, :4], p=1) - cost_rot = 1.0 - torch.abs(pred_rot @ tgt_rot.T) - - total_cost = 2.0 * cost_cls + 5.0 * cost_l1 + 2.0 * cost_rot - - cost_np = total_cost.detach().cpu().numpy() - row_ind, col_ind = linear_sum_assignment(cost_np) - - pos_idx = torch.as_tensor(row_ind, device=device) - gt_idx = torch.as_tensor(col_ind, device=device) - - background_idx = self.num_classes - 1 - - target_classes = torch.full((Q,), background_idx, device=device, dtype=torch.long) - target_classes[pos_idx] = tgt_cls[gt_idx] - - cls_weights = torch.ones(self.num_classes, device=device) - cls_weights[background_idx] = 0.1 - - total_cls += F.cross_entropy(pred_logits, target_classes, weight=cls_weights) - - if pos_idx.numel() == 0: - continue + # Classification Cost (Focal Loss based) + neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + eps).log()) + pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + eps).log()) + class_cost = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] - num_matched_total += pos_idx.numel() + # Box L1 Cost + out_bbox_f = out_bbox.to(torch.float32) + tgt_bbox_f = tgt_bbox.to(torch.float32) + bbox_cost = torch.cdist(out_bbox_f[:, :4], tgt_bbox_f[:, :4], p=1).to(dtype) - pred_sel = pred_boxes_b[pos_idx] - tgt_sel = tgt_boxes[gt_idx] + # ProbIoU Cost + mu_pred, sig_pred = _rotated_boxes_to_gaussian(out_bbox_f) + mu_tgt, sig_tgt = _rotated_boxes_to_gaussian(tgt_bbox_f) - l1_loss = F.smooth_l1_loss(pred_sel[:, :4], tgt_sel[:, :4], reduction="sum", beta=0.1) - probiou_loss = _probiou_loss(pred_sel, tgt_sel).sum() - total_box += 5.0 * l1_loss + 2.0 * probiou_loss + bhat_dist = _bhattacharyya_distance( + mu_pred.unsqueeze(1), sig_pred.unsqueeze(1), mu_tgt.unsqueeze(0), sig_tgt.unsqueeze(0) + ) + probiou_cost = (1.0 - torch.exp(-bhat_dist)).to(dtype) + + # Rotation Cost + pred_rot = F.normalize(out_bbox_f[:, 4:6], dim=-1) + tgt_rot = F.normalize(tgt_bbox_f[:, 4:6], dim=-1) + rot_cost = (1.0 - torch.abs(pred_rot @ tgt_rot.T)).to(dtype) + + # Total balanced Cost Matrix + cost_matrix = ( + class_weight * class_cost + + bbox_weight * bbox_cost + + probiou_weight * probiou_cost + + rot_weight * rot_cost + ) + cost_matrix = cost_matrix.view(B, Q, -1).cpu() + + # Grouped Hungarian Assignment + indices = [] + group_num_queries = Q // group_detr + cost_matrix_groups = cost_matrix.split(group_num_queries, dim=1) + + for group_id in range(group_detr): + group_cost_matrix = cost_matrix_groups[group_id] + + # Split targets per batch element + group_indices = [] + for i, c in enumerate(group_cost_matrix.split(sizes, -1)): + if sizes[i] == 0: + group_indices.append((np.array([], dtype=np.int64), np.array([], dtype=np.int64))) + else: + row_ind, col_ind = linear_sum_assignment(c[i].numpy()) + group_indices.append((row_ind, col_ind)) + + if group_id == 0: + indices = group_indices + else: + indices = [ + ( + np.concatenate([idx1[0], idx2[0] + group_num_queries * group_id]), + np.concatenate([idx1[1], idx2[1]]), + ) + for idx1, idx2 in zip(indices, group_indices) + ] + + # Image lovel loss normalization: scale by the number of matched boxes, + # and the number of active groups in group DETR + # Scale denominator by the number of active assignment groups + num_boxes = max(sum(sizes) * group_detr, 1) + + batch_idx = torch.cat([torch.full((len(src),), i, dtype=torch.long) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([torch.as_tensor(src, dtype=torch.long) for (src, _) in indices]) + + flat_tgt_idx_list = [] + offset = 0 + for i, (_, tgt) in enumerate(indices): + flat_tgt_idx_list.append(torch.as_tensor(tgt, dtype=torch.long) + offset) + offset += sizes[i] + flat_tgt_idx = torch.cat(flat_tgt_idx_list) + + target_classes_o = tgt_ids[flat_tgt_idx] + src_boxes = pred_boxes[batch_idx, src_idx] + target_boxes_matched = tgt_bbox[flat_tgt_idx] + + # Label Loss with Quality Mapping + prob = logits.sigmoid() + + mu1, sig1 = _rotated_boxes_to_gaussian(src_boxes.detach().to(torch.float32)) + mu2, sig2 = _rotated_boxes_to_gaussian(target_boxes_matched.detach().to(torch.float32)) + bhat_matched = _bhattacharyya_distance(mu1, sig1, mu2, sig2) + pos_ious = torch.exp(-bhat_matched).clamp(min=0.0, max=1.0).to(dtype) + + pos_weights = torch.zeros_like(logits) + neg_weights = prob.pow(gamma) + pos_ind = (batch_idx, src_idx, target_classes_o) + + pos_quality = prob[pos_ind].pow(alpha) * pos_ious.pow(1 - alpha) + pos_quality = torch.clamp(pos_quality, 0.01).detach() + + pos_weights[pos_ind] = pos_quality + neg_weights[pos_ind] = 1 - pos_quality + + # AMP safety for log computation + prob_safe = prob.clamp(min=eps, max=1.0 - eps) + loss_ce = -pos_weights * prob_safe.log() - neg_weights * (1.0 - prob_safe).log() + loss_ce = loss_ce.sum() / num_boxes + + # Bounding Box Loss + loss_bbox = ( + F.smooth_l1_loss(src_boxes[:, :4], target_boxes_matched[:, :4], reduction="sum", beta=0.1) / num_boxes + ) - num_matched_total = max(num_matched_total, 1) - loss_cls = total_cls / B - loss_box = total_box / num_matched_total + # ProbIoU Loss + mu1_l, sig1_l = _rotated_boxes_to_gaussian(src_boxes.to(torch.float32)) + mu2_l, sig2_l = _rotated_boxes_to_gaussian(target_boxes_matched.to(torch.float32)) + bhat_loss = _bhattacharyya_distance(mu1_l, sig1_l, mu2_l, sig2_l) + loss_probiou = (1.0 - torch.exp(-bhat_loss)).to(dtype).sum() / num_boxes - return loss_cls + loss_box + return class_weight * loss_ce + bbox_weight * loss_bbox + probiou_weight * loss_probiou def _lw_detr( From 408797b473fef795ab98c6cd9fc75285e9b2c0f0 Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 5 Jun 2026 09:22:25 +0200 Subject: [PATCH 06/15] Update train script --- references/layout/train.py | 91 +++++++++++++++++++++++++++++++------- 1 file changed, 75 insertions(+), 16 deletions(-) diff --git a/references/layout/train.py b/references/layout/train.py index f3f0d2c117..b9a3a29b03 100644 --- a/references/layout/train.py +++ b/references/layout/train.py @@ -17,7 +17,14 @@ # The following import is required for DDP import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR +from torch.optim.lr_scheduler import ( + CosineAnnealingLR, + LinearLR, + MultiplicativeLR, + OneCycleLR, + PolynomialLR, + SequentialLR, +) from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler from torchvision.transforms.v2 import Compose, Normalize, RandomGrayscale, RandomPhotometricDistort @@ -82,14 +89,14 @@ def record_lr( scaler.scale(train_loss).backward() # Gradient clipping scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) # Update the params scaler.step(optimizer) scaler.update() else: train_loss = model(imgs, padding_masks, targets)["loss"] train_loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) optimizer.step() # Update LR scheduler.step() @@ -130,14 +137,14 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a scaler.scale(train_loss).backward() # Gradient clipping scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) # Update the params scaler.step(optimizer) scaler.update() else: train_loss = model(imgs, padding_masks, targets)["loss"] train_loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) optimizer.step() scheduler.step() @@ -472,20 +479,29 @@ def main(args): # construct DDP model model = DDP(model, device_ids=[rank]) + backbone_params = [p for n, p in model.named_parameters() if n.startswith("feat_extractor.") and p.requires_grad] + decoder_params = [p for n, p in model.named_parameters() if not n.startswith("feat_extractor.") and p.requires_grad] + # Optimizer if args.optim == "adam": optimizer = torch.optim.Adam( - [p for p in model.parameters() if p.requires_grad], - args.lr, - betas=(0.95, 0.999), + [ + {"params": backbone_params, "lr": 1e-5, "weight_decay": args.weight_decay or 1e-4}, + {"params": decoder_params, "lr": args.lr, "weight_decay": args.weight_decay or 1e-4}, + ], + lr=args.lr, + betas=(0.9, 0.999), eps=1e-6, - weight_decay=args.weight_decay, + weight_decay=args.weight_decay or 1e-4, ) elif args.optim == "adamw": optimizer = torch.optim.AdamW( - [p for p in model.parameters() if p.requires_grad], - args.lr, + [ + {"params": backbone_params, "lr": 1e-5, "weight_decay": args.weight_decay or 1e-4}, + {"params": decoder_params, "lr": args.lr, "weight_decay": args.weight_decay or 1e-4}, + ], + lr=args.lr, betas=(0.9, 0.999), eps=1e-6, weight_decay=args.weight_decay or 1e-4, @@ -498,12 +514,55 @@ def main(args): return # Scheduler + total_steps = args.epochs * len(train_loader) + warmup_steps = min(1000, max(200, total_steps // 20)) + if args.sched == "cosine": - scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4) + warmup = LinearLR( + optimizer, + start_factor=0.01, + end_factor=1.0, + total_iters=warmup_steps, + ) + cosine = CosineAnnealingLR( + optimizer, + T_max=total_steps - warmup_steps, + eta_min=args.lr * 0.01, + ) + scheduler = SequentialLR( + optimizer, + schedulers=[warmup, cosine], + milestones=[warmup_steps], + ) + elif args.sched == "onecycle": - scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader)) + scheduler = OneCycleLR( + optimizer, + max_lr=[g["lr"] for g in optimizer.param_groups], + total_steps=total_steps, + pct_start=warmup_steps / total_steps, + div_factor=100, + final_div_factor=100, + anneal_strategy="cos", + ) + elif args.sched == "poly": - scheduler = PolynomialLR(optimizer, args.epochs * len(train_loader)) + warmup = LinearLR( + optimizer, + start_factor=0.01, + end_factor=1.0, + total_iters=warmup_steps, + ) + poly = PolynomialLR( + optimizer, + total_iters=total_steps - warmup_steps, + power=1.0, + ) + scheduler = SequentialLR( + optimizer, + schedulers=[warmup, poly], + milestones=[warmup_steps], + ) # Training monitoring current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") @@ -690,8 +749,8 @@ def parse_args(): "--save-interval-epoch", dest="save_interval_epoch", action="store_true", help="Save model every epoch" ) parser.add_argument("--input_size", type=int, default=1024, help="model input size, H = W") - parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam or AdamW)") - parser.add_argument("--wd", "--weight-decay", default=0, type=float, help="weight decay", dest="weight_decay") + parser.add_argument("--lr", type=float, default=4e-4, help="learning rate for the optimizer (Adam or AdamW)") + parser.add_argument("--wd", "--weight-decay", default=1e-4, type=float, help="weight decay", dest="weight_decay") parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading") parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint") parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop") From d8f5d762fe76743976de46d2cc102c5ac0ac2b9f Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 5 Jun 2026 11:02:03 +0200 Subject: [PATCH 07/15] Update layout post proc --- doctr/models/layout/lw_detr/base.py | 7 +++++-- tests/pytorch/test_models_layout.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index 9c6890e6b0..a33a050397 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -146,8 +146,11 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int # Convert logits to probabilities and get scores and labels prob = 1.0 / (1.0 + np.exp(-logits[b])) - scores = prob.max(axis=-1) - labels = prob.argmax(axis=-1) + # Remove background class + prob_fg = prob[:, :-1] + + scores = prob_fg.max(axis=-1) + labels = prob_fg.argmax(axis=-1) # Keep only topk predictions before NMS if self.topk is not None and len(scores) > self.topk: diff --git a/tests/pytorch/test_models_layout.py b/tests/pytorch/test_models_layout.py index d0c2865411..1094eee76f 100644 --- a/tests/pytorch/test_models_layout.py +++ b/tests/pytorch/test_models_layout.py @@ -76,7 +76,7 @@ def test_layout_models(arch_name, input_shape, train_mode, use_polygons): assert isinstance(results[1], np.ndarray) and results[1].shape == (len(results[0]), 4) assert isinstance(results[2], list) and all(isinstance(scores, float) for scores in results[2]) # Check class idxs are in the model's num_classes - assert all(0 <= idx < model.num_classes for idx in results[0]) + assert all(0 <= idx < len(model.class_names) for idx in results[0]) # Check scores are between 0 and 1 assert all(0 <= score <= 1 for score in results[2]) # Check that the number of boxes, labels and scores are the same From f2e239eb16b6e32a265091d70320d2e0dff01983 Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 5 Jun 2026 11:23:12 +0200 Subject: [PATCH 08/15] Update loss --- doctr/models/layout/lw_detr/pytorch.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 57331c8c85..c329ddffb3 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -754,7 +754,12 @@ def _bhattacharyya_distance( bhat_loss = _bhattacharyya_distance(mu1_l, sig1_l, mu2_l, sig2_l) loss_probiou = (1.0 - torch.exp(-bhat_loss)).to(dtype).sum() / num_boxes - return class_weight * loss_ce + bbox_weight * loss_bbox + probiou_weight * loss_probiou + # Rotation Loss + pred_rot = F.normalize(src_boxes[:, 4:6], dim=-1, eps=1e-6) + tgt_rot = F.normalize(target_boxes_matched[:, 4:6], dim=-1, eps=1e-6) + loss_rot = (1.0 - torch.abs((pred_rot * tgt_rot).sum(dim=-1))).sum() / num_boxes + + return class_weight * loss_ce + bbox_weight * loss_bbox + probiou_weight * loss_probiou + rot_weight * loss_rot def _lw_detr( From ad1693c802d7dfc2d64a1f1d6589c3267b88b381 Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 5 Jun 2026 12:23:29 +0200 Subject: [PATCH 09/15] amp --- doctr/models/layout/lw_detr/pytorch.py | 31 ++++++++++--------- references/classification/train_character.py | 10 +++--- .../classification/train_orientation.py | 10 +++--- references/detection/evaluate.py | 2 +- references/detection/train.py | 10 +++--- references/layout/evaluate.py | 2 +- references/layout/train.py | 10 +++--- references/recognition/evaluate.py | 2 +- references/recognition/train.py | 10 +++--- 9 files changed, 45 insertions(+), 42 deletions(-) diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index c329ddffb3..fdb59fb2c6 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -157,7 +157,7 @@ def __init__( score_thresh: float = 0.5, iou_thresh: float = 0.5, d_model: int = 256, - num_queries: int = 100, + num_queries: int = 195, # This is different from the paper which uses 300 queries, but 195 queries is sufficient for document layout analysis) # noqa: E501 group_detr: int = 13, dec_layers: int = 3, sa_num_heads: int = 8, @@ -507,21 +507,25 @@ def _postprocess(logits, boxes): # Build target processed_targets = self.build_target(target, self.class_names) - # Main loss from final decoder layer - loss = self.compute_loss(logits, pred_boxes, processed_targets) + # Disable mixed precision for loss computation to ensure numerical stability, + # especially for the Bhattacharyya distance which involves + # logarithms and determinants of covariance matrices. + with torch.autocast(device_type=logits.device.type, enabled=False): + # Main loss from final decoder layer + loss = self.compute_loss(logits.float(), pred_boxes.float(), processed_targets) - # Auxiliary losses from intermediate decoder layers - for i in range(intermediate.shape[0] - 1): - aux_logits = self.class_embed(intermediate[i]) - aux_boxes = intermediate_reference_points[i + 1] - loss += self.compute_loss(aux_logits, aux_boxes, processed_targets) + # Auxiliary losses from intermediate decoder layers + for i in range(intermediate.shape[0] - 1): + aux_logits = self.class_embed(intermediate[i]).float() + aux_boxes = intermediate_reference_points[i + 1].float() + loss += self.compute_loss(aux_logits, aux_boxes, processed_targets) - # Auxiliary losses for encoder proposals - enc_logits = torch.cat(all_group_enc_logits, dim=1) - enc_coords = torch.cat(all_group_enc_coords, dim=1) - loss += 0.2 * self.compute_loss(enc_logits, enc_coords, processed_targets) + # Auxiliary losses for encoder proposals + enc_logits = torch.cat(all_group_enc_logits, dim=1).float() + enc_coords = torch.cat(all_group_enc_coords, dim=1).float() + loss += 0.2 * self.compute_loss(enc_logits, enc_coords, processed_targets) - out["loss"] = loss + out["loss"] = loss return out @@ -738,7 +742,6 @@ def _bhattacharyya_distance( pos_weights[pos_ind] = pos_quality neg_weights[pos_ind] = 1 - pos_quality - # AMP safety for log computation prob_safe = prob.clamp(min=eps, max=1.0 - eps) loss_ce = -pos_weights * prob_safe.log() - neg_weights * (1.0 - prob_safe).log() loss_ce = loss_ce.sum() / num_boxes diff --git a/references/classification/train_character.py b/references/classification/train_character.py index c266bb2f5c..22a9298ecb 100644 --- a/references/classification/train_character.py +++ b/references/classification/train_character.py @@ -66,7 +66,7 @@ def record_lr( loss_recorder = [] if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for batch_idx, (images, targets) in enumerate(train_loader): targets = torch.tensor(targets) @@ -79,7 +79,7 @@ def record_lr( # Forward, Backward & update optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) train_loss = cross_entropy(out, targets) scaler.scale(train_loss).backward() @@ -110,7 +110,7 @@ def record_lr( def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None): if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") model.train() # Iterate over the batches of the dataset @@ -126,7 +126,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) train_loss = cross_entropy(out, targets) scaler.scale(train_loss).backward() @@ -168,7 +168,7 @@ def evaluate(model, val_loader, batch_transforms, amp=False, log=None): targets = targets.cuda() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) loss = cross_entropy(out, targets) else: diff --git a/references/classification/train_orientation.py b/references/classification/train_orientation.py index 86dc4c5931..0cbbf05616 100644 --- a/references/classification/train_orientation.py +++ b/references/classification/train_orientation.py @@ -78,7 +78,7 @@ def record_lr( loss_recorder = [] if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for batch_idx, (images, targets) in enumerate(train_loader): targets = torch.tensor(targets) @@ -91,7 +91,7 @@ def record_lr( # Forward, Backward & update optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) train_loss = cross_entropy(out, targets) scaler.scale(train_loss).backward() @@ -122,7 +122,7 @@ def record_lr( def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None): if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") model.train() # Iterate over the batches of the dataset @@ -138,7 +138,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) train_loss = cross_entropy(out, targets) scaler.scale(train_loss).backward() @@ -180,7 +180,7 @@ def evaluate(model, val_loader, batch_transforms, amp=False, log=None): targets = targets.cuda() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) loss = cross_entropy(out, targets) else: diff --git a/references/detection/evaluate.py b/references/detection/evaluate.py index 7c5fb597aa..a8674884f4 100644 --- a/references/detection/evaluate.py +++ b/references/detection/evaluate.py @@ -39,7 +39,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): images = batch_transforms(images) targets = [{CLASS_NAME: t} for t in targets] if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images, targets, return_preds=True) else: out = model(images, targets, return_preds=True) diff --git a/references/detection/train.py b/references/detection/train.py index 43e4ec3a56..37b439a6e5 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -63,7 +63,7 @@ def record_lr( loss_recorder = [] if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for batch_idx, (images, targets) in enumerate(train_loader): if torch.cuda.is_available(): @@ -74,7 +74,7 @@ def record_lr( # Forward, Backward & update optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(images, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping @@ -107,7 +107,7 @@ def record_lr( def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0): if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") model.train() # Iterate over the batches of the dataset @@ -120,7 +120,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(images, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping @@ -163,7 +163,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, args, amp=False, l images = images.cuda() images = batch_transforms(images) if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images, targets, return_preds=True) else: out = model(images, targets, return_preds=True) diff --git a/references/layout/evaluate.py b/references/layout/evaluate.py index 36fb78cf80..aa8ea788aa 100644 --- a/references/layout/evaluate.py +++ b/references/layout/evaluate.py @@ -39,7 +39,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): padding_masks = padding_masks.cuda() imgs = batch_transforms(imgs) if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(imgs, padding_masks, targets, return_preds=True) else: out = model(imgs, padding_masks, targets, return_preds=True) diff --git a/references/layout/train.py b/references/layout/train.py index b9a3a29b03..bf8a3a600d 100644 --- a/references/layout/train.py +++ b/references/layout/train.py @@ -70,7 +70,7 @@ def record_lr( loss_recorder = [] if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for batch_idx, (images, targets) in enumerate(train_loader): imgs, padding_masks = images @@ -84,7 +84,7 @@ def record_lr( # Forward, Backward & update optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(imgs, padding_masks, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping @@ -117,7 +117,7 @@ def record_lr( def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0): if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") model.train() # Iterate over the batches of the dataset @@ -132,7 +132,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(imgs, padding_masks, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping @@ -177,7 +177,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False, log=Non padding_masks = padding_masks.cuda() imgs = batch_transforms(imgs) if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(imgs, padding_masks, targets, return_preds=True) else: out = model(imgs, padding_masks, targets, return_preds=True) diff --git a/references/recognition/evaluate.py b/references/recognition/evaluate.py index 45a6b38306..22f69fa2cd 100644 --- a/references/recognition/evaluate.py +++ b/references/recognition/evaluate.py @@ -38,7 +38,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): images = images.cuda() images = batch_transforms(images) if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images, targets, return_preds=True) else: out = model(images, targets, return_preds=True) diff --git a/references/recognition/train.py b/references/recognition/train.py index dc6b7b1b24..fd0cd1826d 100644 --- a/references/recognition/train.py +++ b/references/recognition/train.py @@ -68,7 +68,7 @@ def record_lr( loss_recorder = [] if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for batch_idx, (images, targets) in enumerate(train_loader): if torch.cuda.is_available(): @@ -79,7 +79,7 @@ def record_lr( # Forward, Backward & update optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(images, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping @@ -112,7 +112,7 @@ def record_lr( def fit_one_epoch(model, device, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0): if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") model.train() # Iterate over the batches of the dataset @@ -125,7 +125,7 @@ def fit_one_epoch(model, device, train_loader, batch_transforms, optimizer, sche optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(images, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping @@ -167,7 +167,7 @@ def evaluate(model, device, val_loader, batch_transforms, val_metric, amp=False, images = images.to(device) images = batch_transforms(images) if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images, targets, return_preds=True) else: out = model(images, targets, return_preds=True) From e2ce2a733fef5f2c9a95618031befc78ee17dceb Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 12 Jun 2026 11:11:02 +0200 Subject: [PATCH 10/15] Update layout model & loss & train script --- doctr/models/layout/lw_detr/base.py | 39 +- doctr/models/layout/lw_detr/layers/pytorch.py | 198 +++--- doctr/models/layout/lw_detr/pytorch.py | 653 +++++++++--------- references/detection/train.py | 8 +- references/layout/train.py | 42 +- 5 files changed, 471 insertions(+), 469 deletions(-) diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index a33a050397..cbe9c80229 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -143,31 +143,26 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int results: list[tuple[list[int], np.ndarray, list[float]]] = [] for b in range(boxes.shape[0]): - # Convert logits to probabilities and get scores and labels - prob = 1.0 / (1.0 + np.exp(-logits[b])) + # Sigmoid scores (the model is trained with a sigmoid-based (IA-BCE) loss without + # a background class, as in LW-DETR) + prob = 1.0 / (1.0 + np.exp(-logits[b])) # (num_queries, num_classes) + num_classes = prob.shape[-1] - # Remove background class - prob_fg = prob[:, :-1] + # Keep only the topk (query, class) pairs before NMS + flat_prob = prob.reshape(-1) + topk = min(self.topk, flat_prob.size) if self.topk is not None else flat_prob.size + topk_idxs = np.argsort(flat_prob)[::-1][:topk] - scores = prob_fg.max(axis=-1) - labels = prob_fg.argmax(axis=-1) + scores_b = flat_prob[topk_idxs] + labels_b = topk_idxs % num_classes + query_idxs = topk_idxs // num_classes + bboxes = boxes[b][query_idxs] - # Keep only topk predictions before NMS - if self.topk is not None and len(scores) > self.topk: - idxs = np.argpartition(-scores, self.topk)[: self.topk] - idxs = idxs[np.argsort(-scores[idxs])] - else: - idxs = np.arange(len(scores)) + mask = scores_b > self.score_thresh - scores_b = scores[idxs] - labels_b = labels[idxs] - bboxes = boxes[b][idxs] - - # Filter by score threshold - thresh_mask = scores_b >= self.score_thresh - scores_b = scores_b[thresh_mask] - labels_b = labels_b[thresh_mask] - bboxes = bboxes[thresh_mask] + bboxes = bboxes[mask] + scores_b = scores_b[mask] + labels_b = labels_b[mask] polys, _ = ( self._decode_boxes(bboxes) @@ -309,7 +304,7 @@ def to_quad(box: np.ndarray): labels_all.append(cls_id) targets.append({ - "boxes": np.asarray(boxes_all, dtype=np.float32), + "boxes": np.asarray(boxes_all, dtype=np.float32) if boxes_all else np.zeros((0, 6), dtype=np.float32), "labels": np.asarray(labels_all, dtype=np.int64), }) diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py index f0bc63c865..ed970f9aa9 100644 --- a/doctr/models/layout/lw_detr/layers/pytorch.py +++ b/doctr/models/layout/lw_detr/layers/pytorch.py @@ -12,7 +12,55 @@ from doctr.models.modules import ChannelLayerNorm from doctr.models.utils import conv_sequence_pt -__all__ = ["MultiScaleProjector", "C2fBottleneck", "LWDETRHead", "LWDETRDecoder", "LWDETRMultiscaleDeformableAttention"] +__all__ = [ + "MultiScaleProjector", + "C2fBottleneck", + "LWDETRHead", + "LWDETRDecoder", + "LWDETRMultiscaleDeformableAttention", + "refine_obb_boxes", +] + + +def refine_obb_boxes(reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: + """Refine oriented bounding boxes by applying predicted deltas to reference points. + Both reference points and deltas are in the format (cx, cy, w, h, sinθ, cosθ). + + cx' = cx + delta_cx * w + cy' = cy + delta_cy * h + w' = w * exp(delta_w) + h' = h * exp(delta_h) + (sinθ', cosθ') = rotation composition of (sinθ, cosθ) with the normalized delta rotation + + Args: + reference_points: (..., 6) tensor containing the reference points + deltas: (..., 6) tensor containing the predicted deltas + + Returns: + refined_boxes: (..., 6) tensor containing the refined bounding boxes + """ + reference_points = reference_points.to(deltas.device) + # center + cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] + # size: clamp deltas to prevent exp() from overflowing during early training. + # NOTE: the upper bound must allow the encoder proposals (w = h = 0.05) to reach full-page boxes + # in a single refinement step: exp(3.5) * 0.05 ~= 1.66, so boxes up to (and beyond) the full canvas + # remain reachable. A tighter bound (e.g. 2.0 -> 0.05 * exp(2) ~= 0.37) silently caps the size of the + # encoder proposals and creates an irreducible loss floor for large layout items (tables, pictures). + wh = torch.clamp(deltas[..., 2:4], min=-5.0, max=3.5).exp() * reference_points[..., 2:4] + # rotation (eps avoids division-by-zero NaN creation) + delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) + sin_delta = delta_rot[..., 0:1] + cos_delta = delta_rot[..., 1:2] + sin_ref = reference_points[..., 4:5] + cos_ref = reference_points[..., 5:6] + + # compose rotations + sin_new = sin_ref * cos_delta + cos_ref * sin_delta + cos_new = cos_ref * cos_delta - sin_ref * sin_delta + rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) + + return torch.cat((cxcy, wh, rot), dim=-1) class LWDETRHead(nn.Module): @@ -109,7 +157,7 @@ def forward( # Crash prevention: ensure seq_len is perfectly divisible assert seq_len % self.group_detr == 0, ( f"Seq len {seq_len} must be divisible by group_detr {self.group_detr}" - ) # noqa: E501 + ) hidden_states_original = hidden_states if position_embeddings is not None: @@ -249,7 +297,7 @@ def forward( value = self.value_proj(encoder_hidden_states) if attention_mask is not None: - # we invert the attention_mask + # attention_mask contains True on padded positions -> zero-out the padded values value = value.masked_fill(attention_mask[..., None], float(0)) value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) sampling_offsets = self.sampling_offsets(hidden_states).view( @@ -293,8 +341,6 @@ def forward( else: raise ValueError(f"Last dim of reference_points must be 4 or 6, but got {reference_points.shape[-1]}") - # clamp sampling locations to keep them within the valid range [0, 1] for grid sampling - sampling_locations = sampling_locations.clamp(0.0, 1.0) output = self.attn( value, spatial_shapes_list, @@ -411,28 +457,26 @@ def gen_sine_position_embeddings(pos_tensor: torch.Tensor, hidden_size: int = 25 """ scale = 2 * math.pi dim = hidden_size // 2 - # Keep dim_t in float32 for numerical precision; cast output to match caller dtype dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device) dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim) - x_embed = pos_tensor[:, :, 0].float() * scale - y_embed = pos_tensor[:, :, 1].float() * scale + x_embed = pos_tensor[:, :, 0] * scale + y_embed = pos_tensor[:, :, 1] * scale pos_x = x_embed[:, :, None] / dim_t pos_y = y_embed[:, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) if pos_tensor.size(-1) == 4: - w_embed = pos_tensor[:, :, 2].float() * scale + w_embed = pos_tensor[:, :, 2] * scale pos_w = w_embed[:, :, None] / dim_t pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) - h_embed = pos_tensor[:, :, 3].float() * scale + h_embed = pos_tensor[:, :, 3] * scale pos_h = h_embed[:, :, None] / dim_t pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) else: raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}") - # Cast back to the caller's dtype (supports bfloat16 / float16 AMP) return pos.to(pos_tensor.dtype) @@ -484,104 +528,94 @@ def __init__( self.bbox_embed = bbox_embed self.ref_point_head = LWDETRHead(2 * self.d_model, self.d_model, self.d_model, num_layers=2) - self.angle_proj = nn.Linear(2, self.d_model) + self.angle_proj = nn.Sequential( + nn.Linear(4, self.d_model), + nn.ReLU(), + nn.Linear(self.d_model, self.d_model), + ) def get_reference( - self, reference_points: torch.Tensor, valid_ratios: torch.Tensor + self, + reference_points: torch.Tensor, + num_levels: int, ) -> tuple[torch.Tensor, torch.Tensor]: """This function computes the reference point inputs and positional embeddings for the decoder layers. Args: reference_points: (batch_size, num_queries, 6) tensor containing the current reference points in the format (cx, cy, w, h, sinθ, cosθ) - valid_ratios: (batch_size, num_levels, 2) - tensor containing the valid ratios for each level of the input feature maps + num_levels: number of feature levels used in the decoder (1 for LW-DETR small and medium) Returns: - reference_points_inputs: (batch_size, num_queries, 1, num_levels, 6) - tensor containing the reference point inputs for the decoder layers, - which are the normalized center coordinates, - width and height of the bounding boxes w.r.t. the valid ratios of the input feature maps + reference_points_inputs: (batch_size, num_queries, num_levels, 6) + tensor containing the reference point inputs for the decoder layers query_pos: (batch_size, num_queries, d_model) tensor containing the positional embeddings for the decoder layers, which are computed from the reference points using sine and cosine functions and a linear projection """ - obj_center = reference_points[..., :4] - spatial_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] - # Extract angles - angle = reference_points[..., 4:6] # (sin, cos) - angle_expanded = angle[:, :, None] + ref_xywh = reference_points[..., :4] + angle = reference_points[..., 4:6] + + spatial_inputs = ref_xywh[:, :, None].expand(-1, -1, num_levels, -1) + angle_expanded = angle[:, :, None].expand(-1, -1, num_levels, -1) + reference_points_inputs = torch.cat([spatial_inputs, angle_expanded], dim=-1) - # DETR positional encoding - query_sine_embed = gen_sine_position_embeddings(spatial_inputs[:, :, 0, :], self.d_model) + # generate sine positional embeddings from the reference points and + # project them to get the final query positional embeddings + query_sine_embed = gen_sine_position_embeddings(ref_xywh, self.d_model) base_query_pos = self.ref_point_head(query_sine_embed) - - angle_emb = self.angle_proj(angle) + # Angle embedding: use the same sine/cosine encoding scheme as for the spatial coordinates, + # but with a separate linear projection. + sin_t = angle[..., 0:1] + cos_t = angle[..., 1:2] + angle_feat = torch.cat( + [ + sin_t, + cos_t, + 2 * sin_t * cos_t, + cos_t**2 - sin_t**2, + ], + dim=-1, + ) + angle_emb = self.angle_proj(angle_feat) # Combine query_pos = base_query_pos + angle_emb return reference_points_inputs, query_pos - def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: - """Refine bounding boxes by applying the predicted deltas to the reference points. - The reference points are in the format (cx, cy, w, h, sinθ, cosθ), and the deltas are in the same format. - The refined boxes are computed as follows: - - cx' = cx + delta_cx * w - cy' = cy + delta_cy * h - w' = w * exp(delta_w) - h' = h * exp(delta_h) - sinθ' = sinθ * cosΔ + cosθ * sinΔ - cosθ' = cosθ * cosΔ - sinθ * sinΔ - - Args: - reference_points: (N, S, 6) tensor containing the reference points - deltas: (N, S, 6) tensor containing the predicted deltas - - Returns: - refined_boxes: (N, S, 6) tensor containing the refined bounding boxes - """ - reference_points = reference_points.to(deltas.device) - # size - wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=4.0).exp() * reference_points[..., 2:4] - wh = wh.clamp(min=1e-4, max=1.0) - # center - raw_cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] - half_wh = wh / 2 - cxcy = raw_cxcy.clamp( - min=half_wh, - max=1.0 - half_wh, - ) - # rotation - sin_d = deltas[..., 4:5] - cos_d = deltas[..., 5:6] + 1.0 - delta_rot = F.normalize(torch.cat([sin_d, cos_d], dim=-1), dim=-1, eps=1e-6) - sin_delta = delta_rot[..., 0:1] - cos_delta = delta_rot[..., 1:2] - sin_ref = reference_points[..., 4:5] - cos_ref = reference_points[..., 5:6] - # compose rotations - sin_new = sin_ref * cos_delta + cos_ref * sin_delta - cos_new = cos_ref * cos_delta - sin_ref * sin_delta - rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) - return torch.cat((cxcy, wh, rot), dim=-1) - def forward( self, inputs_embeds: torch.Tensor | None, reference_points: torch.Tensor, - spatial_shapes_list: torch.Tensor, - valid_ratios: torch.Tensor, + spatial_shapes_list: list[tuple[int, int]], encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor | None = None, ): + """Forward pass of the decoder with iterative (per-layer) box refinement. + + Returns: + last_hidden_state: (B, Q, d_model) normalized hidden states of the last decoder layer + intermediate: (num_layers, B, Q, d_model) normalized hidden states of every decoder layer + intermediate_reference_points: (num_layers, B, Q, 6) the reference points used as INPUT to each + decoder layer. Box predictions for layer i must be computed as + `refine_obb_boxes(intermediate_reference_points[i], bbox_embed(intermediate[i]))`. + Following LW-DETR, the refined reference points are detached before being fed to the next + layer, while the undetached versions are kept in the returned stack so that gradients can + flow one extra refinement step when computing the auxiliary losses. + """ intermediate: list[torch.Tensor] = [] + # reference points fed to each decoder layer (input of layer i at index i) intermediate_reference_points: list[torch.Tensor] = [reference_points] if inputs_embeds is not None: hidden_states = inputs_embeds - reference_points_inputs, query_pos = self.get_reference(reference_points, valid_ratios) + num_levels = len(spatial_shapes_list) + + reference_points_inputs, query_pos = self.get_reference( + reference_points, + num_levels=num_levels, + ) for lid, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( @@ -594,21 +628,23 @@ def forward( ) hidden_states_norm = self.layernorm(hidden_states) + intermediate.append(hidden_states_norm) # iterative refinement - if self.bbox_embed is not None: + if self.bbox_embed is not None and lid < len(self.layers) - 1: delta = self.bbox_embed(hidden_states_norm) + new_reference_points = refine_obb_boxes(reference_points, delta) - reference_points = self.refine_bboxes(reference_points, delta) - intermediate_reference_points.append(reference_points) + # keep the undetached version for the auxiliary losses ("look forward" supervision) + intermediate_reference_points.append(new_reference_points) + # the next layer consumes detached reference points + reference_points = new_reference_points.detach() reference_points_inputs, query_pos = self.get_reference( reference_points, - valid_ratios, + num_levels=num_levels, ) - intermediate.append(hidden_states_norm) - intermediate_stack = torch.stack(intermediate) last_hidden_state = intermediate_stack[-1] diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index fdb59fb2c6..92473b7f92 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -18,7 +18,12 @@ from ...utils import load_pretrained_params from .base import _LWDETR, LWDETRPostProcessor -from .layers import LWDETRDecoder, LWDETRHead, LWDETRMultiscaleDeformableAttention, MultiScaleProjector +from .layers import ( + LWDETRDecoder, + LWDETRHead, + MultiScaleProjector, + refine_obb_boxes, +) __all__ = ["LWDETR", "lw_detr_s", "lw_detr_m"] @@ -69,6 +74,82 @@ } +def _obb_covariance_components(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute the components (a, b, c) of the Gaussian covariance matrix [[a, c], [c, b]] associated with + oriented boxes in (cx, cy, w, h, sinθ, cosθ) format, following ProbIoU (https://arxiv.org/abs/2106.06072). + + Args: + boxes: (..., 6) tensor of oriented boxes + + Returns: + a, b, c: (...,) tensors of covariance components + """ + w = boxes[..., 2].clamp(min=1e-6) + h = boxes[..., 3].clamp(min=1e-6) + rot = F.normalize(boxes[..., 4:6], dim=-1, eps=1e-6) + sin, cos = rot[..., 0], rot[..., 1] + + var_w = w.pow(2) / 12.0 + var_h = h.pow(2) / 12.0 + cos2 = cos.pow(2) + sin2 = sin.pow(2) + + a = var_w * cos2 + var_h * sin2 + b = var_w * sin2 + var_h * cos2 + c = (var_w - var_h) * cos * sin + return a, b, c + + +def _probiou( + boxes1: torch.Tensor, + boxes2: torch.Tensor, + pairwise: bool = False, + scale: float = 1.0, + eps: float = 1e-7, +) -> torch.Tensor: + """Compute the probabilistic IoU between oriented boxes in (cx, cy, w, h, sinθ, cosθ) format, + as described in `"Gaussian Bounding Boxes and Probabilistic IoU" `_. + + Args: + boxes1: (N, 6) tensor of oriented boxes + boxes2: (M, 6) tensor of oriented boxes (M = N if `pairwise` is False) + pairwise: if True, return the (N, M) matrix of IoUs, otherwise the element-wise (N,) IoUs + scale: factor applied to the spatial components (cx, cy, w, h) before computing the IoU. + ProbIoU is mathematically scale-invariant, but the stabilizing `eps` terms are not: in + normalized [0, 1] coordinates the covariance terms of small boxes (e.g. checkboxes) fall + below `eps` and the result degenerates (non-overlapping small boxes can score ~1). + eps: small value for numerical stability + + Returns: + probiou: (N,) or (N, M) tensor of IoU-like similarities in [0, 1] + """ + if scale != 1.0: + boxes1 = torch.cat([boxes1[..., :4] * scale, boxes1[..., 4:]], dim=-1) + boxes2 = torch.cat([boxes2[..., :4] * scale, boxes2[..., 4:]], dim=-1) + + x1, y1 = boxes1[..., 0], boxes1[..., 1] + x2, y2 = boxes2[..., 0], boxes2[..., 1] + a1, b1, c1 = _obb_covariance_components(boxes1) + a2, b2, c2 = _obb_covariance_components(boxes2) + + if pairwise: + x1, y1, a1, b1, c1 = (t.unsqueeze(-1) for t in (x1, y1, a1, b1, c1)) # (N, 1) + x2, y2, a2, b2, c2 = (t.unsqueeze(-2) for t in (x2, y2, a2, b2, c2)) # (1, M) + + denom = (a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps + t1 = ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / denom * 0.25 + t2 = ((c1 + c2) * (x2 - x1) * (y1 - y2)) / denom * 0.5 + t3 = ( + ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2)).clamp(min=eps) + / (4 * ((a1 * b1 - c1.pow(2)).clamp(min=eps) * (a2 * b2 - c2.pow(2)).clamp(min=eps)).sqrt() + eps) + + eps + ).log() * 0.5 + + bhattacharyya = (t1 + t2 + t3).clamp(min=eps, max=100.0) + hellinger = (1.0 - (-bhattacharyya).exp() + eps).sqrt() + return 1.0 - hellinger + + class LWDETRBackbone(nn.Module): """Backbone of LW-DETR, based on a ViT Det architecture. The backbone is used as feature extractor. @@ -103,6 +184,42 @@ def __init__( num_blocks=num_blocks, ) + def _resize_padding_mask(self, mask: torch.Tensor, size: tuple[int, int]) -> torch.Tensor: + """Resize padding mask to feature-map size + + Args: + mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels + size: the target size (H', W') for the resized mask + + Returns: + resized_mask: a binary mask of shape [batch_size x H' x W'], containing 1 on padded pixels + """ + if mask.dtype != torch.bool: + mask = mask.bool() + + valid = (~mask).float().unsqueeze(1) # True/1 = valid pixels + + if (valid.flatten(1).sum(dim=1) == 0).any(): + bad = torch.where(valid.flatten(1).sum(dim=1) == 0)[0].tolist() + raise RuntimeError(f"Input masks are fully padded before resizing: {bad}") + + valid_resized = ( + F.interpolate( + valid, + size=size, + mode="area", + ) + > 0 + ) + + resized_mask = ~valid_resized.squeeze(1) + + if resized_mask.flatten(1).all(dim=1).any(): + bad = torch.where(resized_mask.flatten(1).all(dim=1))[0].tolist() + raise RuntimeError(f"Feature masks became fully padded after resizing: {bad}") + + return resized_mask + def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> list[tuple[torch.Tensor, torch.Tensor]]: """Forward pass of the backbone. @@ -121,10 +238,7 @@ def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> list[tup # [(B, C, H, W)] if mask is None: # pragma: no cover mask = torch.zeros((x.shape[0], x.shape[2], x.shape[3]), dtype=torch.bool, device=x.device) - return [ - (feat, F.interpolate(mask.unsqueeze(1).float(), size=feat.shape[-2:], mode="nearest").squeeze(1).bool()) - for feat in feats - ] + return [(feat, self._resize_padding_mask(mask, feat.shape[2:])) for feat in feats] class LWDETR(nn.Module, _LWDETR): @@ -154,17 +268,17 @@ def __init__( self, feat_extractor: LWDETRBackbone, class_names: list[str], - score_thresh: float = 0.5, + score_thresh: float = 0.1, iou_thresh: float = 0.5, d_model: int = 256, - num_queries: int = 195, # This is different from the paper which uses 300 queries, but 195 queries is sufficient for document layout analysis) # noqa: E501 + num_queries: int = 195, group_detr: int = 13, dec_layers: int = 3, sa_num_heads: int = 8, ca_num_heads: int = 16, ff_dim: int = 2048, dec_n_points: int = 2, - dropout_prob: float = 0.1, + dropout_prob: float = 0.0, assume_straight_pages: bool = True, exportable: bool = False, cfg: dict[str, Any] | None = None, @@ -172,7 +286,8 @@ def __init__( super().__init__() self.class_names: list[str] = class_names - self.num_classes = len(self.class_names) + 1 # +1 for background class + # No background class: the model is trained with a sigmoid-based (IA-BCE) loss + self.num_classes = len(self.class_names) self.cfg = cfg self.exportable = exportable self.assume_straight_pages = assume_straight_pages @@ -184,13 +299,6 @@ def __init__( self.d_model = d_model self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 6) - # Initialize angle to (sin=0, cos=1) - with torch.no_grad(): - self.reference_point_embed.weight[:, 0:2].uniform_(0.05, 0.95) # cx, cy - self.reference_point_embed.weight[:, 2].uniform_(0.1, 0.6) # w - self.reference_point_embed.weight[:, 3].uniform_(0.02, 0.3) # h - self.reference_point_embed.weight[:, 4].zero_() # sinθ - self.reference_point_embed.weight[:, 5].fill_(1.0) # cosθ self.query_feat = nn.Embedding(self.num_queries * self.group_detr, self.d_model) @@ -245,43 +353,26 @@ def __init__( if hasattr(m, "bias") and m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): - if m is not self.reference_point_embed: - nn.init.normal_(m.weight, std=0.02) - elif isinstance(m, LWDETRMultiscaleDeformableAttention): - nn.init.constant_(m.sampling_offsets.weight, 0.0) - - thetas = torch.arange(m.n_heads, dtype=torch.float32) * (2.0 * math.pi / m.n_heads) - grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) - grid_init = ( - (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) - .view(m.n_heads, 1, 1, 2) - .repeat(1, m.n_levels, m.n_points, 1) - ) - for i in range(m.n_points): - grid_init[:, :, i, :] *= i + 1 - with torch.no_grad(): - m.sampling_offsets.bias.copy_(grid_init.view(-1)) - - nn.init.constant_(m.attention_weights.weight, 0.0) - nn.init.constant_(m.attention_weights.bias, 0.0) - nn.init.xavier_uniform_(m.value_proj.weight) - nn.init.zeros_(m.value_proj.bias) - nn.init.xavier_uniform_(m.output_proj.weight) - nn.init.zeros_(m.output_proj.bias) + nn.init.normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.out_features == self.num_classes: + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) if m.bias is not None: - with torch.no_grad(): - # Focal-loss prior: foreground starts with low confidence (~0.01), - # preventing background from dominating gradients at the start of training. - prior_prob = 0.01 - bias_value = -math.log((1 - prior_prob) / prior_prob) - nn.init.constant_(m.bias, 0.0) - m.bias[:-1].fill_(bias_value) - if isinstance(m, LWDETRHead): - last = m.layers[-1] - if isinstance(last, nn.Linear): - nn.init.zeros_(last.weight) - nn.init.zeros_(last.bias) + nn.init.constant_(m.bias, bias_value) + + # Initialize the iterative refinement heads to predict zero deltas (i.e. identity refinement) + # at the start of training, to stabilize training in the early stages when the encoder proposals are still noisy + with torch.no_grad(): + for head in [self.bbox_embed, *self.enc_out_bbox_embed]: + last = head.layers[-1] + last.weight.zero_() + last.bias.zero_() + last.bias[5] = 1.0 # cosθ of the rotation delta -> identity rotation + + # The reference point embedding acts as a learned delta composed with the encoder proposals + self.reference_point_embed.weight.zero_() + self.reference_point_embed.weight[:, 5] = 1.0 # cosθ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: """Load pretrained parameters onto the model @@ -292,24 +383,6 @@ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: """ load_pretrained_params(self, path_or_url, **kwargs) - def get_valid_ratio(self, mask: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: - """Get the valid ratio of all feature maps. - - Args: - mask: (N, H, W) binary tensor containing 1 on padded pixels - dtype: the desired data type of the output tensor - - Returns: - valid_ratio: (N, 2) tensor containing the valid ratio of width and height for each image in the batch - """ - _, height, width = mask.shape - valid_height = torch.sum(~mask[:, :, 0], 1) - valid_width = torch.sum(~mask[:, 0, :], 1) - valid_ratio_height = valid_height.to(dtype) / height - valid_ratio_width = valid_width.to(dtype) / width - valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1) - return valid_ratio - def gen_encoder_output_proposals( self, enc_output: torch.Tensor, padding_mask: torch.Tensor, spatial_shapes: list[tuple[int, int]] ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -331,56 +404,48 @@ def gen_encoder_output_proposals( batch_size = enc_output.shape[0] proposals = [] _cur = 0 - for level, (height, width) in enumerate(spatial_shapes): - mask_flatten_ = padding_mask[:, _cur : (_cur + height * width)].view(batch_size, height, width, 1) - valid_height = torch.sum(~mask_flatten_[:, :, 0, 0], 1) - valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + for level, (height, width) in enumerate(spatial_shapes): grid_y, grid_x = torch.meshgrid( - torch.linspace( - 0, - height - 1, - height, - dtype=enc_output.dtype, - device=enc_output.device, - ), - torch.linspace( - 0, - width - 1, - width, - dtype=enc_output.dtype, - device=enc_output.device, - ), + torch.arange(height, dtype=enc_output.dtype, device=enc_output.device), + torch.arange(width, dtype=enc_output.dtype, device=enc_output.device), indexing="ij", ) - grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) - scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2) - grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale + grid = torch.stack([grid_x, grid_y], dim=-1) + grid = grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + + scale = torch.tensor( + [width, height], + dtype=enc_output.dtype, + device=enc_output.device, + ).view(1, 1, 1, 2) + + # Canvas-normalized center coordinates + grid = (grid + 0.5) / scale width_height = torch.ones_like(grid) * 0.05 * (2.0**level) - # add default rotation (sin=0, cos=1) sin = torch.zeros_like(grid[..., :1]) cos = torch.ones_like(grid[..., :1]) - proposal = torch.cat((grid, width_height, sin, cos), -1).view(batch_size, -1, 6) + proposal = torch.cat((grid, width_height, sin, cos), dim=-1).view(batch_size, -1, 6) proposals.append(proposal) _cur += height * width - output_proposals = torch.cat(proposals, 1) - spatial_valid = ((output_proposals[..., :2] > 0.01) & (output_proposals[..., :2] < 0.99)).all(-1, keepdim=True) - output_proposals_valid = spatial_valid - invalid_mask = padding_mask.unsqueeze(-1) | ~output_proposals_valid - output_proposals = output_proposals.masked_fill(invalid_mask, float(0)) + output_proposals = torch.cat(proposals, dim=1) + + spatial_valid = ((output_proposals[..., :4] > 0.01) & (output_proposals[..., :4] < 0.99)).all( + dim=-1, keepdim=True + ) + invalid_mask = padding_mask.unsqueeze(-1) | ~spatial_valid - # assign each pixel as an object query - object_query = enc_output - object_query = object_query.masked_fill(invalid_mask, 0.0) + output_proposals = output_proposals.masked_fill(invalid_mask, 0.0) + object_query = enc_output.masked_fill(invalid_mask, 0.0) return object_query, output_proposals, invalid_mask def forward( self, input: torch.Tensor, - masks: torch.Tensor, + masks: torch.Tensor | None = None, target: list[dict[str, np.ndarray]] | None = None, return_model_output: bool = False, return_preds: bool = False, @@ -419,7 +484,6 @@ def forward( mask_flatten_list.append(mask) source_flatten = torch.cat(source_flatten_list, 1) mask_flatten = torch.cat(mask_flatten_list, 1) - valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in feats_masks], 1) tgt = query_feat.unsqueeze(0).expand(batch_size, -1, -1) reference_points = reference_points.unsqueeze(0).expand(batch_size, -1, -1) @@ -432,9 +496,8 @@ def forward( topk = self.num_queries topk_coords_logits_list: list[torch.Tensor] = [] - topk_content_list: list[torch.Tensor] = [] - # encoder predictions for auxiliary losses + # encoder predictions on the selected top-k proposals, kept undetached for the auxiliary loss all_group_enc_logits: list[torch.Tensor] = [] all_group_enc_coords: list[torch.Tensor] = [] @@ -443,47 +506,46 @@ def forward( group_object_query = self.enc_output_norm[group_id](group_object_query) group_enc_outputs_class = self.enc_out_class_embed[group_id](group_object_query) - all_group_enc_logits.append(group_enc_outputs_class) group_enc_outputs_class_masked = group_enc_outputs_class.masked_fill(invalid_mask, float("-inf")) group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query) - group_enc_outputs_coord = self.decoder.refine_bboxes(output_proposals, group_delta_bbox) + group_enc_outputs_coord = refine_obb_boxes(output_proposals, group_delta_bbox) - all_group_enc_coords.append(group_enc_outputs_coord) - - scores = group_enc_outputs_class_masked[..., :-1].max(-1).values - group_topk_proposals = torch.topk(scores, topk, dim=1)[1] + group_topk_proposals = torch.topk(group_enc_outputs_class_masked.max(-1)[0], topk, dim=1)[1] group_topk_coords_logits_undetach = torch.gather( group_enc_outputs_coord, 1, group_topk_proposals.unsqueeze(-1).repeat(1, 1, 6), ) - group_topk_coords_logits = group_topk_coords_logits_undetach.detach() - topk_coords_logits_list.append(group_topk_coords_logits) - group_topk_content = torch.gather( - group_object_query, + # the auxiliary loss supervises only the selected proposals, + # so gather the matching class logits as well + group_topk_logits_undetach = torch.gather( + group_enc_outputs_class, 1, - group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model), + group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.num_classes), ) - topk_content_list.append(group_topk_content) + all_group_enc_logits.append(group_topk_logits_undetach) + all_group_enc_coords.append(group_topk_coords_logits_undetach) - topk_coords_logits = torch.cat(topk_coords_logits_list, 1) + # the decoder consumes detached proposals as initial reference points + topk_coords_logits_list.append(group_topk_coords_logits_undetach.detach()) - reference_points = self.decoder.refine_bboxes(topk_coords_logits, reference_points) + topk_coords_logits = torch.cat(topk_coords_logits_list, 1) + reference_points = refine_obb_boxes(topk_coords_logits, reference_points) last_hidden_states, intermediate, intermediate_reference_points = self.decoder( inputs_embeds=tgt, reference_points=reference_points, spatial_shapes_list=spatial_shapes_list, - valid_ratios=valid_ratios, encoder_hidden_states=source_flatten, encoder_attention_mask=mask_flatten, ) logits = self.class_embed(last_hidden_states) - pred_boxes = intermediate_reference_points[-1] + pred_boxes_delta = self.bbox_embed(last_hidden_states) + pred_boxes = refine_obb_boxes(intermediate_reference_points[-1], pred_boxes_delta) out: dict[str, Any] = {} @@ -507,25 +569,40 @@ def _postprocess(logits, boxes): # Build target processed_targets = self.build_target(target, self.class_names) - # Disable mixed precision for loss computation to ensure numerical stability, - # especially for the Bhattacharyya distance which involves - # logarithms and determinants of covariance matrices. - with torch.autocast(device_type=logits.device.type, enabled=False): - # Main loss from final decoder layer - loss = self.compute_loss(logits.float(), pred_boxes.float(), processed_targets) + # ProbIoU is computed in pixel coordinates + box_scale = float(max(input.shape[-2], input.shape[-1])) + + # Main loss from final decoder layer (group DETR: each group is matched independently) + split_logits = logits.chunk(group_detr, dim=1) + split_boxes = pred_boxes.chunk(group_detr, dim=1) + + main_loss: float | torch.Tensor = 0.0 + for g_logits, g_boxes in zip(split_logits, split_boxes): + main_loss += self.compute_loss(g_logits, g_boxes, processed_targets, box_scale=box_scale) + loss = main_loss / group_detr + + # Auxiliary losses from intermediate decoder layers + # (`intermediate_reference_points[i]` is the reference INPUT to decoder layer i) + for i in range(intermediate.shape[0] - 1): + aux_logits = self.class_embed(intermediate[i]) + aux_boxes_delta = self.bbox_embed(intermediate[i]) + aux_boxes = refine_obb_boxes(intermediate_reference_points[i], aux_boxes_delta) - # Auxiliary losses from intermediate decoder layers - for i in range(intermediate.shape[0] - 1): - aux_logits = self.class_embed(intermediate[i]).float() - aux_boxes = intermediate_reference_points[i + 1].float() - loss += self.compute_loss(aux_logits, aux_boxes, processed_targets) + split_aux_logits = aux_logits.chunk(group_detr, dim=1) + split_aux_boxes = aux_boxes.chunk(group_detr, dim=1) - # Auxiliary losses for encoder proposals - enc_logits = torch.cat(all_group_enc_logits, dim=1).float() - enc_coords = torch.cat(all_group_enc_coords, dim=1).float() - loss += 0.2 * self.compute_loss(enc_logits, enc_coords, processed_targets) + aux_loss: float | torch.Tensor = 0.0 + for g_logits, g_boxes in zip(split_aux_logits, split_aux_boxes): + aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets, box_scale=box_scale) + loss += aux_loss / group_detr - out["loss"] = loss + # Auxiliary losses for the selected encoder proposals + enc_loss: float | torch.Tensor = 0.0 + for group_logits, group_coords in zip(all_group_enc_logits, all_group_enc_coords): + enc_loss += self.compute_loss(group_logits, group_coords, processed_targets, box_scale=box_scale) + loss += enc_loss / group_detr + + out["loss"] = loss return out @@ -534,235 +611,121 @@ def compute_loss( logits: torch.Tensor, pred_boxes: torch.Tensor, targets: list[dict[str, np.ndarray]], + cls_loss_weight: float = 1.0, + l1_loss_weight: float = 5.0, + iou_loss_weight: float = 2.0, + box_scale: float = 1024.0, ) -> torch.Tensor: - """Compute the loss using Grouped Hungarian Matching - and consistent ProbIoU semantics for rotated bounding boxes. + """Compute the LW-DETR loss for oriented bounding boxes. + + Predictions are matched one-to-one to the ground truth boxes with Hungarian matching, + using a cost combining a focal-style classification cost, an L1 box cost + and a (negated) ProbIoU cost. The loss then consists of: + + - an IoU-aware binary cross-entropy (IA-BCE) classification loss, as described in the LW-DETR paper, + where the target of a matched (query, class) pair is `p**alpha * IoU**(1 - alpha)` and the rotated + ProbIoU is used as IoU measure + - an L1 regression loss on the normalized (cx, cy, w, h) of the matched pairs + - a ProbIoU loss (1 - ProbIoU) on the matched oriented boxes, computed in absolute pixel + coordinates (as in O2-RT-DETR). The box rotation is supervised solely through this term: + ProbIoU is differentiable w.r.t. the angle even for non-overlapping boxes. + + All terms are normalized by the number of ground truth boxes in the batch. Args: logits: (B, Q, C) tensor containing the predicted class logits for each query - pred_boxes: (B, Q, 6) tensor containing the predicted boxes in (cx, cy, w, h, sinθ, cosθ) format - targets: list of length B, where each element is a dict with keys "labels" and "boxes", - containing the ground truth labels and boxes for each image in the batch. + pred_boxes: (B, Q, 6) tensor containing the predicted boxes (cx, cy, w, h, sinθ, cosθ) for each query + targets: list of B dictionaries with keys "boxes" ((N, 6) array in OBB format) + and "labels" ((N,) array of class indices) + cls_loss_weight: weight of the classification loss + l1_loss_weight: weight of the L1 box regression loss + iou_loss_weight: weight of the ProbIoU loss + box_scale: image size used to rescale normalized boxes to pixel coordinates for ProbIoU computation Returns: - A scalar tensor containing the computed loss. + loss: the computed loss value """ + alpha, gamma, eps = 0.25, 2.0, 1e-8 device = logits.device - dtype = logits.dtype - B, Q, C = logits.shape - - # Consistent coefficients across matcher and loss components - class_weight = 2.0 - bbox_weight = 5.0 - probiou_weight = 2.0 - rot_weight = 0.5 - - # Focal Loss Params - alpha = 0.25 - gamma = 2.0 - eps = 1e-7 - - group_detr = getattr(self, "group_detr", 1) - - def _rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Convert rotated boxes to Gaussian distributions using the true - variance of a uniform continuous rectangle (w^2 / 12).""" - cxcy = boxes[..., :2] - w = boxes[..., 2].clamp(min=1e-6) - h = boxes[..., 3].clamp(min=1e-6) - sin = boxes[..., 4] - cos = boxes[..., 5] - - R = torch.stack( - [ - torch.stack([cos, -sin], dim=-1), - torch.stack([sin, cos], dim=-1), - ], - dim=-2, - ) + batch_size = logits.shape[0] - sx = (w**2) / 12.0 - sy = (h**2) / 12.0 - - S = torch.zeros((*boxes.shape[:-1], 2, 2), device=boxes.device, dtype=boxes.dtype) - S[..., 0, 0] = sx - S[..., 1, 1] = sy - - covariance = R @ S @ R.transpose(-1, -2) - return cxcy, covariance - - def _bhattacharyya_distance( - mu1: torch.Tensor, sigma1: torch.Tensor, mu2: torch.Tensor, sigma2: torch.Tensor - ) -> torch.Tensor: - """Compute Bhattacharyya distance with broadcast support.""" - delta = (mu1 - mu2).unsqueeze(-1) - sigma = (sigma1 + sigma2) * 0.5 - - eye = torch.eye(2, device=sigma.device, dtype=sigma.dtype) * 1e-6 - sigma_safe = sigma + eye - sigma1_safe = sigma1 + eye - sigma2_safe = sigma2 + eye - - L = torch.linalg.cholesky(sigma_safe) - sigma_inv = torch.cholesky_inverse(L) - - mahalanobis = (delta.transpose(-1, -2) @ sigma_inv @ delta).squeeze(-1).squeeze(-1) - - det_sigma = torch.linalg.det(sigma_safe).clamp(min=1e-6) - det_sigma1 = torch.linalg.det(sigma1_safe).clamp(min=1e-6) - det_sigma2 = torch.linalg.det(sigma2_safe).clamp(min=1e-6) - - bhattacharyya = 0.125 * mahalanobis + 0.5 * torch.log(det_sigma / torch.sqrt(det_sigma1 * det_sigma2)) - return bhattacharyya.clamp(min=0.0) - - # Prepare targets for matching - target_labels = [] - target_boxes = [] - sizes = [] - for t in targets: - lbls = torch.as_tensor(t["labels"], device=device, dtype=torch.long) - bxs = torch.as_tensor(t["boxes"], device=device, dtype=pred_boxes.dtype) - if bxs.ndim == 1 and bxs.numel() > 0: - bxs = bxs.unsqueeze(0) - target_labels.append(lbls) - target_boxes.append(bxs) - sizes.append(len(lbls)) - - # Unified formulation for empty batches - if sum(sizes) == 0: - prob = logits.sigmoid() - prob_safe = prob.clamp(min=eps, max=1.0 - eps) - neg_weights = prob.pow(gamma) - loss_ce = -neg_weights * (1.0 - prob_safe).log() - return class_weight * (loss_ce.sum() / (B * Q)) + tgt_boxes_list: list[torch.Tensor] = [] + tgt_labels_list: list[torch.Tensor] = [] + for sample in targets: + tgt_boxes_list.append( + torch.as_tensor(sample["boxes"], device=device, dtype=pred_boxes.dtype).reshape(-1, 6) + ) + tgt_labels_list.append(torch.as_tensor(sample["labels"], device=device, dtype=torch.long).reshape(-1)) - tgt_ids = torch.cat(target_labels) - tgt_bbox = torch.cat(target_boxes) + # Number of target boxes in the batch, for loss normalization + num_boxes = max(sum(int(labels.numel()) for labels in tgt_labels_list), 1) - # Matcher: Grouped Hungarian Assignment with a balanced cost matrix + # Hungarian matching (one-to-one), performed independently for each sample + indices: list[tuple[torch.Tensor, torch.Tensor]] = [] with torch.no_grad(): - out_prob = logits.flatten(0, 1).sigmoid() - out_bbox = pred_boxes.flatten(0, 1) - - # Classification Cost (Focal Loss based) - neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + eps).log()) - pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + eps).log()) - class_cost = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] - - # Box L1 Cost - out_bbox_f = out_bbox.to(torch.float32) - tgt_bbox_f = tgt_bbox.to(torch.float32) - bbox_cost = torch.cdist(out_bbox_f[:, :4], tgt_bbox_f[:, :4], p=1).to(dtype) - - # ProbIoU Cost - mu_pred, sig_pred = _rotated_boxes_to_gaussian(out_bbox_f) - mu_tgt, sig_tgt = _rotated_boxes_to_gaussian(tgt_bbox_f) + prob = logits.sigmoid() + for b in range(batch_size): + tgt_boxes, tgt_labels = tgt_boxes_list[b], tgt_labels_list[b] + if tgt_labels.numel() == 0: + empty = torch.empty(0, dtype=torch.long, device=device) + indices.append((empty, empty)) + continue + + out_prob = prob[b] + out_boxes = pred_boxes[b] + + # Focal-style classification cost + neg_cost = (1 - alpha) * out_prob.pow(gamma) * (-(1 - out_prob + eps).log()) + pos_cost = alpha * (1 - out_prob).pow(gamma) * (-(out_prob + eps).log()) + cost_class = pos_cost[:, tgt_labels] - neg_cost[:, tgt_labels] + + # L1 cost on normalized (cx, cy, w, h) + cost_bbox = torch.cdist(out_boxes[:, :4], tgt_boxes[:, :4], p=1) + + # Rotated IoU cost, computed in pixel coordinates + # this term also carries the angle signal for the matching + cost_iou = -_probiou(out_boxes, tgt_boxes, pairwise=True, scale=box_scale) + + cost = 2.0 * cost_class + 5.0 * cost_bbox + 2.0 * cost_iou + + query_idx, tgt_idx = linear_sum_assignment(cost.cpu().numpy()) + indices.append(( + torch.as_tensor(query_idx, dtype=torch.long, device=device), + torch.as_tensor(tgt_idx, dtype=torch.long, device=device), + )) + + # Flatten the matched pairs across the batch + batch_idx = torch.cat([torch.full_like(src, b) for b, (src, _) in enumerate(indices)]) + query_idx = torch.cat([src for (src, _) in indices]) + matched_tgt_boxes = torch.cat([tgt_boxes_list[b][tgt] for b, (_, tgt) in enumerate(indices)]) + matched_tgt_labels = torch.cat([tgt_labels_list[b][tgt] for b, (_, tgt) in enumerate(indices)]) + matched_pred_boxes = pred_boxes[batch_idx, query_idx] - bhat_dist = _bhattacharyya_distance( - mu_pred.unsqueeze(1), sig_pred.unsqueeze(1), mu_tgt.unsqueeze(0), sig_tgt.unsqueeze(0) - ) - probiou_cost = (1.0 - torch.exp(-bhat_dist)).to(dtype) - - # Rotation Cost - pred_rot = F.normalize(out_bbox_f[:, 4:6], dim=-1) - tgt_rot = F.normalize(tgt_bbox_f[:, 4:6], dim=-1) - rot_cost = (1.0 - torch.abs(pred_rot @ tgt_rot.T)).to(dtype) - - # Total balanced Cost Matrix - cost_matrix = ( - class_weight * class_cost - + bbox_weight * bbox_cost - + probiou_weight * probiou_cost - + rot_weight * rot_cost - ) - cost_matrix = cost_matrix.view(B, Q, -1).cpu() - - # Grouped Hungarian Assignment - indices = [] - group_num_queries = Q // group_detr - cost_matrix_groups = cost_matrix.split(group_num_queries, dim=1) - - for group_id in range(group_detr): - group_cost_matrix = cost_matrix_groups[group_id] - - # Split targets per batch element - group_indices = [] - for i, c in enumerate(group_cost_matrix.split(sizes, -1)): - if sizes[i] == 0: - group_indices.append((np.array([], dtype=np.int64), np.array([], dtype=np.int64))) - else: - row_ind, col_ind = linear_sum_assignment(c[i].numpy()) - group_indices.append((row_ind, col_ind)) - - if group_id == 0: - indices = group_indices - else: - indices = [ - ( - np.concatenate([idx1[0], idx2[0] + group_num_queries * group_id]), - np.concatenate([idx1[1], idx2[1]]), - ) - for idx1, idx2 in zip(indices, group_indices) - ] - - # Image lovel loss normalization: scale by the number of matched boxes, - # and the number of active groups in group DETR - # Scale denominator by the number of active assignment groups - num_boxes = max(sum(sizes) * group_detr, 1) - - batch_idx = torch.cat([torch.full((len(src),), i, dtype=torch.long) for i, (src, _) in enumerate(indices)]) - src_idx = torch.cat([torch.as_tensor(src, dtype=torch.long) for (src, _) in indices]) - - flat_tgt_idx_list = [] - offset = 0 - for i, (_, tgt) in enumerate(indices): - flat_tgt_idx_list.append(torch.as_tensor(tgt, dtype=torch.long) + offset) - offset += sizes[i] - flat_tgt_idx = torch.cat(flat_tgt_idx_list) - - target_classes_o = tgt_ids[flat_tgt_idx] - src_boxes = pred_boxes[batch_idx, src_idx] - target_boxes_matched = tgt_bbox[flat_tgt_idx] - - # Label Loss with Quality Mapping prob = logits.sigmoid() - mu1, sig1 = _rotated_boxes_to_gaussian(src_boxes.detach().to(torch.float32)) - mu2, sig2 = _rotated_boxes_to_gaussian(target_boxes_matched.detach().to(torch.float32)) - bhat_matched = _bhattacharyya_distance(mu1, sig1, mu2, sig2) - pos_ious = torch.exp(-bhat_matched).clamp(min=0.0, max=1.0).to(dtype) - + # IoU-aware BCE classification loss (IA-BCE) pos_weights = torch.zeros_like(logits) neg_weights = prob.pow(gamma) - pos_ind = (batch_idx, src_idx, target_classes_o) + if len(batch_idx) > 0: + with torch.no_grad(): + ious = _probiou(matched_pred_boxes, matched_tgt_boxes, scale=box_scale).clamp(min=0.0, max=1.0) + t = prob[batch_idx, query_idx, matched_tgt_labels].pow(alpha) * ious.pow(1 - alpha) + t = t.clamp(min=0.01) + pos_weights[batch_idx, query_idx, matched_tgt_labels] = t + neg_weights[batch_idx, query_idx, matched_tgt_labels] = 1 - t - pos_quality = prob[pos_ind].pow(alpha) * pos_ious.pow(1 - alpha) - pos_quality = torch.clamp(pos_quality, 0.01).detach() - - pos_weights[pos_ind] = pos_quality - neg_weights[pos_ind] = 1 - pos_quality - - prob_safe = prob.clamp(min=eps, max=1.0 - eps) - loss_ce = -pos_weights * prob_safe.log() - neg_weights * (1.0 - prob_safe).log() - loss_ce = loss_ce.sum() / num_boxes - - # Bounding Box Loss - loss_bbox = ( - F.smooth_l1_loss(src_boxes[:, :4], target_boxes_matched[:, :4], reduction="sum", beta=0.1) / num_boxes - ) + cls_loss = -(pos_weights * (prob + eps).log() + neg_weights * (1 - prob + eps).log()) + loss = cls_loss_weight * cls_loss.sum() / num_boxes - # ProbIoU Loss - mu1_l, sig1_l = _rotated_boxes_to_gaussian(src_boxes.to(torch.float32)) - mu2_l, sig2_l = _rotated_boxes_to_gaussian(target_boxes_matched.to(torch.float32)) - bhat_loss = _bhattacharyya_distance(mu1_l, sig1_l, mu2_l, sig2_l) - loss_probiou = (1.0 - torch.exp(-bhat_loss)).to(dtype).sum() / num_boxes + if len(batch_idx) == 0: + return loss - # Rotation Loss - pred_rot = F.normalize(src_boxes[:, 4:6], dim=-1, eps=1e-6) - tgt_rot = F.normalize(target_boxes_matched[:, 4:6], dim=-1, eps=1e-6) - loss_rot = (1.0 - torch.abs((pred_rot * tgt_rot).sum(dim=-1))).sum() / num_boxes + # L1 loss on normalized (cx, cy, w, h) + l1_loss = F.l1_loss(matched_pred_boxes[:, :4], matched_tgt_boxes[:, :4], reduction="sum") / num_boxes + # ProbIoU loss on the whole oriented box (position, size and rotation), in pixel coordinates + probiou_loss = (1 - _probiou(matched_pred_boxes, matched_tgt_boxes, scale=box_scale)).sum() / num_boxes - return class_weight * loss_ce + bbox_weight * loss_bbox + probiou_weight * loss_probiou + rot_weight * loss_rot + return loss + l1_loss_weight * l1_loss + iou_loss_weight * probiou_loss def _lw_detr( @@ -784,7 +747,7 @@ def _lw_detr( False, include_top=False, input_shape=default_cfgs[arch]["input_shape"], - patch_size=kwargs.get("patch_size", (16, 16)), + patch_size=kwargs.pop("patch_size", (16, 16)), ) feat_extractor = LWDETRBackbone(encoder_fn=backbone) diff --git a/references/detection/train.py b/references/detection/train.py index 37b439a6e5..14c8871794 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -333,8 +333,8 @@ def main(args): [ T.RandomHorizontalFlip(0.15), T.OneOf([ - T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25), - T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), + T.RandomApply(T.RandomCrop(ratio=(0.85, 1.15), scale=(0.75, 1.0)), 0.15), + T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.15), ]), T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), ] @@ -342,8 +342,8 @@ def main(args): else [ T.RandomHorizontalFlip(0.15), T.OneOf([ - T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25), - T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), + T.RandomApply(T.RandomCrop(ratio=(0.85, 1.15), scale=(0.75, 1.0)), 0.15), + T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.15), ]), # Rotation augmentation T.Resize(args.input_size, preserve_aspect_ratio=True), diff --git a/references/layout/train.py b/references/layout/train.py index bf8a3a600d..649289f6fe 100644 --- a/references/layout/train.py +++ b/references/layout/train.py @@ -89,14 +89,14 @@ def record_lr( scaler.scale(train_loss).backward() # Gradient clipping scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) # Update the params scaler.step(optimizer) scaler.update() else: train_loss = model(imgs, padding_masks, targets)["loss"] train_loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) optimizer.step() # Update LR scheduler.step() @@ -137,14 +137,14 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a scaler.scale(train_loss).backward() # Gradient clipping scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) # Update the params scaler.step(optimizer) scaler.update() else: train_loss = model(imgs, padding_masks, targets)["loss"] train_loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) optimizer.step() scheduler.step() @@ -390,8 +390,8 @@ def main(args): [ T.RandomHorizontalFlip(0.15), T.OneOf([ - T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25), - T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), + T.RandomApply(T.RandomCrop(ratio=(0.85, 1.15), scale=(0.75, 1.0)), 0.15), + T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.15), ]), T.Resize( (args.input_size, args.input_size), @@ -404,8 +404,8 @@ def main(args): else [ T.RandomHorizontalFlip(0.15), T.OneOf([ - T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25), - T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), + T.RandomApply(T.RandomCrop(ratio=(0.85, 1.15), scale=(0.75, 1.0)), 0.15), + T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.15), ]), # Rotation augmentation T.Resize(args.input_size, preserve_aspect_ratio=True, return_padding_mask=True), @@ -479,15 +479,23 @@ def main(args): # construct DDP model model = DDP(model, device_ids=[rank]) - backbone_params = [p for n, p in model.named_parameters() if n.startswith("feat_extractor.") and p.requires_grad] - decoder_params = [p for n, p in model.named_parameters() if not n.startswith("feat_extractor.") and p.requires_grad] + def is_backbone_param(name: str) -> bool: + name = name.removeprefix("module.") + return name.startswith("feat_extractor.") + + backbone_params = [p for n, p in model.named_parameters() if is_backbone_param(n) and p.requires_grad] + + decoder_params = [p for n, p in model.named_parameters() if not is_backbone_param(n) and p.requires_grad] + + backbone_lr = args.lr if not args.pretrained else args.lr * 0.1 + decoder_lr = args.lr # Optimizer if args.optim == "adam": optimizer = torch.optim.Adam( [ - {"params": backbone_params, "lr": 1e-5, "weight_decay": args.weight_decay or 1e-4}, - {"params": decoder_params, "lr": args.lr, "weight_decay": args.weight_decay or 1e-4}, + {"params": backbone_params, "lr": backbone_lr, "weight_decay": args.weight_decay or 1e-4}, + {"params": decoder_params, "lr": decoder_lr, "weight_decay": args.weight_decay or 1e-4}, ], lr=args.lr, betas=(0.9, 0.999), @@ -498,8 +506,8 @@ def main(args): elif args.optim == "adamw": optimizer = torch.optim.AdamW( [ - {"params": backbone_params, "lr": 1e-5, "weight_decay": args.weight_decay or 1e-4}, - {"params": decoder_params, "lr": args.lr, "weight_decay": args.weight_decay or 1e-4}, + {"params": backbone_params, "lr": backbone_lr, "weight_decay": args.weight_decay or 1e-4}, + {"params": decoder_params, "lr": decoder_lr, "weight_decay": args.weight_decay or 1e-4}, ], lr=args.lr, betas=(0.9, 0.999), @@ -515,7 +523,7 @@ def main(args): # Scheduler total_steps = args.epochs * len(train_loader) - warmup_steps = min(1000, max(200, total_steps // 20)) + warmup_steps = max(1, min(1000, int(0.05 * total_steps))) if args.sched == "cosine": warmup = LinearLR( @@ -526,7 +534,7 @@ def main(args): ) cosine = CosineAnnealingLR( optimizer, - T_max=total_steps - warmup_steps, + T_max=max(1, total_steps - warmup_steps), eta_min=args.lr * 0.01, ) scheduler = SequentialLR( @@ -749,7 +757,7 @@ def parse_args(): "--save-interval-epoch", dest="save_interval_epoch", action="store_true", help="Save model every epoch" ) parser.add_argument("--input_size", type=int, default=1024, help="model input size, H = W") - parser.add_argument("--lr", type=float, default=4e-4, help="learning rate for the optimizer (Adam or AdamW)") + parser.add_argument("--lr", type=float, default=1e-4, help="learning rate for the optimizer (Adam or AdamW)") parser.add_argument("--wd", "--weight-decay", default=1e-4, type=float, help="weight decay", dest="weight_decay") parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading") parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint") From 1b1b9a4ff9e69cf37f1d4c28004dbeccd9ea2f2e Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 12 Jun 2026 11:14:30 +0200 Subject: [PATCH 11/15] Update layout model & loss & train script --- references/detection/train.py | 8 ++++---- references/layout/train.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/references/detection/train.py b/references/detection/train.py index 14c8871794..cba2c15d0a 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -333,8 +333,8 @@ def main(args): [ T.RandomHorizontalFlip(0.15), T.OneOf([ - T.RandomApply(T.RandomCrop(ratio=(0.85, 1.15), scale=(0.75, 1.0)), 0.15), - T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.15), + T.RandomApply(T.RandomCrop(ratio=(0.85, 1.15), scale=(0.75, 1.0)), 0.25), + T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), ]), T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), ] @@ -342,8 +342,8 @@ def main(args): else [ T.RandomHorizontalFlip(0.15), T.OneOf([ - T.RandomApply(T.RandomCrop(ratio=(0.85, 1.15), scale=(0.75, 1.0)), 0.15), - T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.15), + T.RandomApply(T.RandomCrop(ratio=(0.85, 1.15), scale=(0.75, 1.0)), 0.25), + T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), ]), # Rotation augmentation T.Resize(args.input_size, preserve_aspect_ratio=True), diff --git a/references/layout/train.py b/references/layout/train.py index 649289f6fe..7598b3df42 100644 --- a/references/layout/train.py +++ b/references/layout/train.py @@ -390,8 +390,8 @@ def main(args): [ T.RandomHorizontalFlip(0.15), T.OneOf([ - T.RandomApply(T.RandomCrop(ratio=(0.85, 1.15), scale=(0.75, 1.0)), 0.15), - T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.15), + T.RandomApply(T.RandomCrop(ratio=(0.85, 1.15), scale=(0.75, 1.0)), 0.25), + T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), ]), T.Resize( (args.input_size, args.input_size), @@ -404,8 +404,8 @@ def main(args): else [ T.RandomHorizontalFlip(0.15), T.OneOf([ - T.RandomApply(T.RandomCrop(ratio=(0.85, 1.15), scale=(0.75, 1.0)), 0.15), - T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.15), + T.RandomApply(T.RandomCrop(ratio=(0.85, 1.15), scale=(0.75, 1.0)), 0.25), + T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), ]), # Rotation augmentation T.Resize(args.input_size, preserve_aspect_ratio=True, return_padding_mask=True), From 0a870ac0d3b826ad1096e8a2f3e982f087ab36db Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 12 Jun 2026 14:08:15 +0200 Subject: [PATCH 12/15] fixes --- doctr/models/layout/lw_detr/pytorch.py | 21 +++---------- references/detection/train.py | 4 ++- references/layout/train.py | 43 ++++++++------------------ references/layout/utils.py | 38 +++++++++++++++++++++++ references/recognition/train.py | 2 ++ tests/pytorch/test_models_layout.py | 4 +-- 6 files changed, 63 insertions(+), 49 deletions(-) diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 92473b7f92..a065b8927d 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -198,23 +198,12 @@ def _resize_padding_mask(self, mask: torch.Tensor, size: tuple[int, int]) -> tor mask = mask.bool() valid = (~mask).float().unsqueeze(1) # True/1 = valid pixels - - if (valid.flatten(1).sum(dim=1) == 0).any(): - bad = torch.where(valid.flatten(1).sum(dim=1) == 0)[0].tolist() - raise RuntimeError(f"Input masks are fully padded before resizing: {bad}") - - valid_resized = ( - F.interpolate( - valid, - size=size, - mode="area", - ) - > 0 - ) - + valid_resized = F.interpolate(valid, size=size, mode="area") > 0 resized_mask = ~valid_resized.squeeze(1) - if resized_mask.flatten(1).all(dim=1).any(): + # Sanity check: no feature should be fully padded after resizing, + # otherwise it would cause NaNs in the attention weights + if self.training and resized_mask.flatten(1).all(dim=1).any(): # pragma: no cover bad = torch.where(resized_mask.flatten(1).all(dim=1))[0].tolist() raise RuntimeError(f"Feature masks became fully padded after resizing: {bad}") @@ -268,7 +257,7 @@ def __init__( self, feat_extractor: LWDETRBackbone, class_names: list[str], - score_thresh: float = 0.1, + score_thresh: float = 0.25, iou_thresh: float = 0.5, d_model: int = 256, num_queries: int = 195, diff --git a/references/detection/train.py b/references/detection/train.py index cba2c15d0a..2070b63c3e 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -363,7 +363,7 @@ def main(args): ) if distributed: - sampler = DistributedSampler(train_set, rank=rank, shuffle=False, drop_last=True) + sampler = DistributedSampler(train_set, rank=rank, shuffle=True, drop_last=True) else: sampler = RandomSampler(train_set) @@ -518,6 +518,8 @@ def log_at_step(train_loss=None, val_loss=None, lr=None): # Training loop for epoch in range(args.epochs): + if distributed: + sampler.set_epoch(epoch) train_loss, actual_lr = fit_one_epoch( model, train_loader, batch_transforms, optimizer, scheduler, amp=args.amp, log=log_at_step, rank=rank ) diff --git a/references/layout/train.py b/references/layout/train.py index 7598b3df42..68339f761a 100644 --- a/references/layout/train.py +++ b/references/layout/train.py @@ -38,7 +38,7 @@ from doctr.datasets import LayoutDataset from doctr.models import layout, login_to_hub, push_to_hf_hub from doctr.utils.metrics import ObjectDetectionMetric -from utils import EarlyStopper, convert_target, plot_recorder, plot_samples +from utils import EarlyStopper, build_param_groups, convert_target, plot_recorder, plot_samples def record_lr( @@ -431,7 +431,7 @@ def main(args): ) if distributed: - sampler = DistributedSampler(train_set, rank=rank, shuffle=False, drop_last=True) + sampler = DistributedSampler(train_set, rank=rank, shuffle=True, drop_last=True) else: sampler = RandomSampler(train_set) @@ -483,37 +483,18 @@ def is_backbone_param(name: str) -> bool: name = name.removeprefix("module.") return name.startswith("feat_extractor.") - backbone_params = [p for n, p in model.named_parameters() if is_backbone_param(n) and p.requires_grad] - - decoder_params = [p for n, p in model.named_parameters() if not is_backbone_param(n) and p.requires_grad] - - backbone_lr = args.lr if not args.pretrained else args.lr * 0.1 - decoder_lr = args.lr + param_groups = build_param_groups( + model, + lr=args.lr, + backbone_lr=args.lr if not args.pretrained else args.lr * 0.1, + weight_decay=args.weight_decay or 1e-4, + ) # Optimizer if args.optim == "adam": - optimizer = torch.optim.Adam( - [ - {"params": backbone_params, "lr": backbone_lr, "weight_decay": args.weight_decay or 1e-4}, - {"params": decoder_params, "lr": decoder_lr, "weight_decay": args.weight_decay or 1e-4}, - ], - lr=args.lr, - betas=(0.9, 0.999), - eps=1e-6, - weight_decay=args.weight_decay or 1e-4, - ) - + optimizer = torch.optim.Adam(param_groups, lr=args.lr, betas=(0.9, 0.999), eps=1e-8) elif args.optim == "adamw": - optimizer = torch.optim.AdamW( - [ - {"params": backbone_params, "lr": backbone_lr, "weight_decay": args.weight_decay or 1e-4}, - {"params": decoder_params, "lr": decoder_lr, "weight_decay": args.weight_decay or 1e-4}, - ], - lr=args.lr, - betas=(0.9, 0.999), - eps=1e-6, - weight_decay=args.weight_decay or 1e-4, - ) + optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.999), eps=1e-8) # LR Finder if rank == 0 and args.find_lr: @@ -523,7 +504,7 @@ def is_backbone_param(name: str) -> bool: # Scheduler total_steps = args.epochs * len(train_loader) - warmup_steps = max(1, min(1000, int(0.05 * total_steps))) + warmup_steps = max(1, min(2000, int(0.05 * total_steps))) if args.sched == "cosine": warmup = LinearLR( @@ -658,6 +639,8 @@ def log_at_step(train_loss=None, val_loss=None, lr=None): # Training loop for epoch in range(args.epochs): + if distributed: + sampler.set_epoch(epoch) train_loss, actual_lr = fit_one_epoch( model, train_loader, diff --git a/references/layout/utils.py b/references/layout/utils.py index 6629487ab7..f7e56f049f 100644 --- a/references/layout/utils.py +++ b/references/layout/utils.py @@ -86,6 +86,44 @@ def plot_samples( plt.show() +def build_param_groups(model: Any, lr: float, backbone_lr: float, weight_decay: float): + """Build parameter groups for the optimizer, separating backbone and non-backbone parameters, + and applying weight decay only to non-bias and non-norm parameters. + + Args: + model: the model containing the parameters + lr: learning rate for non-backbone parameters + backbone_lr: learning rate for backbone parameters + weight_decay: weight decay to apply to non-bias and non-norm parameters + + Returns: + a list of parameter groups to be passed to the optimizer + """ + no_decay_keys = ("bias", "norm", ".bn", "embed") # Embedding, LayerNorm, BN + + def is_backbone(name: str) -> bool: + return name.removeprefix("module.").startswith("feat_extractor.") + + groups: dict[tuple[bool, bool], list[Any]] = { + (False, True): [], + (False, False): [], + (True, True): [], + (True, False): [], + } + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + decay = not (p.ndim <= 1 or any(k in n.lower() for k in no_decay_keys)) + groups[(is_backbone(n), decay)].append(p) + + return [ + {"params": groups[(False, True)], "lr": lr, "weight_decay": weight_decay}, + {"params": groups[(False, False)], "lr": lr, "weight_decay": 0.0}, + {"params": groups[(True, True)], "lr": backbone_lr, "weight_decay": weight_decay}, + {"params": groups[(True, False)], "lr": backbone_lr, "weight_decay": 0.0}, + ] + + def plot_recorder(lr_recorder, loss_recorder, beta: float = 0.95, **kwargs) -> None: """Display the results of the LR grid search. Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py diff --git a/references/recognition/train.py b/references/recognition/train.py index fd0cd1826d..913f2ed017 100644 --- a/references/recognition/train.py +++ b/references/recognition/train.py @@ -559,6 +559,8 @@ def log_at_step(train_loss=None, val_loss=None, lr=None): early_stopper = EarlyStopper(patience=args.early_stop_epochs, min_delta=args.early_stop_delta) # Training loop for epoch in range(args.epochs): + if distributed: + sampler.set_epoch(epoch) train_loss, actual_lr = fit_one_epoch( model, device, diff --git a/tests/pytorch/test_models_layout.py b/tests/pytorch/test_models_layout.py index 1094eee76f..f42cb95385 100644 --- a/tests/pytorch/test_models_layout.py +++ b/tests/pytorch/test_models_layout.py @@ -27,7 +27,7 @@ def test_layout_models(arch_name, input_shape, train_mode, use_polygons): model = model.train() if train_mode else model.eval() assert isinstance(model, torch.nn.Module) input_tensor = torch.rand((batch_size, *input_shape)) - input_masks = torch.ones((batch_size, input_shape[1], input_shape[2]), dtype=torch.bool) + input_masks = torch.zeros((batch_size, input_shape[1], input_shape[2]), dtype=torch.bool) class_names = model.class_names @@ -130,7 +130,7 @@ def test_models_onnx_export(arch_name, input_shape): batch_size = 2 model = layout.__dict__[arch_name](pretrained=True, exportable=True).eval() dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32) - dummy_masks = torch.ones((batch_size, input_shape[1], input_shape[2]), dtype=torch.bool) + dummy_masks = torch.zeros((batch_size, input_shape[1], input_shape[2]), dtype=torch.bool) pt = model(dummy_input, dummy_masks) pt_logits = pt["logits"].detach().cpu().numpy() pt_boxes = pt["pred_boxes"].detach().cpu().numpy() From cf7caddc87104e504d3dcaaaec2d930298d8f8c1 Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 15 Jun 2026 07:01:33 +0200 Subject: [PATCH 13/15] Fix lwdetr model export and small hub factory fix --- doctr/models/factory/hub.py | 82 +++++++++++++------------- doctr/models/layout/lw_detr/pytorch.py | 26 ++++++-- tests/pytorch/test_models_factory.py | 2 +- tests/pytorch/test_models_layout.py | 2 +- 4 files changed, 63 insertions(+), 49 deletions(-) diff --git a/doctr/models/factory/hub.py b/doctr/models/factory/hub.py index 45d57a326d..57839f5646 100644 --- a/doctr/models/factory/hub.py +++ b/doctr/models/factory/hub.py @@ -9,7 +9,6 @@ import logging import subprocess import tempfile -import textwrap from pathlib import Path from typing import Any @@ -101,61 +100,62 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: # raise ValueError("task must be one of classification, detection, recognition, layout") # default readme - readme = textwrap.dedent( - f"""--- - language: en - tags: - - ocr - - pytorch - - doctr - - {task} - --- + readme = f"""--- +language: en +tags: +- ocr +- pytorch +- doctr +- {task} +--- -

- -

+

+ +

- **Optical Character Recognition made seamless & accessible to anyone, powered by PyTorch** +**Optical Character Recognition made seamless & accessible to anyone, powered by PyTorch** - ## Task: {task} +## Task: {task} - https://github.com/mindee/doctr +https://github.com/mindee/doctr - ### Example usage: +### Example usage: - ```python - >>> from doctr.io import DocumentFile - >>> from doctr.models import ocr_predictor, from_hub +```python +>>> from doctr.io import DocumentFile +>>> from doctr.models import ocr_predictor, from_hub - >>> img = DocumentFile.from_images(['']) - >>> # Load your model from the hub - >>> model = from_hub('mindee/my-model') +>>> img = DocumentFile.from_images(['']) +>>> # Load your model from the hub +>>> model = from_hub('mindee/my-model') - >>> # Pass it to the predictor - >>> # If your model is a recognition model: - >>> predictor = ocr_predictor(det_arch='db_mobilenet_v3_large', - >>> reco_arch=model, - >>> pretrained=True) +>>> # Pass it to the predictor +>>> # If your model is a recognition model: +>>> predictor = ocr_predictor(det_arch='db_mobilenet_v3_large', +>>> reco_arch=model, +>>> pretrained=True) - >>> # If your model is a detection model: - >>> predictor = ocr_predictor(det_arch=model, - >>> reco_arch='crnn_mobilenet_v3_small', - >>> pretrained=True) +>>> # If your model is a detection model: +>>> predictor = ocr_predictor(det_arch=model, +>>> reco_arch='crnn_mobilenet_v3_small', +>>> pretrained=True) - >>> # Get your predictions - >>> res = predictor(img) - ``` - """ - ) +>>> # Get your predictions +>>> res = predictor(img) +``` +""" # add run configuration to readme if available if run_config is not None: arch = run_config.arch - readme += textwrap.dedent( - f"""### Run Configuration - \n{json.dumps(vars(run_config), indent=2, ensure_ascii=False)}""" - ) + readme += f""" +### Run Configuration + +```json +{json.dumps(vars(run_config), indent=2, ensure_ascii=False)} +``` +""" if arch not in AVAILABLE_ARCHS[task]: raise ValueError( diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index a065b8927d..96f2edd205 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -198,14 +198,28 @@ def _resize_padding_mask(self, mask: torch.Tensor, size: tuple[int, int]) -> tor mask = mask.bool() valid = (~mask).float().unsqueeze(1) # True/1 = valid pixels - valid_resized = F.interpolate(valid, size=size, mode="area") > 0 + + # Data-dependent sanity checks + if self.training: # pragma: no cover + if (valid.flatten(1).sum(dim=1) == 0).any(): + bad = torch.where(valid.flatten(1).sum(dim=1) == 0)[0].tolist() + raise RuntimeError(f"Input masks are fully padded before resizing: {bad}") + + # Use max pooling to resize the valid mask: + # a pixel in the resized mask is valid if at least one pixel + # in the corresponding window in the input mask is valid + h_in, w_in = int(mask.shape[-2]), int(mask.shape[-1]) + h_out, w_out = int(size[0]), int(size[1]) + kh, kw = h_in // h_out, w_in // w_out + valid_resized = F.max_pool2d(valid, kernel_size=(kh, kw), stride=(kh, kw)) > 0 + resized_mask = ~valid_resized.squeeze(1) - # Sanity check: no feature should be fully padded after resizing, - # otherwise it would cause NaNs in the attention weights - if self.training and resized_mask.flatten(1).all(dim=1).any(): # pragma: no cover - bad = torch.where(resized_mask.flatten(1).all(dim=1))[0].tolist() - raise RuntimeError(f"Feature masks became fully padded after resizing: {bad}") + # Data-dependent sanity checks + if self.training: # pragma: no cover + if resized_mask.flatten(1).all(dim=1).any(): + bad = torch.where(resized_mask.flatten(1).all(dim=1))[0].tolist() + raise RuntimeError(f"Feature masks became fully padded after resizing: {bad}") return resized_mask diff --git a/tests/pytorch/test_models_factory.py b/tests/pytorch/test_models_factory.py index 8cf0a9fa45..db9aed1cdf 100644 --- a/tests/pytorch/test_models_factory.py +++ b/tests/pytorch/test_models_factory.py @@ -50,7 +50,7 @@ def test_push_to_hf_hub(): ["vitstr_small", "recognition", "Felix92/doctr-dummy-torch-vitstr-small"], ["parseq", "recognition", "Felix92/doctr-dummy-torch-parseq"], ["viptr_tiny", "recognition", "Felix92/doctr-dummy-torch-viptr-tiny"], - # ["lw_detr_s", "layout", "Felix92/doctr-dummy-torch-lw-detr-s"], + ["lw_detr_s", "layout", "Felix92/doctr-dummy-torch-lw-detr-s"], ], ) def test_models_huggingface_hub(arch_name, task_name, dummy_model_id, tmpdir): diff --git a/tests/pytorch/test_models_layout.py b/tests/pytorch/test_models_layout.py index f42cb95385..5eb3f582e3 100644 --- a/tests/pytorch/test_models_layout.py +++ b/tests/pytorch/test_models_layout.py @@ -76,7 +76,7 @@ def test_layout_models(arch_name, input_shape, train_mode, use_polygons): assert isinstance(results[1], np.ndarray) and results[1].shape == (len(results[0]), 4) assert isinstance(results[2], list) and all(isinstance(scores, float) for scores in results[2]) # Check class idxs are in the model's num_classes - assert all(0 <= idx < len(model.class_names) for idx in results[0]) + assert all(0 <= idx < model.num_classes for idx in results[0]) # Check scores are between 0 and 1 assert all(0 <= score <= 1 for score in results[2]) # Check that the number of boxes, labels and scores are the same From 90cb6ffe218203f274622b228b3d2b7978db7871 Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 15 Jun 2026 11:48:50 +0200 Subject: [PATCH 14/15] minor fixes --- doctr/models/layout/lw_detr/base.py | 6 +-- doctr/transforms/modules/base.py | 37 ++++++++++++---- doctr/transforms/modules/pytorch.py | 2 +- tests/pytorch/test_transforms_pt.py | 69 ++++++++++++++++++++++++----- 4 files changed, 91 insertions(+), 23 deletions(-) diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index cbe9c80229..11bab964a4 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -272,11 +272,9 @@ def to_quad(box: np.ndarray): if box.shape == (4,): x1, y1, x2, y2 = box return np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]], dtype=np.float32) - if box.shape == (8,): - return box.reshape(4, 2) if box.shape == (4, 2): return box.astype(np.float32) - raise ValueError(f"Unsupported box shape: {box.shape}") + raise ValueError(f"Unsupported box shape: {box.shape}") # pragma: no cover for sample in target: boxes_all = [] @@ -284,7 +282,7 @@ def to_quad(box: np.ndarray): for class_name, boxes in sample.items(): if class_name not in class_to_id: - raise ValueError(f"Unknown class name: {class_name}") + raise ValueError(f"Unknown class name: {class_name}") # pragma: no cover cls_id = class_to_id[class_name] boxes = np.asarray(boxes) diff --git a/doctr/transforms/modules/base.py b/doctr/transforms/modules/base.py index a53acf038f..5b68d45b06 100644 --- a/doctr/transforms/modules/base.py +++ b/doctr/transforms/modules/base.py @@ -272,7 +272,14 @@ def __init__(self, scale: tuple[float, float] = (0.08, 1.0), ratio: tuple[float, def extra_repr(self) -> str: return f"scale={self.scale}, ratio={self.ratio}" - def _crop_array(self, img: Any, target: np.ndarray, crop_box): + def _crop_image_only(self, img: Any, crop_box: tuple[float, float, float, float]) -> Any: + dummy_box = np.array([[0, 0, 1, 1]], dtype=np.float32) + cropped_img, _ = F.crop_detection(img, dummy_box, crop_box) + return cropped_img + + def _crop_array( + self, img: Any, target: np.ndarray, crop_box: tuple[float, float, float, float] + ) -> tuple[Any, np.ndarray]: is_polygon = target.shape[1:] == (4, 2) if is_polygon: @@ -339,17 +346,15 @@ def __call__(self, sample: Sample) -> Sample: (y + crop_h) / h, ) - r_mask = None - if mask is not None: - r_mask, _ = self._crop_array(mask, np.zeros((0, 4)), crop_box) - if target is None: - r_img, _ = self._crop_array(img, np.zeros((0, 4)), crop_box) + r_img = self._crop_image_only(img, crop_box) + r_mask = self._crop_image_only(mask, crop_box) if mask is not None else None return sample.replace(image=r_img, mask=r_mask) if isinstance(target, dict): cropped_targets = {} cropped_img = None + crop_rejected = False for cls_name, arr in target.items(): if len(arr) == 0: @@ -358,13 +363,29 @@ def __call__(self, sample: Sample) -> Sample: c_img, c_arr = self._crop_array(img, arr, crop_box) + if c_img is img: + crop_rejected = True + break + if cropped_img is None: cropped_img = c_img cropped_targets[cls_name] = c_arr - final_img = cropped_img if cropped_img is not None else img - return sample.replace(image=final_img, mask=r_mask, target=cropped_targets) + if crop_rejected or cropped_img is None: + return sample.replace( + image=img, + mask=mask, + target={key: value.copy() for key, value in target.items()}, + ) + + r_mask = self._crop_image_only(mask, crop_box) if mask is not None else None + return sample.replace(image=cropped_img, mask=r_mask, target=cropped_targets) c_img, c_target = self._crop_array(img, target, crop_box) + if c_img is img: + r_mask = mask + else: + r_mask = self._crop_image_only(mask, crop_box) if mask is not None else None + return sample.replace(image=c_img, mask=r_mask, target=c_target) diff --git a/doctr/transforms/modules/pytorch.py b/doctr/transforms/modules/pytorch.py index c7245f3a82..c6b7881b56 100644 --- a/doctr/transforms/modules/pytorch.py +++ b/doctr/transforms/modules/pytorch.py @@ -380,7 +380,7 @@ def __call__(self, sample: Sample) -> Sample: else: shadowed_image = random_shadow(sample.image, self.opacity_range).clip(0, 1) return sample.replace(image=shadowed_image) - except ValueError: + except ValueError: # pragma: no cover return sample def extra_repr(self) -> str: diff --git a/tests/pytorch/test_transforms_pt.py b/tests/pytorch/test_transforms_pt.py index 2a41a712b1..353dc87625 100644 --- a/tests/pytorch/test_transforms_pt.py +++ b/tests/pytorch/test_transforms_pt.py @@ -133,6 +133,21 @@ def test_resize(): assert mask.shape == (64, 64) assert mask.dtype == torch.bool + # Test with included mask in input sample + input_t = Sample( + image=torch.ones((3, 32, 64), dtype=torch.float32), + mask=torch.zeros((32, 64), dtype=torch.bool), + target=target_boxes, + ) + data = transfo(input_t) + out, mask, new_target = data.image, data.mask, data.target + + assert out.shape[-2:] == (64, 64) + assert new_target.shape == target_boxes.shape + assert np.all((0 <= new_target) & (new_target <= 1)) + assert mask.shape == (64, 64) + assert mask.dtype == torch.bool + # Test with invalid target shape input_t = torch.ones((3, 32, 64), dtype=torch.float32) target = np.ones((2, 5)) # Invalid shape @@ -312,6 +327,18 @@ def test_random_rotate(): assert r_targets["boxes"].shape == (0, 4) assert r_targets["polygons"].shape == (0, 4, 2) + # Test with mask in input sample + input_m = torch.zeros((50, 50), dtype=torch.bool) + data = rotator(Sample(image=input_t, mask=input_m, target=boxes)) + r_img, r_mask = data.image, data.mask + assert r_img.ndim == input_t.ndim + assert r_mask.ndim - 1 == input_m.ndim # Mask should be 2D + + # Test without target + data = rotator(Sample(image=input_t)) + r_img = data.image + assert r_img.ndim == input_t.ndim + # FP16 (only on GPU) if torch.cuda.is_available(): input_t = torch.ones((3, 50, 50), dtype=torch.float16).cuda() @@ -354,8 +381,8 @@ def test_crop_detection(): @pytest.mark.parametrize( "target", [ - np.array([[15, 20, 35, 30]]), # box - np.array([[[15, 20], [35, 20], [35, 30], [15, 30]]]), # polygon + np.array([[15, 20, 35, 30]], dtype=np.float32), # box + np.array([[[15, 20], [35, 20], [35, 30], [15, 30]]], dtype=np.float32), # polygon ], ) def test_random_crop(target): @@ -363,23 +390,28 @@ def test_random_crop(target): assert repr(cropper) == "RandomCrop(scale=(0.5, 1.0), ratio=(0.75, 1.33))" input_t = torch.ones((3, 50, 50), dtype=torch.float32) - sample = cropper(Sample(image=input_t, target=target)) - img, target = sample.image, sample.target + original_target = target.copy() + + sample = cropper(Sample(image=input_t, target=original_target.copy())) + img, cropped_target = sample.image, sample.target + # Check the scale assert img.shape[-1] * img.shape[-2] >= 0.4 * input_t.shape[-1] * input_t.shape[-2] # Check aspect ratio assert 0.65 <= img.shape[-2] / img.shape[-1] <= 1.6 # Check the target - assert np.all(target >= 0) - if target.ndim == 2: - assert np.all(target[:, [0, 2]] <= img.shape[-1]) and np.all(target[:, [1, 3]] <= img.shape[-2]) + assert np.all(cropped_target >= 0) + if cropped_target.ndim == 2: + assert np.all(cropped_target[:, [0, 2]] <= img.shape[-1]) + assert np.all(cropped_target[:, [1, 3]] <= img.shape[-2]) else: - assert np.all(target[..., 0] <= img.shape[-1]) and np.all(target[..., 1] <= img.shape[-2]) + assert np.all(cropped_target[..., 0] <= img.shape[-1]) + assert np.all(cropped_target[..., 1] <= img.shape[-2]) # Test dict targets dict_target = { - "boxes": np.array([[15, 20, 35, 30]]), - "polygons": np.array([[[15, 20], [35, 20], [35, 30], [15, 30]]]), + "boxes": np.array([[15, 20, 35, 30]], dtype=np.float32), + "polygons": np.array([[[15, 20], [35, 20], [35, 30], [15, 30]]], dtype=np.float32), } sample = cropper(Sample(image=input_t, target=dict_target)) img, cropped_targets = sample.image, sample.target @@ -401,6 +433,22 @@ def test_random_crop(target): assert np.all(cropped_targets["polygons"][..., 0] <= img.shape[-1]) assert np.all(cropped_targets["polygons"][..., 1] <= img.shape[-2]) + # Test with mask in input sample + input_m = torch.ones((50, 50), dtype=torch.bool) + sample = cropper(Sample(image=input_t, mask=input_m, target=original_target.copy())) + img, mask, cropped_target = sample.image, sample.mask, sample.target + + assert img.shape[-1] * img.shape[-2] >= 0.4 * input_t.shape[-1] * input_t.shape[-2] + assert 0.65 <= img.shape[-2] / img.shape[-1] <= 1.6 + assert mask.shape == img.shape[-2:] + assert mask.dtype == torch.bool + + # Test without target + sample = cropper(Sample(image=input_t)) + img = sample.image + assert img.shape[-1] * img.shape[-2] >= 0.4 * input_t.shape[-1] * input_t.shape[-2] + assert 0.65 <= img.shape[-2] / img.shape[-1] <= 1.6 + @pytest.mark.parametrize( "input_dtype, input_size", @@ -450,6 +498,7 @@ def test_gaussian_noise(input_dtype, input_shape): assert torch.all(transformed <= 255) else: assert torch.all(transformed <= 1.0) + assert repr(transform) == "GaussianNoise(mean=0.0, std=1.0)" @pytest.mark.parametrize( From 1beceef4b0fab01514540f2e4fe686d160b9b9a5 Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 15 Jun 2026 12:02:10 +0200 Subject: [PATCH 15/15] Remove unused func + docstring update --- doctr/models/layout/lw_detr/base.py | 2 +- references/layout/train.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index 11bab964a4..b15ca54099 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -144,7 +144,7 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int for b in range(boxes.shape[0]): # Sigmoid scores (the model is trained with a sigmoid-based (IA-BCE) loss without - # a background class, as in LW-DETR) + # a background class) prob = 1.0 / (1.0 + np.exp(-logits[b])) # (num_queries, num_classes) num_classes = prob.shape[-1] diff --git a/references/layout/train.py b/references/layout/train.py index 68339f761a..8f9ce372b1 100644 --- a/references/layout/train.py +++ b/references/layout/train.py @@ -479,10 +479,6 @@ def main(args): # construct DDP model model = DDP(model, device_ids=[rank]) - def is_backbone_param(name: str) -> bool: - name = name.removeprefix("module.") - return name.startswith("feat_extractor.") - param_groups = build_param_groups( model, lr=args.lr,