@@ -67,8 +67,10 @@ def forward(self, preds, targets):
6767 ## 首先计算所有边界框的置信度损失(假定不存在obj)
6868 loss = self .noobj * self .sum_squared_error (pred_confidences , target_confidences )
6969
70- # 选取每个网格中置信度最高的边界框
71- top_idxs = torch .argmax (pred_confidences , dim = 1 )
70+ # 计算每个预测边界框与对应目标边界框的IoU
71+ iou_scores = self .iou (pred_bboxs .reshape (- 1 , 4 ), target_bboxs .reshape (- 1 , 4 )).reshape (- 1 , 2 )
72+ # 选取每个网格中IoU最高的边界框
73+ top_idxs = torch .argmax (iou_scores , dim = 1 )
7274 top_len = len (top_idxs )
7375 # 获取相应的置信度以及边界框
7476 top_pred_confidences = pred_confidences [range (top_len ), top_idxs ]
@@ -199,7 +201,7 @@ def iou(self, pred_boxs, target_boxs):
199201 xB = np .minimum (pred_boxs [:, 0 ] + pred_boxs [:, 2 ] / 2 , target_boxs [:, 0 ] + target_boxs [:, 2 ] / 2 )
200202 yB = np .minimum (pred_boxs [:, 1 ] + pred_boxs [:, 3 ] / 2 , target_boxs [:, 1 ] + target_boxs [:, 3 ] / 2 )
201203 # 计算交集面积
202- intersection = np .maximum (0.0 , xB - xA ) * np .maximum (0.0 , yB - yA )
204+ intersection = np .maximum (0.0 , xB - xA + 1 ) * np .maximum (0.0 , yB - yA + 1 )
203205 # 计算两个边界框面积
204206 boxAArea = pred_boxs [:, 2 ] * pred_boxs [:, 3 ]
205207 boxBArea = target_boxs [:, 2 ] * target_boxs [:, 3 ]
0 commit comments