Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 26 additions & 25 deletions pycvvdp/cvvdp_ml_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,44 +1697,45 @@ def do_pooling_and_jods(self, features):

# features[band][batch,frames,width,height,channels,stat]
# disables_features is an array of indices of the stat to be disabled

# no_channels = features[0].shape[3]
# no_frames = features[0].shape[0]
no_bands = len(features)
batch_sz = features[0].shape[0]
is_image = (features[0].shape[4] == 3) # if 3 channels, it is an image

Q_JOD = torch.ones((batch_sz), device=self.device)*10.

is_image = (features[0].shape[4]==3) # if 3 channels, it is an image

sizes, fTR_chunks, fD_chunks = [], [], []
for bb in range(no_bands):

#F[batch,frames,width,height,channels,stat]
f = features[bb]

# Variance into std
f[...,1::2] = torch.sqrt(torch.abs(f[...,1::2]))
f[..., 1::2] = torch.sqrt(torch.abs(f[..., 1::2]))

if is_image:
f = torch.cat( (f, torch.zeros((f.shape[0:4] + (1,f.shape[5])), device=self.device)), dim=4) # Add the missing channel
f = torch.cat(
(f, torch.zeros((f.shape[0:4] + (1, f.shape[5])), device=self.device)),
dim=4,
) # Add the missing channel
if self.disabled_features is not None:
f[..., self.disabled_features] = 0
f[..., self.disabled_features] = 0

f_TR = f[..., 0:4].flatten( start_dim=4 )
f_D = f[..., 4:].flatten( start_dim=4 )
f_TR = f[..., 0:4].flatten(start_dim=4)
f_D = f[..., 4:].flatten(start_dim=4)

Att = self.att_net(f_TR)
Att = F.relu(Att)
D_all = self.feature_net(f_D)
D_all = F.relu(D_all) * Att /no_bands
fTR_chunks.append(f_TR.reshape(batch_sz, -1, f_TR.shape[-1]))
fD_chunks.append(f_D.reshape(batch_sz, -1, f_D.shape[-1]))
sizes.append(fTR_chunks[-1].shape[1])

is_base_band = (bb==no_bands-1)
if is_base_band:
D_all *= self.baseband_weight
# One fused MLP call across all bands (vs. no_bands calls per loop)
f_TR_cat = torch.cat(fTR_chunks, dim=1)
f_D_cat = torch.cat(fD_chunks, dim=1)
Att_cat = F.relu(self.att_net(f_TR_cat))
D_all_cat = F.relu(self.feature_net(f_D_cat)) * Att_cat / no_bands

Q_JOD = torch.ones((batch_sz), device=self.device) * 10.
for bb, D_all in enumerate(D_all_cat.split(sizes, dim=1)):
is_base_band = (bb == no_bands - 1)
if is_base_band:
D_all = D_all * self.baseband_weight
if is_image:
D_all *= self.image_int

D_all = D_all * self.image_int
Q_JOD -= self.spatiotemporal_pooling(D_all)

assert(not Q_JOD.isnan().any())
Expand All @@ -1745,7 +1746,7 @@ def full_name(self):

def spatiotemporal_pooling(self, D_all):
return D_all.view(D_all.shape[0],-1).mean(dim=1)


register_metric( cvvdp_ml_saliency )

Expand Down