Skip to content

Commit ae2a45e

Browse files
committed
[fix] Modify Mask R-CNN wrapper to use tensor-based inputs for fvcore MAC computation
1 parent 96e48b0 commit ae2a45e

1 file changed

Lines changed: 29 additions & 12 deletions

File tree

compressai_vision/utils/measure_complexity.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,32 @@ def __init__(self, image_sizes):
166166
and roi_heads.mask_head is not None
167167
and hasattr(roi_heads, "mask_pooler")
168168
):
169-
with torch.no_grad():
170-
mask_pooled = roi_heads.mask_pooler(feat_list, boxes)
169+
# Only run when mask task is actually enabled
170+
if sum(len(p) for p in proposals) > 0:
171171

172-
mask_head_model = MaskHeadFvcoreWrapper(roi_heads).eval()
173-
kmacs_sum += measure_kmacs(mask_head_model, mask_pooled)
172+
# Run ROIHeads once to obtain pred_instances
173+
with torch.no_grad():
174+
pred_instances, _ = roi_heads(images, feature_pyramid, proposals, None)
175+
176+
# Skip if no detected objects
177+
if sum(len(p) for p in pred_instances) > 0:
178+
179+
# Mask pooling requires pred_boxes
180+
mask_boxes = [p.pred_boxes for p in pred_instances]
181+
182+
with torch.no_grad():
183+
mask_pooled = roi_heads.mask_pooler(feat_list, mask_boxes)
184+
185+
pred_classes = torch.cat([p.pred_classes for p in pred_instances])
186+
mask_head_model = MaskHeadFvcoreWrapper(roi_heads, pred_classes).eval()
187+
kmacs_sum += measure_kmacs(mask_head_model, mask_pooled)
188+
189+
#with torch.no_grad():
190+
# mask_pooled = roi_heads.mask_pooler(feat_list, boxes)
191+
#
192+
#if sum(len(p) for p in proposals) > 0:
193+
# mask_head_model = MaskHeadFvcoreWrapper(roi_heads, proposals).eval()
194+
# kmacs_sum += measure_kmacs(mask_head_model, mask_pooled)
174195

175196
# ---------- Pixel count (unchanged) ----------
176197
pixels = sum([reduce(operator.mul, list(d.shape)) for d in data.values()])
@@ -385,18 +406,14 @@ def forward(self, box_features):
385406

386407

387408
class MaskHeadFvcoreWrapper(nn.Module):
388-
"""
389-
Wrapper to measure FLOPs only for the mask head (if available).
390-
391-
Input shape:
392-
(num_boxes, C, pool_h, pool_w)
393-
"""
394-
def __init__(self, roi_heads):
409+
def __init__(self, roi_heads, pred_classes):
395410
super().__init__()
396411
self.mask_head = roi_heads.mask_head
412+
self.pred_classes = pred_classes
397413

398414
def forward(self, mask_features):
399-
return self.mask_head(mask_features)
415+
# simulate detectron2 mask inference
416+
return self.mask_head.layers(mask_features)
400417

401418
class DarknetBackboneOnlyFvcoreWrapper(nn.Module):
402419
"""

0 commit comments

Comments
 (0)