-
Notifications
You must be signed in to change notification settings - Fork 490
Description
###Proposal
I propose adding a new tutorial notebook to demos/ titled "Training Dynamics with VSM Telemetry." This tutorial introduces a lightweight bridge class (VSMTelemetry) that allows researchers to monitor mechanistic properties—specifically Attention Coherence (\sigma_p) and Head Specialization (\sigma_a)—in real-time during a training loop, using standard TransformerLens hooks.
###Motivation
Currently, TransformerLens is the gold standard for analyzing frozen checkpoints. However, there are few standardized tools or examples for monitoring mechanistic properties during the training loop itself.
Researchers studying phenomena like "grokking," "phase transitions," or circuit formation often have to write custom, complex logging hooks from scratch. This proposal fills that gap by providing a copy-pasteable, physics-inspired telemetry solution that works out-of-the-box with ActivationCache.
###Pitch
I would like to contribute a self-contained notebook that:
Defines VSMTelemetry: A lightweight (~30 line) class that calculates Vector-Space-Mapping (VSM) metrics from standard ActivationCache objects.
Demonstrates Real-Time Analysis: Trains a toy model (2-layer, induction head task) and logs the Entropy (\sigma_p) and Variance (\sigma_a) of attention heads at every step.
Visualizes Phase Transitions: Includes pre-built code to generate "Entropy Heatmaps" that reveal exactly when and where a model "groks" a task.
This aligns with the library's goal of exploratory analysis by extending it into the temporal domain (training dynamics).
###Alternatives
Manual Hooking: Users can manually write hooks to extract entropy and variance, but this is repetitive and prone to shape/dimension errors.
#External Loggers (WandB/Neptune): While powerful, these require external accounts and dependencies. This proposal is strictly "White Hat" and dependency-free (only torch, pandas, matplotlib), keeping the barrier to entry low.
###Additional context
I have prototyped this on a 2-layer model trained on a toy Induction Head task. The telemetry successfully captured the Phase Transition at Step 60, detecting the exact moment the model learned the pattern.
Prototype Results:
The image below (generated by the proposed notebook) shows the Coherence (\sigma_p) metric spiking vertically at the exact moment the Loss crashes. The heatmap further reveals the causal dependency, showing Layer 0 (Bottom) stabilizing before Layer 1 (Top).
###Implementation Details:
Zero Dependencies: Uses only transformer_lens and standard scientific stack.
Style: Code adheres to Google Python Style Guide for docstrings.
Speed: The demo runs in seconds on Colab (CPU or GPU).
Checklist
[x] I have checked that there is no similar issue in the repo (required)
