diff --git a/pycvvdp/cvvdp_ml_metric.py b/pycvvdp/cvvdp_ml_metric.py index 440ff7d..d5ae20d 100644 --- a/pycvvdp/cvvdp_ml_metric.py +++ b/pycvvdp/cvvdp_ml_metric.py @@ -1697,44 +1697,45 @@ def do_pooling_and_jods(self, features): # features[band][batch,frames,width,height,channels,stat] # disables_features is an array of indices of the stat to be disabled - - # no_channels = features[0].shape[3] - # no_frames = features[0].shape[0] no_bands = len(features) batch_sz = features[0].shape[0] + is_image = (features[0].shape[4] == 3) # if 3 channels, it is an image - Q_JOD = torch.ones((batch_sz), device=self.device)*10. - - is_image = (features[0].shape[4]==3) # if 3 channels, it is an image - + sizes, fTR_chunks, fD_chunks = [], [], [] for bb in range(no_bands): - - #F[batch,frames,width,height,channels,stat] f = features[bb] - + # Variance into std - f[...,1::2] = torch.sqrt(torch.abs(f[...,1::2])) + f[..., 1::2] = torch.sqrt(torch.abs(f[..., 1::2])) if is_image: - f = torch.cat( (f, torch.zeros((f.shape[0:4] + (1,f.shape[5])), device=self.device)), dim=4) # Add the missing channel + f = torch.cat( + (f, torch.zeros((f.shape[0:4] + (1, f.shape[5])), device=self.device)), + dim=4, + ) # Add the missing channel if self.disabled_features is not None: - f[..., self.disabled_features] = 0 + f[..., self.disabled_features] = 0 - f_TR = f[..., 0:4].flatten( start_dim=4 ) - f_D = f[..., 4:].flatten( start_dim=4 ) + f_TR = f[..., 0:4].flatten(start_dim=4) + f_D = f[..., 4:].flatten(start_dim=4) - Att = self.att_net(f_TR) - Att = F.relu(Att) - D_all = self.feature_net(f_D) - D_all = F.relu(D_all) * Att /no_bands + fTR_chunks.append(f_TR.reshape(batch_sz, -1, f_TR.shape[-1])) + fD_chunks.append(f_D.reshape(batch_sz, -1, f_D.shape[-1])) + sizes.append(fTR_chunks[-1].shape[1]) - is_base_band = (bb==no_bands-1) - if is_base_band: - D_all *= self.baseband_weight + # One fused MLP call across all bands (vs. no_bands calls per loop) + f_TR_cat = torch.cat(fTR_chunks, dim=1) + f_D_cat = torch.cat(fD_chunks, dim=1) + Att_cat = F.relu(self.att_net(f_TR_cat)) + D_all_cat = F.relu(self.feature_net(f_D_cat)) * Att_cat / no_bands + Q_JOD = torch.ones((batch_sz), device=self.device) * 10. + for bb, D_all in enumerate(D_all_cat.split(sizes, dim=1)): + is_base_band = (bb == no_bands - 1) + if is_base_band: + D_all = D_all * self.baseband_weight if is_image: - D_all *= self.image_int - + D_all = D_all * self.image_int Q_JOD -= self.spatiotemporal_pooling(D_all) assert(not Q_JOD.isnan().any()) @@ -1745,7 +1746,7 @@ def full_name(self): def spatiotemporal_pooling(self, D_all): return D_all.view(D_all.shape[0],-1).mean(dim=1) - + register_metric( cvvdp_ml_saliency )