Skip to content

Numerical Instability in Forward Pass #112

@thalaby

Description

@thalaby

Hello 😁

While running example_inference.py, I observed numerical instability in the forward pass of the model. The activations rapidly grow in magnitude across layers, reaching extreme values, this starts in layer 11. I added the following prints to the StripedHyena class:

    def stateful_forward(self, x, inference_params_dict=None):
        for block_idx, block in enumerate(self.blocks):
            block_name = "mha" if block_idx in self.config.attn_layer_idxs else "hyena"
            inference_params = inference_params_dict[block_name]
            x, _ = block(x, inference_params=inference_params)
           # Added code
            print(f"Block {block_idx} stats - min: {x.min().item():.3f}, max: {x.max().item():.3f}, mean: {x.mean().item():.3f}")
        return x, inference_params_dict
    def stateless_forward(self, x, padding_mask=None):
        if type(padding_mask) == torch.Tensor:
            x = x * padding_mask[..., None]

        for block_idx, block in enumerate(self.blocks):
            x, _ = block(x, inference_params=None, padding_mask=padding_mask)
            # Added code
            print(f"Block {block_idx} stats - min: {x.min().item():.3f}, max: {x.max().item():.3f}, mean: {x.mean().item():.3f}")
        return x, None

After running example_inference.py I got the following results:

Block 0 stats - min: -1.430, max: 1.688, mean: 0.000
Block 1 stats - min: -51.000, max: 54.500, mean: -0.003
Block 2 stats - min: -60.250, max: 74.000, mean: -0.007
Block 3 stats - min: -69.500, max: 88.000, mean: -0.009
Block 4 stats - min: -73.500, max: 93.000, mean: -0.009
Block 5 stats - min: -78.000, max: 99.000, mean: -0.012
Block 6 stats - min: -91.500, max: 115.500, mean: -0.015
Block 7 stats - min: -140.000, max: 180.000, mean: -0.008
Block 8 stats - min: -75.000, max: 79.500, mean: 0.004
Block 9 stats - min: -72.000, max: 88.500, mean: -0.011
Block 10 stats - min: -752.000, max: 724.000, mean: -0.586
Block 11 stats - min: -464896.000, max: 462848.000, mean: -130.000
Block 12 stats - min: -14548992.000, max: 18743296.000, mean: -18304.000
Block 13 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 14 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 15 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 16 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 17 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 18 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 19 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 20 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 21 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 22 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 23 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 24 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 25 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 26 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 27 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 28 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 29 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 30 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 31 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000

Issues

  • The activations explode in magnitude as they propagate through the layers.
  • By block 10, values reach ±750, and by block 11, they explode to ±460,000.
  • From block 13 onward, the model produces massive activation values in the range of ±394 million.

I got a good inference result at the end:

Image

I this what is supposed to happen or am I missing something?

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions