Skip to content

[Bug Report] Incorrect logit lens implementation: apply_ln=True uses cached normalization instead of recomputing statistics #1076

@hartigel

Description

@hartigel

[Bug Report] Incorrect logit lens implementation: apply_ln=True uses cached normalization instead of recomputing statistics

Describe the bug

The accumulated_resid() method in ActivationCache is explicitly documented for logit lens analysis in both its docstring and the TransformerLens documentation. However, the apply_ln=True parameter does not implement the logit lens as standardly defined in the literature.

Instead of recomputing layer normalization statistics for each intermediate layer's state, accumulated_resid() calls apply_ln_to_stack() which normalizes all layers using the cached scale from the final layer norm. This produces fundamentally different probability distributions and entropy/KL divergence measurements than the standard logit lens approach, making it unsuitable for distribution analysis despite being the documented method for logit lens experiments.

Current Behavior

The accumulated_resid() method in ActivationCache contains this code:

if layer is None or layer == -1:
    # Default to the residual stream immediately pre unembed
    layer = self.model.cfg.n_layers

if apply_ln:
    components = self.apply_ln_to_stack(
        components, layer, pos_slice=pos_slice, mlp_input=mlp_input
    )

When apply_ln=True and layer is None or -1 (the default, representing all layers up to the unembed), this passes layer = self.model.cfg.n_layers to apply_ln_to_stack().

Inside apply_ln_to_stack(), this results in:

if layer == self.model.cfg.n_layers or layer is None:
    scale = self["ln_final.hook_scale"]
else:
    hook_name = f"blocks.{layer}.ln{2 if mlp_input else 1}.hook_scale"
    scale = self[hook_name]

# ...
return residual_stack / scale

This means the final cached layer norm scale (ln_final.hook_scale) is used to normalize ALL intermediate layers' residual streams, not their own freshly computed statistics.

The implementation:

  1. Centers each layer's residual stream by subtracting that layer's mean (if using LayerNorm)
  2. Divides by the single cached scale (ln_final.hook_scale) that was computed during the original forward pass on the complete accumulated residual stream

All intermediate layers are normalized using the standard deviation from this single cached computation, not their own freshly computed statistics.

Expected Behavior (Standard Logit Lens Definition)

Every major logit lens paper recomputes normalization statistics for each layer's accumulated residual stream:

The common principle is: "What would the model predict if we stopped processing at layer L and immediately applied the unembedding procedure?" This requires applying the model's actual layer norm operation (recomputing both mean and scale) to that layer's current state.

Why This Matters: The Temperature Parameter Effect

The normalization scale acts exactly like a temperature parameter in softmax. Since the residual stream norm grows exponentially through the network, the cached final-layer scale is much larger than what would be computed from earlier layers' states. When you divide early-layer activations by this large scale factor, you get extremely small normalized values. Just as dividing logits by a high temperature flattens the distribution toward uniformity, dividing the residual by a large scale produces small values that, after unembedding and softmax, create an artificially uniform distribution.

This means:

  • One could take the most peaked distribution (near a delta function) and transform it into a uniform distribution simply by changing the normalization scale
  • The distribution statistics (entropy, KL divergence, probabilities) are completely determined by which scale is used
  • The measurements reflect the exponentially growing layer norm scale rather than the model's actual intermediate "beliefs"

Code example

Concrete Example: GPT-2 XL

Prompt: "The capital of France is"

Correct Implementation (recomputed normalization):

  • Layer 0: " is" at 99.997% (expected for uncontextualized input)
  • Layer 36: " Paris" at 59.634%
  • Layer 47: " Paris" at 14.769%

Current apply_ln=True (cached scale):

  • Layer 0: " is" at 0.002% (near-uniform over all 50,257 tokens)
  • Layer 36: " Paris" at 1.731%
  • Layer 47: " Paris" at 18.961%

Both predict the same top tokens (normalization preserves direction), but with completely different probability distributions.

Entropy and KL Divergence Patterns for GPT2

The cached-scale approach produces:

  • Entropy: ~15.6 bits at layer 0 (near maximum for uniform distribution), remaining high across most layers
Image
  • KL divergence: Smooth monotonic decrease from ~8 bits
Image

The correct implementation produces:

  • Entropy: ~0.5 bits at layer 0 (input embedding), sharp spike to ~3.5 bits (context integration), then gradual sharpening
Image
  • KL divergence: ~20 bits at layer 0 (uncontextualized input is very different from final output), sharp drop, then monotonic decrease
Image

The input layer pattern (very low entropy for correct implementation vs. near-maximum entropy for cached scale) holds across model families. This is expected since the input layer is simply the embedded uncontextualized input token at the given position and should thus have very low entropy with most probability mass at the given token and also a substantially higher KL than the rest of the (somewhat contextualized) layers.

Additional Context

The accumulated_resid docstring itself suggests fresh computation:

"This is useful for Logit Lens style analysis, where it can be thought of as what the model 'believes' at each point in the residual stream. To project this into the vocabulary space, remember that there is a final layer norm in most decoder-only transformers. Therefore, you need to first apply the final layer norm (which can be done with apply_ln)..."

The language of "applying" the layer norm and measuring model "beliefs" suggests fresh computation at each layer to simulate what the model would predict at that point, not viewing all layers through a single cached normalization scale.

Proposed Solution

I would recommend changing the default behavior of apply_ln=True to recompute normalization statistics for each layer's residual stream, i.e., actually applying the models own final layer normalization. While this is a breaking change, it aligns with the standard definition of logit lens in the literature and matches what the documentation describes. The current behavior produces measurement artifacts rather than accurate intermediate model beliefs.

Alternative options:

  1. Add a parameter (e.g., recompute_ln=True) to accumulated_resid() and apply_ln_to_stack() that recomputes the layer norm mean and scale freshly for each layer's residual stream instead of using cached values (allows backward compatibility)
  2. Clarify in the documentation that apply_ln=True should not be used for distribution analysis (entropy, KL divergence, probability measurements) as it produces artifacts of the cached normalization rather than accurate intermediate predictions

Related Issues

This bug may help clarify confusion in #523, where a user reports unexpected behavior with accumulated_resid and apply_ln=True when fold_ln=False. While that issue stems from needing to manually apply layer norm weights/biases when folding is disabled, the proposed solution here (having apply_ln=True actually call the model's layer norm operations instead of just using cached scales) would make the behavior more intuitive and work correctly regardless of the fold_ln setting.

Checklist

  • I have checked that there is no similar issue in the repo

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions