-
Notifications
You must be signed in to change notification settings - Fork 489
Description
[Bug Report] Incorrect logit lens implementation: apply_ln=True uses cached normalization instead of recomputing statistics
Describe the bug
The accumulated_resid() method in ActivationCache is explicitly documented for logit lens analysis in both its docstring and the TransformerLens documentation. However, the apply_ln=True parameter does not implement the logit lens as standardly defined in the literature.
Instead of recomputing layer normalization statistics for each intermediate layer's state, accumulated_resid() calls apply_ln_to_stack() which normalizes all layers using the cached scale from the final layer norm. This produces fundamentally different probability distributions and entropy/KL divergence measurements than the standard logit lens approach, making it unsuitable for distribution analysis despite being the documented method for logit lens experiments.
Current Behavior
The accumulated_resid() method in ActivationCache contains this code:
if layer is None or layer == -1:
# Default to the residual stream immediately pre unembed
layer = self.model.cfg.n_layers
if apply_ln:
components = self.apply_ln_to_stack(
components, layer, pos_slice=pos_slice, mlp_input=mlp_input
)When apply_ln=True and layer is None or -1 (the default, representing all layers up to the unembed), this passes layer = self.model.cfg.n_layers to apply_ln_to_stack().
Inside apply_ln_to_stack(), this results in:
if layer == self.model.cfg.n_layers or layer is None:
scale = self["ln_final.hook_scale"]
else:
hook_name = f"blocks.{layer}.ln{2 if mlp_input else 1}.hook_scale"
scale = self[hook_name]
# ...
return residual_stack / scaleThis means the final cached layer norm scale (ln_final.hook_scale) is used to normalize ALL intermediate layers' residual streams, not their own freshly computed statistics.
The implementation:
- Centers each layer's residual stream by subtracting that layer's mean (if using LayerNorm)
- Divides by the single cached scale (
ln_final.hook_scale) that was computed during the original forward pass on the complete accumulated residual stream
All intermediate layers are normalized using the standard deviation from this single cached computation, not their own freshly computed statistics.
Expected Behavior (Standard Logit Lens Definition)
Every major logit lens paper recomputes normalization statistics for each layer's accumulated residual stream:
-
[Nostalgebraist's logit lens](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens) (2020): The [follow-up implementation](https://colab.research.google.com/drive/1MjdfK2srcerLrAJDRaJQKO0sUiZ-hQtA) explicitly applies ln_f to each intermediate layer separately to see what the model would predict at that point.
-
[Tuned Lens](https://arxiv.org/abs/2303.08112) (Belrose et al., 2023): Recomputes layer normalization for each layer's residual stream before projection. Their standard logit lens implementation applies the final layer norm freshly to each intermediate state.
-
[Jump to Conclusions](https://arxiv.org/abs/2411.04229) (Gur et al., 2024): Defines the approach as applying "the final layer norm to the intermediate activations" with fresh computation for each layer.
-
[Patchscopes](https://arxiv.org/abs/2401.06102) (Ghandeharioun et al., 2024): Frames logit lens as patching intermediate representations into the final layer position of a forward pass, which inherently recomputes normalization based on the current intermediate state.
-
[A Primer on the Inner Workings of Transformer-based Language Models](https://arxiv.org/abs/2405.00208) (Ferrando et al., 2024): Explicitly differentiates between logit lens (recomputing normalization for intermediate predictions) and direct logit attribution (using cached normalization statistics).
The common principle is: "What would the model predict if we stopped processing at layer L and immediately applied the unembedding procedure?" This requires applying the model's actual layer norm operation (recomputing both mean and scale) to that layer's current state.
Why This Matters: The Temperature Parameter Effect
The normalization scale acts exactly like a temperature parameter in softmax. Since the residual stream norm grows exponentially through the network, the cached final-layer scale is much larger than what would be computed from earlier layers' states. When you divide early-layer activations by this large scale factor, you get extremely small normalized values. Just as dividing logits by a high temperature flattens the distribution toward uniformity, dividing the residual by a large scale produces small values that, after unembedding and softmax, create an artificially uniform distribution.
This means:
- One could take the most peaked distribution (near a delta function) and transform it into a uniform distribution simply by changing the normalization scale
- The distribution statistics (entropy, KL divergence, probabilities) are completely determined by which scale is used
- The measurements reflect the exponentially growing layer norm scale rather than the model's actual intermediate "beliefs"
Code example
Concrete Example: GPT-2 XL
Prompt: "The capital of France is"
Correct Implementation (recomputed normalization):
- Layer 0: " is" at 99.997% (expected for uncontextualized input)
- Layer 36: " Paris" at 59.634%
- Layer 47: " Paris" at 14.769%
Current apply_ln=True (cached scale):
- Layer 0: " is" at 0.002% (near-uniform over all 50,257 tokens)
- Layer 36: " Paris" at 1.731%
- Layer 47: " Paris" at 18.961%
Both predict the same top tokens (normalization preserves direction), but with completely different probability distributions.
Entropy and KL Divergence Patterns for GPT2
The cached-scale approach produces:
- Entropy: ~15.6 bits at layer 0 (near maximum for uniform distribution), remaining high across most layers
- KL divergence: Smooth monotonic decrease from ~8 bits
The correct implementation produces:
- Entropy: ~0.5 bits at layer 0 (input embedding), sharp spike to ~3.5 bits (context integration), then gradual sharpening
- KL divergence: ~20 bits at layer 0 (uncontextualized input is very different from final output), sharp drop, then monotonic decrease
The input layer pattern (very low entropy for correct implementation vs. near-maximum entropy for cached scale) holds across model families. This is expected since the input layer is simply the embedded uncontextualized input token at the given position and should thus have very low entropy with most probability mass at the given token and also a substantially higher KL than the rest of the (somewhat contextualized) layers.
Additional Context
The accumulated_resid docstring itself suggests fresh computation:
"This is useful for Logit Lens style analysis, where it can be thought of as what the model 'believes' at each point in the residual stream. To project this into the vocabulary space, remember that there is a final layer norm in most decoder-only transformers. Therefore, you need to first apply the final layer norm (which can be done with
apply_ln)..."
The language of "applying" the layer norm and measuring model "beliefs" suggests fresh computation at each layer to simulate what the model would predict at that point, not viewing all layers through a single cached normalization scale.
Proposed Solution
I would recommend changing the default behavior of apply_ln=True to recompute normalization statistics for each layer's residual stream, i.e., actually applying the models own final layer normalization. While this is a breaking change, it aligns with the standard definition of logit lens in the literature and matches what the documentation describes. The current behavior produces measurement artifacts rather than accurate intermediate model beliefs.
Alternative options:
- Add a parameter (e.g.,
recompute_ln=True) toaccumulated_resid()andapply_ln_to_stack()that recomputes the layer norm mean and scale freshly for each layer's residual stream instead of using cached values (allows backward compatibility) - Clarify in the documentation that
apply_ln=Trueshould not be used for distribution analysis (entropy, KL divergence, probability measurements) as it produces artifacts of the cached normalization rather than accurate intermediate predictions
Related Issues
This bug may help clarify confusion in #523, where a user reports unexpected behavior with accumulated_resid and apply_ln=True when fold_ln=False. While that issue stems from needing to manually apply layer norm weights/biases when folding is disabled, the proposed solution here (having apply_ln=True actually call the model's layer norm operations instead of just using cached scales) would make the behavior more intuitive and work correctly regardless of the fold_ln setting.
Checklist
- I have checked that there is no similar issue in the repo