Mechanistic interpretability lab that dissects transformer attention heads and traces reasoning circuits in real time.
Circuit Trace loads GPT-2 small, runs a forward pass on your prompt, and streams every attention pattern, head metric, and residual delta from the model into a browser-based lab. You watch attention heads light up layer-by-layer, hover any head to see its top source/destination tokens, follow a circuit back from the predicted token through the heads that wrote it, and run activation patches to measure causal effect. No notebooks, no print(tensor.shape) β just the model's internals laid out interactively.
Above: the IOI prompt "When Mary and John went to the store, John gave a drink to" β top prediction "Mary" at 44.6%, with the circuit graph (right of model architecture) tracing the deep-layer name-mover heads that wrote the answer.
Most production engineers treat transformers as opaque black boxes. The mech-interp community has built powerful tools (TransformerLens, captum, neuroscope) but they are all Jupyter-first β you load activations into a notebook, plot a matplotlib heatmap, and squint. There's no way to see the model think.
Circuit Trace fills that gap. It takes the same primitives β attention pattern extraction, head classification, direct logit attribution, activation patching β and exposes them as a real interactive surface. You type a prompt, the forward pass streams over a WebSocket, and 144 attention heads animate into a 3D grid where colour and brightness encode contribution and value-norm. Click a head, see its attention pattern and classification (induction, duplicate-token, previous-token, diffuse). Run an activation patch, watch which head causally restores the lost prediction.
It's the kind of tool an interpretability team would build for itself, packaged as a standalone open-source app.
| Capability | What it does |
|---|---|
| Streaming forward pass | WebSocket emits per-layer attention + metrics so the UI animates the computation. |
| 3D model architecture | Three.js grid of 12 layers Γ 12 heads; brightness β value-norm; circuit edges drawn between top contributors. |
| Attention heatmap | D3 [seq Γ seq] matrix for any selected head, with proper token labels and tooltips. |
| Head classification | Each head labeled induction / duplicate / previous / diffuse from its attention signature. |
| Backward circuit trace | Direct logit attribution: per-head contribution to the predicted token's logit, computed via z @ W_O Β· U[t]. |
| Logit lens | Per-layer "current best guess" β projects the residual stream through the unembedding at every depth. |
| Token-flow arc diagram | Aggregate token-to-token attention summed across all layers and heads. |
| Activation patching | Swap a clean head's z into a corrupted run; sweep all 144 heads to measure causal contribution. |
| Preset circuits | Five well-studied prompts (IOI, induction copying, factual recall, year continuation, duplicate-token). |
| Pure local inference | Runs on Apple Silicon (MPS), CUDA, or CPU. No API keys, no telemetry. |
Requirements: macOS / Linux, Python 3.11 (managed via uv), Node 18+, ~600 MB free for the GPT-2 weights cache.
git clone https://github.com/rayancheca/circuit-trace
cd circuit-trace
# one-time install
cd backend && uv sync && cd ..
cd frontend && npm install && cd ..
# run both servers (backend on :8000, frontend on :5173)
./scripts/dev.shOpen http://localhost:5173/. The first request loads GPT-2 small (~500 MB download from HuggingFace), subsequent requests are cached.
- Pick a preset (e.g. Indirect Object Identification) or type your own prompt.
- Click analyze. The status badge transitions
connecting β streaming β readywhile layers animate in. - Watch the 3D model architecture light up β emerald heads write toward the predicted token, magenta heads write away. The brightness of each sphere reflects its value norm.
- Click any head in the 3D view. The right column shows its attention heatmap and the inspector panel β entropy, concentration, value norm, classification badges.
- The lower row shows the circuit trace (a 2D
layers Γ headsheatmap of direct logit attribution, with the path edges in magenta), the token-flow arc diagram (aggregate attention between every token pair), and the logit lens (per-layer best guesses). - Run an activation patch sweep: type a clean prompt and a corrupted prompt, hit run patch sweep, and Circuit Trace will replace each head's
zvector with the clean value (one head at a time, all 144 heads) and report which heads causally recover the lost prediction.
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β BROWSER (React + Vite) β
β βββββββββββ ββββββββββββββββ βββββββββββββββ ββββββββββββββ β
β β Prompt β β 3D Model Map β β Heatmap β β Circuit β β
β β Panel β β (Three.js) β β (D3) β β Trace (D3)β β
β ββββββ¬βββββ ββββββββ²ββββββββ ββββββββ²βββββββ βββββββ²βββββββ β
β β β Zustand store + useAnalysis hook β β
β ββββββββββββββββ΄βββββββββ WebSocket βββββββββββββββ β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β
ββββββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββ
β FASTAPI BACKEND β
β /api/health /api/presets /api/analyze /api/patch /ws/analyze β
β β β
β βββββββββββββββββββ βββββββββββββββΌββββββββββββββ β
β β runner.analyze()β β β ActivationCache (hooks) β β
β ββββββββββ¬βββββββββ βββββββββββββββ¬ββββββββββββββ β
β β β β
β βββββββββΌβββββββββββββββββββββββββββΌββββββββββββββ β
β β Head classifier Β· Circuit tracer Β· Patcher β β
β β Logit lens Β· Attention metrics β β
β βββββββββββββββββββββββββββββββββββββ¬βββββββββββββ β
β β β
β βββββββββββββββββββββββββββββββββββββΌβββββββββββββ β
β β PyTorch GPT-2 small (HuggingFace transformers) β β
β β Forward hooks on every attn + mlp module β β
β ββββββββββββββββββββββββββββββββββββββββββββββββββ β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
- Browser opens WebSocket to
/ws/analyze, sends{prompt, top_k, top_path_per_layer}. - Backend tokenizes (
gpt2BPE), accepts, and dispatches the forward pass on a worker thread. ActivationCachehooks capture per layer:zβ per-head attention output before the W_O projection:[batch, seq, heads, head_dim].attn_outβ post-W_O attention output:[batch, seq, d_model].mlp_out,residual_pre,residual_postβ for direct logit attribution and the logit lens.attentionβ the actual softmax-normalized weights via HF'soutput_attentions=True(forced eager attention so SDPA/Flash don't drop them).
- After each layer, a
LayerUpdatemessage goes over the WebSocket carrying the attention pattern + per-head metrics (entropy, concentration, value norm). The frontend animates layer-by-layer reveals. - Once all 12 layers complete, the analyzer runs head classification, the circuit tracer, and the logit lens, then sends a
donemessage with the fullAnalysisResult.
Most mech-interp code reaches for TransformerLens which provides a cache API for activations. Circuit Trace deliberately uses raw PyTorch hooks on HuggingFace transformers β see backend/circuit_trace/model/hooks.py. The cache:
- Registers a
forward_pre_hookon eachattn.c_projConv1D β the input is the concatenated per-headzvectors[batch, seq, d_model], which we reshape to[batch, seq, heads, head_dim]to recover per-head outputs before the W_O projection. - Registers
forward_hookon eachattn,mlp, and the block itself to capture the post-attn output, the MLP output, and the residual stream pre/post. - Pulls native
output_attentionsandoutput_hidden_statesfrom the model output (withattn_implementation="eager"forced at load time so SDPA/Flash don't silently drop the weights). - Supports interventions:
set_z_override(layer, head, replacement)swaps a head'szduring the next forward pass β the basis of activation patching.
analysis/heads.py computes three scalar scores per head from the attention pattern:
prev_token_scoreβ mean ofattn[i, i-1]across positionsi β₯ 1. A previous-token head will be β1.0.duplicate_token_scoreβ for eachiwith a prior occurrence oftoken[i], the max attention to any earlier same-token position. Reveals duplicate-token heads (the "what was it last time" lookback).induction_scoreβ for eachi, the max attention to the position after a previous occurrence oftoken[i]β i.e. "I saw 'A B' earlier, and now I'm at 'A' again, so attend to that earlier 'B'". This is the canonical induction-head signature from Olsson et al. 2022.
The argmax over the three (above a configurable threshold) labels each head; otherwise it's diffuse. Synthetic-attention unit tests in backend/tests/test_heads.py cover all four labels.
For a predicted token t at query position P, each attention head's direct contribution to its logit is
contribution(layer L, head H) = (z[L][P, H] @ W_O[L][H*dh : (H+1)*dh, :]) Β· U[t]
where W_O is c_proj.weight (shape [d_model, d_model], rows = input dims) and U[t] is the unembedding row (lm_head.weight[t], shape [d_model]). This is the direct logit attribution used in the IOI paper β see analysis/circuits.py. The top_path returned to the UI takes the top-N heads per layer, sorted by |contribution|.
On the IOI prompt above, this surfaces the well-known name-mover heads in the deepest layers (L9βL11) as the dominant writers β visible as the magenta path through the model graph in the screenshot.
analysis/patching.py implements the standard resampling ablation: run the clean prompt, cache its z vectors, then run the corrupted prompt while overriding one head's z at the patch position with the clean value. The change in P(target) is the head's causal contribution. The /api/patch endpoint sweeps all 144 heads (β12 s on Apple Silicon MPS) and the UI sorts them as a horizontal bar chart, click-to-select.
The classic logit lens β for each hidden state across the model, project through the final layer-norm + unembedding to read the model's "current best guess". Implemented in analysis/logit_lens.py, surfaced as a per-layer table in the right column. You can watch the answer crystallize as depth increases.
api/ws.py runs the analysis on a worker thread (loop.run_in_executor) and uses an asyncio.Queue + loop.call_soon_threadsafe to bridge the per-layer callback into the FastAPI event loop. Each LayerUpdate is JSON-serialized, sent, then the loop yields for ~30 ms so the frontend can animate without dropping frames.
| Preset | Prompt | Expected circuit |
|---|---|---|
| IOI | When Mary and John went to the store, John gave a drink to | Name-mover heads in L9βL11 write " Mary". |
| Induction | The cat sat on the mat. The cat sat on the | Induction heads (L5βL7 typically) attend from current token to the token after its previous occurrence and copy " mat". |
| Factual recall | The capital of France is | MLPs do most of the work; few attention heads dominate. |
| Greater-than | The war lasted from the year 1732 to 17 | Early positional heads + deep digit heads. |
| Duplicate token | The password is SWORDFISH. Again, the password is | Duplicate-token heads in early layers feed downstream induction heads. |
circuit-trace/
βββ backend/
β βββ pyproject.toml # uv project + deps (torch, transformers, fastapi, ...)
β βββ .python-version # 3.11
β βββ circuit_trace/
β β βββ __init__.py
β β βββ __main__.py # `python -m circuit_trace` banner + load
β β βββ config.py # env-var driven Settings
β β βββ logger.py # rich-styled logger + banner
β β βββ schemas.py # Pydantic wire format (analyze, patch, stream)
β β βββ runner.py # orchestrator: prompt β AnalysisResult
β β βββ app.py # FastAPI app factory + lifespan + CORS
β β βββ model/
β β β βββ loader.py # GPT-2 + tokenizer + device pick + ModelBundle
β β β βββ hooks.py # ActivationCache with forward hooks + interventions
β β βββ analysis/
β β β βββ attention.py # entropy, concentration, head_norm, summary
β β β βββ heads.py # induction / duplicate / previous classifier
β β β βββ circuits.py # backward direct-logit-attribution tracing
β β β βββ logit_lens.py # per-layer top predictions
β β β βββ patching.py # activation patching (per-head causal sweep)
β β βββ api/
β β βββ http.py # /api/health /api/presets /api/analyze /api/patch
β β βββ ws.py # /ws/analyze streaming endpoint
β β βββ presets.py # 5 preset prompts with expected circuits
β β βββ state.py # singleton model bundle (lazy + warmup)
β βββ tests/ # pytest: attention primitives, head classifier, runner
β
βββ frontend/
β βββ package.json
β βββ vite.config.ts # dev proxy /api + /ws β :8000
β βββ tailwind.config.js
β βββ index.html
β βββ src/
β βββ main.tsx
β βββ App.tsx # 3-row dashboard layout
β βββ index.css # Tailwind + dark theme tokens
β βββ types.ts # TS mirror of Pydantic schemas
β βββ lib/{api,ws,colors,format}.ts
β βββ store/useAnalysisStore.ts # Zustand store
β βββ hooks/{useAnalysis,useResizeObserver}.ts
β βββ components/
β βββ Header.tsx # logo + model badges + status pill
β βββ PromptPanel.tsx # prompt textarea + presets + tokens
β βββ PredictionBar.tsx # top-k bar chart
β βββ ModelGraph3D.tsx # Three.js layer Γ head grid + circuit edges
β βββ AttentionHeatmap.tsx # D3 [seqΓseq] matrix
β βββ HeadInspector.tsx # selected-head metrics + top pairs
β βββ CircuitTrace.tsx # D3 layersΓheads contribution map + path
β βββ TokenFlow.tsx # D3 arc diagram
β βββ LogitLensTable.tsx # per-layer top guesses
β βββ PatchingPanel.tsx # clean/corrupted prompts + causal bar chart
β βββ ui/{Panel,Spinner,Badge}.tsx
β
βββ docs/screenshot.png
βββ scripts/{dev,check}.sh
./scripts/check.sh # backend pytest + frontend tsc + frontend buildBackend tests:
cd backend && uv run pytest -v
# tests/test_attention.py ... 8 passed
# tests/test_heads.py ... 5 passed
# tests/test_runner.py ... 1 passed (loads GPT-2 once)Environment variables (all optional):
| Var | Default | Purpose |
|---|---|---|
CT_MODEL |
gpt2 |
Any HF causal LM with the GPT-2 module layout (e.g. distilgpt2, gpt2-medium). |
CT_DEVICE |
auto |
cuda / mps / cpu / auto. |
CT_PORT |
8000 |
Backend HTTP+WS port. |
CT_HOST |
127.0.0.1 |
Bind host. |
CT_MAX_TOKENS |
64 |
Max prompt tokens (memory cap for the attention cube). |
CT_CORS |
http://localhost:5173 |
Allowed CORS origin for the frontend. |
- GPT-2 small only (124M); larger models will work but the streamed attention cube grows as
O(L Β· H Β· seqΒ²). - Head classification thresholds are heuristic β the included unit tests use synthetic patterns; on noisy real prompts a head can drift between
previousandduplicatefor short sequences. - Activation patching is single-head; multi-head intersections are out of scope here.
- The model is treated as fixed; no fine-tuning or training UI.
MIT.
The interpretability primitives implemented here come straight from these papers:
- Elhage et al. β A Mathematical Framework for Transformer Circuits
- Olsson et al. β In-context Learning and Induction Heads
- Wang et al. β Interpretability in the Wild: A Circuit for Indirect Object Identification in GPT-2 Small
- Meng et al. β Locating and Editing Factual Associations in GPT
- nostalgebraist β Interpreting GPT: the logit lens (LessWrong, 2020)
