diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 96f2edd205..521f6df7cc 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -123,6 +123,10 @@ def _probiou( Returns: probiou: (N,) or (N, M) tensor of IoU-like similarities in [0, 1] """ + # AMP safe-guard + boxes1 = boxes1.float() + boxes2 = boxes2.float() + if scale != 1.0: boxes1 = torch.cat([boxes1[..., :4] * scale, boxes1[..., 4:]], dim=-1) boxes2 = torch.cat([boxes2[..., :4] * scale, boxes2[..., 4:]], dim=-1) @@ -649,6 +653,9 @@ def compute_loss( loss: the computed loss value """ alpha, gamma, eps = 0.25, 2.0, 1e-8 + # AMP safe-guard + logits = logits.float() + pred_boxes = pred_boxes.float() device = logits.device batch_size = logits.shape[0] @@ -683,7 +690,7 @@ def compute_loss( cost_class = pos_cost[:, tgt_labels] - neg_cost[:, tgt_labels] # L1 cost on normalized (cx, cy, w, h) - cost_bbox = torch.cdist(out_boxes[:, :4], tgt_boxes[:, :4], p=1) + cost_bbox = torch.cdist(out_boxes[:, :4].float(), tgt_boxes[:, :4].float(), p=1) # Rotated IoU cost, computed in pixel coordinates # this term also carries the angle signal for the matching