|
3 | 3 | from functools import reduce |
4 | 4 |
|
5 | 5 | import torch |
| 6 | +import torch.nn as nn |
6 | 7 |
|
7 | 8 | from ptflops import get_model_complexity_info |
8 | 9 |
|
@@ -72,7 +73,7 @@ def calc_complexity_nn_part1_plyr(vision_model, img): |
72 | 73 | return kmacs, pixels |
73 | 74 |
|
74 | 75 |
|
75 | | -def calc_complexity_nn_part2_plyr(vision_model, data, dec_features): |
| 76 | +def calc_complexity_nn_part2_plyr(vision_model, dec_features, data): |
76 | 77 | if isinstance(data[0], list): # image task |
77 | 78 | data = {k: v[0] for k, v in data.items()} |
78 | 79 |
|
@@ -147,6 +148,120 @@ def get_downsampled_shape(h, w, ratio): |
147 | 148 | return h, w |
148 | 149 |
|
149 | 150 |
|
| 151 | +class YoloxPart1(nn.Module): |
| 152 | + def __init__(self, vision_model, split_id): |
| 153 | + super().__init__() |
| 154 | + self.backbone = vision_model.backbone |
| 155 | + self.split_id = split_id |
| 156 | + self.squeeze_at_split_enabled = vision_model.squeeze_at_split_enabled |
| 157 | + if self.squeeze_at_split_enabled: |
| 158 | + self.squeeze_model = vision_model.squeeze_model |
| 159 | + |
| 160 | + def forward(self, x): |
| 161 | + if self.split_id == "l13": |
| 162 | + y = self.backbone.stem(x) |
| 163 | + y = self.backbone.dark2(y) |
| 164 | + y = self.backbone.dark3[0](y) |
| 165 | + if self.squeeze_at_split_enabled: |
| 166 | + y = self.squeeze_model.squeeze_(y) |
| 167 | + elif self.split_id == "l37": |
| 168 | + y = self.backbone.stem(x) |
| 169 | + y = self.backbone.dark2(y) |
| 170 | + y = self.backbone.dark3(y) |
| 171 | + return y |
| 172 | + |
| 173 | + |
| 174 | +class YoloxPart2(nn.Module): |
| 175 | + def __init__(self, vision_model, split_id): |
| 176 | + super().__init__() |
| 177 | + self.backbone = vision_model.backbone |
| 178 | + self.out1_cbl = vision_model.yolo_fpn.out1_cbl |
| 179 | + self.out1 = vision_model.yolo_fpn.out1 |
| 180 | + self.out2_cbl = vision_model.yolo_fpn.out2_cbl |
| 181 | + self.out2 = vision_model.yolo_fpn.out2 |
| 182 | + self.upsample = vision_model.yolo_fpn.upsample |
| 183 | + self.head = vision_model.head |
| 184 | + self.split_id = split_id |
| 185 | + self.squeeze_at_split_enabled = vision_model.squeeze_at_split_enabled |
| 186 | + if self.squeeze_at_split_enabled: |
| 187 | + self.squeeze_model = vision_model.squeeze_model |
| 188 | + # self.postprocess = vision_model.postprocess # Not needed for MAC calc |
| 189 | + |
| 190 | + def forward(self, x): |
| 191 | + y = x |
| 192 | + if self.split_id == "l13": |
| 193 | + if self.squeeze_at_split_enabled: |
| 194 | + y = self.squeeze_model.expand_(y) |
| 195 | + for proc_module in self.backbone.dark3[1:]: |
| 196 | + y = proc_module(y) |
| 197 | + |
| 198 | + fp_lvl2 = y |
| 199 | + fp_lvl1 = self.backbone.dark4(fp_lvl2) |
| 200 | + fp_lvl0 = self.backbone.dark5(fp_lvl1) |
| 201 | + |
| 202 | + # yolo branch 1 |
| 203 | + b1_in = self.out1_cbl(fp_lvl0) |
| 204 | + b1_in = self.upsample(b1_in) |
| 205 | + b1_in = torch.cat([b1_in, fp_lvl1], 1) |
| 206 | + fp_lvl1 = self.out1(b1_in) |
| 207 | + |
| 208 | + # yolo branch 2 |
| 209 | + b2_in = self.out2_cbl(fp_lvl1) |
| 210 | + b2_in = self.upsample(b2_in) |
| 211 | + b2_in = torch.cat([b2_in, fp_lvl2], 1) |
| 212 | + fp_lvl2 = self.out2(b2_in) |
| 213 | + |
| 214 | + outputs = self.head((fp_lvl2, fp_lvl1, fp_lvl0)) |
| 215 | + return outputs |
| 216 | + |
| 217 | + |
| 218 | +def calc_complexity_nn_part1_yolox(vision_model, img): |
| 219 | + device = torch.device(vision_model.device) |
| 220 | + img = img[0]["image"].unsqueeze(0).to(device) |
| 221 | + |
| 222 | + partial_model = YoloxPart1(vision_model, vision_model.split_id) |
| 223 | + |
| 224 | + C, H, W = img.shape[1:] |
| 225 | + |
| 226 | + kmacs, _ = measure_mac( |
| 227 | + partial_model=partial_model, |
| 228 | + input_res=(C, H, W), |
| 229 | + input_constructor=None, |
| 230 | + ) |
| 231 | + |
| 232 | + pixels = reduce(operator.mul, [p_size for p_size in img.shape]) |
| 233 | + return kmacs, pixels |
| 234 | + |
| 235 | + |
| 236 | +def calc_complexity_nn_part2_yolox(vision_model, dec_features): |
| 237 | + assert "data" in dec_features |
| 238 | + |
| 239 | + x_data = dec_features["data"] |
| 240 | + |
| 241 | + x_data = { |
| 242 | + k: (v[0] if isinstance(x_data[0], list) else v).to(vision_model.device) |
| 243 | + for k, v in zip(vision_model.split_layer_list, x_data.values()) |
| 244 | + } |
| 245 | + |
| 246 | + input_tensor = x_data[vision_model.split_id] |
| 247 | + |
| 248 | + if input_tensor.dim() == 3: |
| 249 | + input_tensor = input_tensor.unsqueeze(0) |
| 250 | + |
| 251 | + C, H, W = input_tensor.shape[1:] |
| 252 | + partial_model = YoloxPart2(vision_model, vision_model.split_id) |
| 253 | + |
| 254 | + kmacs, _ = measure_mac( |
| 255 | + partial_model=partial_model, |
| 256 | + input_res=(C, H, W), |
| 257 | + input_constructor=None, |
| 258 | + ) |
| 259 | + |
| 260 | + pixels = reduce(operator.mul, input_tensor.shape) |
| 261 | + |
| 262 | + return kmacs, pixels |
| 263 | + |
| 264 | + |
150 | 265 | def prepare_proposal_input_fpn(resolutions): |
151 | 266 | b, c, h, w = resolutions[1] |
152 | 267 | resized_img = resolutions[0] |
|
0 commit comments