In INN.BatchNorm1d, the forward function is:
def forward(self, x, log_p=0, log_det_J=0):
if self.compute_p:
if not self.training:
# if in self.eval()
var = self.running_var # [dim]
else:
# if in training
# TODO: Do we need to add .detach() after var?
var = torch.var(x, dim=0, unbiased=False) # [dim]
x = super(BatchNorm1d, self).forward(x)
log_det = -0.5 * torch.log(var + self.eps)
log_det = torch.sum(log_det, dim=-1)
return x, log_p, log_det_J + log_det
else:
return super(BatchNorm1d, self).forward(x)
Do we need to requires var has gradient information? It seems not training BatchNorm1d, but training modules before it. Is there any references on this?
In
INN.BatchNorm1d, the forward function is:Do we need to requires
varhas gradient information? It seems not trainingBatchNorm1d, but training modules before it. Is there any references on this?