|
x = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) |
Why should µ_t0 and α_t0 be assigned to zero vectors after passing through MetaBlock?because of the following code:
x = self.proj_out(x) # (b, t, 32)
x = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) #
if self.nvp: # True
xa, xb = x.chunk(2, dim=-1) # alpha, miu