@@ -166,11 +166,32 @@ def __init__(self, image_sizes):
166166 and roi_heads .mask_head is not None
167167 and hasattr (roi_heads , "mask_pooler" )
168168 ):
169- with torch . no_grad ():
170- mask_pooled = roi_heads . mask_pooler ( feat_list , boxes )
169+ # Only run when mask task is actually enabled
170+ if sum ( len ( p ) for p in proposals ) > 0 :
171171
172- mask_head_model = MaskHeadFvcoreWrapper (roi_heads ).eval ()
173- kmacs_sum += measure_kmacs (mask_head_model , mask_pooled )
172+ # Run ROIHeads once to obtain pred_instances
173+ with torch .no_grad ():
174+ pred_instances , _ = roi_heads (images , feature_pyramid , proposals , None )
175+
176+ # Skip if no detected objects
177+ if sum (len (p ) for p in pred_instances ) > 0 :
178+
179+ # Mask pooling requires pred_boxes
180+ mask_boxes = [p .pred_boxes for p in pred_instances ]
181+
182+ with torch .no_grad ():
183+ mask_pooled = roi_heads .mask_pooler (feat_list , mask_boxes )
184+
185+ pred_classes = torch .cat ([p .pred_classes for p in pred_instances ])
186+ mask_head_model = MaskHeadFvcoreWrapper (roi_heads , pred_classes ).eval ()
187+ kmacs_sum += measure_kmacs (mask_head_model , mask_pooled )
188+
189+ #with torch.no_grad():
190+ # mask_pooled = roi_heads.mask_pooler(feat_list, boxes)
191+ #
192+ #if sum(len(p) for p in proposals) > 0:
193+ # mask_head_model = MaskHeadFvcoreWrapper(roi_heads, proposals).eval()
194+ # kmacs_sum += measure_kmacs(mask_head_model, mask_pooled)
174195
175196 # ---------- Pixel count (unchanged) ----------
176197 pixels = sum ([reduce (operator .mul , list (d .shape )) for d in data .values ()])
@@ -385,18 +406,14 @@ def forward(self, box_features):
385406
386407
387408class MaskHeadFvcoreWrapper (nn .Module ):
388- """
389- Wrapper to measure FLOPs only for the mask head (if available).
390-
391- Input shape:
392- (num_boxes, C, pool_h, pool_w)
393- """
394- def __init__ (self , roi_heads ):
409+ def __init__ (self , roi_heads , pred_classes ):
395410 super ().__init__ ()
396411 self .mask_head = roi_heads .mask_head
412+ self .pred_classes = pred_classes
397413
398414 def forward (self , mask_features ):
399- return self .mask_head (mask_features )
415+ # simulate detectron2 mask inference
416+ return self .mask_head .layers (mask_features )
400417
401418class DarknetBackboneOnlyFvcoreWrapper (nn .Module ):
402419 """
0 commit comments