@@ -43,132 +43,9 @@ def forward(self, preds, targets):
4343 :param targets: (N, S*S, (B*5+C))
4444 :return:
4545 """
46- # print('loss 1', self._process1(preds, targets))
47- # print('loss 2', self._process2(preds, targets))
48- return self ._process3 (preds , targets )
46+ return self ._process (preds , targets )
4947
50- def _process1 (self , preds , targets ):
51- N = preds .shape [0 ]
52- total_loss = 0.0
53- for pred , target in zip (preds , targets ):
54- """
55- 逐个图像计算
56- pred: [S*S, (B*5+C)]
57- target: [S*S, (B*5+C)]
58- """
59- # 分类概率
60- # [S*S, C]
61- pred_probs = pred [:, :self .C ]
62- target_probs = target [:, :self .C ]
63- # 置信度
64- # [S*S, B]
65- pred_confidences = pred [:, self .C :(self .C + self .B )]
66- target_confidences = target [:, self .C :(self .C + self .B )]
67- # 边界框坐标
68- pred_bboxs = pred [:, (self .C + self .B ):]
69- target_bboxs = target [:, (self .C + self .B ):]
70-
71- for i in range (self .S * self .S ):
72- """
73- 逐个网格计算
74- """
75- pred_single_probs = pred_probs [i ]
76- target_single_probs = target_probs [i ]
77-
78- pred_single_confidences = pred_confidences [i ]
79- target_single_confidences = target_confidences [i ]
80-
81- pred_single_bboxs = pred_bboxs [i ]
82- target_single_bboxs = target_bboxs [i ]
83-
84- # 是否存在置信度(如果存在,则target的置信度必然大于0)
85- is_obj = target_single_confidences [0 ] > 0
86- # 计算置信度损失 假定该网格不存在对象
87- total_loss += self .noobj * self .sum_squared_error (pred_single_confidences , target_single_confidences )
88- if is_obj :
89- # 如果存在
90- # 计算分类损失
91- total_loss += self .sum_squared_error (pred_single_probs , target_single_probs )
92-
93- # 计算所有预测边界框和标注边界框的IoU
94- pred_single_bboxs = pred_single_bboxs .reshape (- 1 , 4 )
95- target_single_bboxs = target_single_bboxs .reshape (- 1 , 4 )
96-
97- scores = self .iou (pred_single_bboxs , target_single_bboxs )
98- # 提取IoU最大的下标
99- bbox_idx = torch .argmax (scores )
100- # 计算置信度损失
101- total_loss += (1 - self .noobj ) * \
102- self .sum_squared_error (pred_single_confidences [bbox_idx ],
103- target_single_confidences [bbox_idx ])
104- # 计算边界框损失
105- total_loss += self .coord * self .bbox_loss (pred_single_bboxs [bbox_idx ].reshape (- 1 , 4 ),
106- target_single_bboxs [bbox_idx ].reshape (- 1 , 4 ))
107-
108- return total_loss / N
109-
110- def _process2 (self , preds , targets ):
111- N = preds .shape [0 ]
112- total_loss = 0.0
113- for pred , target in zip (preds , targets ):
114- """
115- 逐个图像计算
116- pred: [S*S, (B*5+C)]
117- target: [S*S, (B*5+C)]
118- """
119- # 分类概率
120- # [S*S, C]
121- pred_probs = pred [:, :self .C ]
122- target_probs = target [:, :self .C ]
123- # 置信度
124- # [S*S, B]
125- pred_confidences = pred [:, self .C :(self .C + self .B )]
126- target_confidences = target [:, self .C :(self .C + self .B )]
127- # 边界框坐标
128- # [S*S, B*4] -> [S*S, B, 4]
129- pred_bboxs = pred [:, (self .C + self .B ):].reshape (self .S * self .S , self .B , 4 )
130- target_bboxs = target [:, (self .C + self .B ):].reshape (self .S * self .S , self .B , 4 )
131-
132- # 统一计算置信度损失
133- total_loss += self .noobj * self .sum_squared_error (pred_confidences , target_confidences )
134- # 计算每个网格预测边界框的IoU
135- # Input: [S*S, B, 4] -> [S*S*B, 4]
136- # Output: [S*S*B] -> [S*S, B]
137- iou_scores = self .iou (pred_bboxs .reshape (- 1 , 4 ), target_bboxs .reshape (- 1 , 4 )).reshape (self .S * self .S ,
138- self .B )
139- # 计算其中最大IoU所属下标
140- # [S*S]
141- top_idxs = torch .argmax (iou_scores , dim = 1 )
142- top_len = len (top_idxs )
143- # 提取对应的边界框以及置信度
144- # [S*S, 4]
145- top_pred_bboxs = pred_bboxs [range (top_len ), top_idxs ]
146- top_pred_confidences = pred_confidences [range (top_len ), top_idxs ]
147- top_target_bboxs = target_bboxs [range (top_len ), top_idxs ]
148- top_target_confidences = target_confidences [range (top_len ), top_idxs ]
149-
150- # 计算网格中是否存在目标
151- # [S*S, C] -> [S*S]
152- obj_idxs = torch .sum (target_probs , dim = 1 ) > 0
153- # 提取对应的目标分类概率、置信度以及边界框坐标
154- # [S*S, C]
155- obj_pred_probs = pred_probs [obj_idxs ]
156- obj_pred_confidences = top_pred_confidences [obj_idxs ]
157- obj_pred_bboxs = top_pred_bboxs [obj_idxs ]
158-
159- obj_target_probs = target_probs [obj_idxs ]
160- obj_target_confidences = top_target_confidences [obj_idxs ]
161- obj_target_bboxs = top_target_bboxs [obj_idxs ]
162-
163- # 计算置信度损失
164- total_loss += (1 - self .noobj ) * self .sum_squared_error (obj_pred_confidences , obj_target_confidences )
165- # 分类概率损失
166- total_loss += self .sum_squared_error (obj_pred_probs , obj_target_probs )
167- # 坐标损失
168- total_loss += self .coord * self .bbox_loss (obj_pred_bboxs , obj_target_bboxs )
169- return total_loss / N
170-
171- def _process3 (self , preds , targets ):
48+ def _process (self , preds , targets ):
17249 N = preds .shape [0 ]
17350 ## 预测
17451 # 提取每个网格的分类概率
0 commit comments