While running example_inference.py, I observed numerical instability in the forward pass of the model. The activations rapidly grow in magnitude across layers, reaching extreme values, this starts in layer 11. I added the following prints to the StripedHyena class:
def stateful_forward(self, x, inference_params_dict=None):
for block_idx, block in enumerate(self.blocks):
block_name = "mha" if block_idx in self.config.attn_layer_idxs else "hyena"
inference_params = inference_params_dict[block_name]
x, _ = block(x, inference_params=inference_params)
# Added code
print(f"Block {block_idx} stats - min: {x.min().item():.3f}, max: {x.max().item():.3f}, mean: {x.mean().item():.3f}")
return x, inference_params_dict
def stateless_forward(self, x, padding_mask=None):
if type(padding_mask) == torch.Tensor:
x = x * padding_mask[..., None]
for block_idx, block in enumerate(self.blocks):
x, _ = block(x, inference_params=None, padding_mask=padding_mask)
# Added code
print(f"Block {block_idx} stats - min: {x.min().item():.3f}, max: {x.max().item():.3f}, mean: {x.mean().item():.3f}")
return x, None
Block 0 stats - min: -1.430, max: 1.688, mean: 0.000
Block 1 stats - min: -51.000, max: 54.500, mean: -0.003
Block 2 stats - min: -60.250, max: 74.000, mean: -0.007
Block 3 stats - min: -69.500, max: 88.000, mean: -0.009
Block 4 stats - min: -73.500, max: 93.000, mean: -0.009
Block 5 stats - min: -78.000, max: 99.000, mean: -0.012
Block 6 stats - min: -91.500, max: 115.500, mean: -0.015
Block 7 stats - min: -140.000, max: 180.000, mean: -0.008
Block 8 stats - min: -75.000, max: 79.500, mean: 0.004
Block 9 stats - min: -72.000, max: 88.500, mean: -0.011
Block 10 stats - min: -752.000, max: 724.000, mean: -0.586
Block 11 stats - min: -464896.000, max: 462848.000, mean: -130.000
Block 12 stats - min: -14548992.000, max: 18743296.000, mean: -18304.000
Block 13 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 14 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 15 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 16 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 17 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 18 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 19 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 20 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 21 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 22 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 23 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 24 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 25 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 26 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 27 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 28 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 29 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 30 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 31 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Hello 😁
While running example_inference.py, I observed numerical instability in the forward pass of the model. The activations rapidly grow in magnitude across layers, reaching extreme values, this starts in layer 11. I added the following prints to the StripedHyena class:
After running example_inference.py I got the following results:
Issues
I got a good inference result at the end:
I this what is supposed to happen or am I missing something?
Thank you!