Kronos/model/module.py 中的class BinarySphericalQuantizer(nn.Module):
def forward(self, z, collect_metrics=True):
# if self.input_format == 'bchw':
# z = rearrange(z, 'b c h w -> b h w c')
zq = self.quantize(z)
q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
zq = zq * q_scale
if not collect_metrics:
return zq, zq.new_zeros(()), {}
indices = self.codes_to_indexes(zq.detach())
这部分的 indices = self.codes_to_indexes(zq.detach()) 是否应该放到 zq = self.quantize(z) 之后?zq = zq * q_scale之后已经不{1,-1}了,之后再求indices是否有问题。
Kronos/model/module.py 中的class BinarySphericalQuantizer(nn.Module):
def forward(self, z, collect_metrics=True):
# if self.input_format == 'bchw':
# z = rearrange(z, 'b c h w -> b h w c')
zq = self.quantize(z)
这部分的 indices = self.codes_to_indexes(zq.detach()) 是否应该放到 zq = self.quantize(z) 之后?zq = zq * q_scale之后已经不{1,-1}了,之后再求indices是否有问题。