Skip to content

rayancheca/circuit-trace

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

17 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Circuit Trace

Mechanistic interpretability lab that dissects transformer attention heads and traces reasoning circuits in real time.

Python 3.11 TypeScript strict PyTorch 2.x License MIT

Circuit Trace loads GPT-2 small, runs a forward pass on your prompt, and streams every attention pattern, head metric, and residual delta from the model into a browser-based lab. You watch attention heads light up layer-by-layer, hover any head to see its top source/destination tokens, follow a circuit back from the predicted token through the heads that wrote it, and run activation patches to measure causal effect. No notebooks, no print(tensor.shape) β€” just the model's internals laid out interactively.

Circuit Trace screenshot β€” IOI prompt

Above: the IOI prompt "When Mary and John went to the store, John gave a drink to" β€” top prediction "Mary" at 44.6%, with the circuit graph (right of model architecture) tracing the deep-layer name-mover heads that wrote the answer.


Why this exists

Most production engineers treat transformers as opaque black boxes. The mech-interp community has built powerful tools (TransformerLens, captum, neuroscope) but they are all Jupyter-first β€” you load activations into a notebook, plot a matplotlib heatmap, and squint. There's no way to see the model think.

Circuit Trace fills that gap. It takes the same primitives β€” attention pattern extraction, head classification, direct logit attribution, activation patching β€” and exposes them as a real interactive surface. You type a prompt, the forward pass streams over a WebSocket, and 144 attention heads animate into a 3D grid where colour and brightness encode contribution and value-norm. Click a head, see its attention pattern and classification (induction, duplicate-token, previous-token, diffuse). Run an activation patch, watch which head causally restores the lost prediction.

It's the kind of tool an interpretability team would build for itself, packaged as a standalone open-source app.

Features

Capability What it does
Streaming forward pass WebSocket emits per-layer attention + metrics so the UI animates the computation.
3D model architecture Three.js grid of 12 layers Γ— 12 heads; brightness ∝ value-norm; circuit edges drawn between top contributors.
Attention heatmap D3 [seq Γ— seq] matrix for any selected head, with proper token labels and tooltips.
Head classification Each head labeled induction / duplicate / previous / diffuse from its attention signature.
Backward circuit trace Direct logit attribution: per-head contribution to the predicted token's logit, computed via z @ W_O Β· U[t].
Logit lens Per-layer "current best guess" β€” projects the residual stream through the unembedding at every depth.
Token-flow arc diagram Aggregate token-to-token attention summed across all layers and heads.
Activation patching Swap a clean head's z into a corrupted run; sweep all 144 heads to measure causal contribution.
Preset circuits Five well-studied prompts (IOI, induction copying, factual recall, year continuation, duplicate-token).
Pure local inference Runs on Apple Silicon (MPS), CUDA, or CPU. No API keys, no telemetry.

Quick start

Requirements: macOS / Linux, Python 3.11 (managed via uv), Node 18+, ~600 MB free for the GPT-2 weights cache.

git clone https://github.com/rayancheca/circuit-trace
cd circuit-trace

# one-time install
cd backend && uv sync && cd ..
cd frontend && npm install && cd ..

# run both servers (backend on :8000, frontend on :5173)
./scripts/dev.sh

Open http://localhost:5173/. The first request loads GPT-2 small (~500 MB download from HuggingFace), subsequent requests are cached.

Usage flow

  1. Pick a preset (e.g. Indirect Object Identification) or type your own prompt.
  2. Click analyze. The status badge transitions connecting β†’ streaming β†’ ready while layers animate in.
  3. Watch the 3D model architecture light up β€” emerald heads write toward the predicted token, magenta heads write away. The brightness of each sphere reflects its value norm.
  4. Click any head in the 3D view. The right column shows its attention heatmap and the inspector panel β€” entropy, concentration, value norm, classification badges.
  5. The lower row shows the circuit trace (a 2D layers Γ— heads heatmap of direct logit attribution, with the path edges in magenta), the token-flow arc diagram (aggregate attention between every token pair), and the logit lens (per-layer best guesses).
  6. Run an activation patch sweep: type a clean prompt and a corrupted prompt, hit run patch sweep, and Circuit Trace will replace each head's z vector with the clean value (one head at a time, all 144 heads) and report which heads causally recover the lost prediction.

Architecture

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                         BROWSER (React + Vite)                     β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”   β”‚
β”‚  β”‚  Prompt β”‚  β”‚ 3D Model Map β”‚  β”‚ Heatmap     β”‚  β”‚ Circuit    β”‚   β”‚
β”‚  β”‚  Panel  β”‚  β”‚  (Three.js)  β”‚  β”‚  (D3)       β”‚  β”‚  Trace (D3)β”‚   β”‚
β”‚  β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”˜  β””β”€β”€β”€β”€β”€β”€β–²β”€β”€β”€β”€β”€β”€β”€β”˜  β””β”€β”€β”€β”€β”€β”€β–²β”€β”€β”€β”€β”€β”€β”˜  β””β”€β”€β”€β”€β”€β–²β”€β”€β”€β”€β”€β”€β”˜   β”‚
β”‚       β”‚              β”‚ Zustand store + useAnalysis hook β”‚          β”‚
β”‚       └──────────────┴───────── WebSocket β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜          β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”‚β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                                       β”‚
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                         FASTAPI BACKEND                            β”‚
β”‚   /api/health  /api/presets  /api/analyze  /api/patch  /ws/analyze β”‚
β”‚                                      β”‚                             β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”               β”‚
β”‚  β”‚ runner.analyze()β”‚ β†’ β”‚   ActivationCache (hooks) β”‚               β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜               β”‚
β”‚           β”‚                          β”‚                             β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”               β”‚
β”‚   β”‚ Head classifier Β· Circuit tracer Β· Patcher     β”‚               β”‚
β”‚   β”‚ Logit lens Β· Attention metrics                 β”‚               β”‚
β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜               β”‚
β”‚                                       β”‚                            β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”               β”‚
β”‚   β”‚ PyTorch GPT-2 small (HuggingFace transformers) β”‚               β”‚
β”‚   β”‚ Forward hooks on every attn + mlp module       β”‚               β”‚
β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜               β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Data flow

  1. Browser opens WebSocket to /ws/analyze, sends {prompt, top_k, top_path_per_layer}.
  2. Backend tokenizes (gpt2 BPE), accepts, and dispatches the forward pass on a worker thread.
  3. ActivationCache hooks capture per layer:
    • z β€” per-head attention output before the W_O projection: [batch, seq, heads, head_dim].
    • attn_out β€” post-W_O attention output: [batch, seq, d_model].
    • mlp_out, residual_pre, residual_post β€” for direct logit attribution and the logit lens.
    • attention β€” the actual softmax-normalized weights via HF's output_attentions=True (forced eager attention so SDPA/Flash don't drop them).
  4. After each layer, a LayerUpdate message goes over the WebSocket carrying the attention pattern + per-head metrics (entropy, concentration, value norm). The frontend animates layer-by-layer reveals.
  5. Once all 12 layers complete, the analyzer runs head classification, the circuit tracer, and the logit lens, then sends a done message with the full AnalysisResult.

Technical deep-dive

Forward hooks instead of TransformerLens

Most mech-interp code reaches for TransformerLens which provides a cache API for activations. Circuit Trace deliberately uses raw PyTorch hooks on HuggingFace transformers β€” see backend/circuit_trace/model/hooks.py. The cache:

  • Registers a forward_pre_hook on each attn.c_proj Conv1D β€” the input is the concatenated per-head z vectors [batch, seq, d_model], which we reshape to [batch, seq, heads, head_dim] to recover per-head outputs before the W_O projection.
  • Registers forward_hook on each attn, mlp, and the block itself to capture the post-attn output, the MLP output, and the residual stream pre/post.
  • Pulls native output_attentions and output_hidden_states from the model output (with attn_implementation="eager" forced at load time so SDPA/Flash don't silently drop the weights).
  • Supports interventions: set_z_override(layer, head, replacement) swaps a head's z during the next forward pass β€” the basis of activation patching.

Head classification (Anthropic IH paper)

analysis/heads.py computes three scalar scores per head from the attention pattern:

  • prev_token_score β€” mean of attn[i, i-1] across positions i β‰₯ 1. A previous-token head will be β‰ˆ1.0.
  • duplicate_token_score β€” for each i with a prior occurrence of token[i], the max attention to any earlier same-token position. Reveals duplicate-token heads (the "what was it last time" lookback).
  • induction_score β€” for each i, the max attention to the position after a previous occurrence of token[i] β€” i.e. "I saw 'A B' earlier, and now I'm at 'A' again, so attend to that earlier 'B'". This is the canonical induction-head signature from Olsson et al. 2022.

The argmax over the three (above a configurable threshold) labels each head; otherwise it's diffuse. Synthetic-attention unit tests in backend/tests/test_heads.py cover all four labels.

Backward circuit trace via direct logit attribution

For a predicted token t at query position P, each attention head's direct contribution to its logit is

contribution(layer L, head H) = (z[L][P, H] @ W_O[L][H*dh : (H+1)*dh, :]) Β· U[t]

where W_O is c_proj.weight (shape [d_model, d_model], rows = input dims) and U[t] is the unembedding row (lm_head.weight[t], shape [d_model]). This is the direct logit attribution used in the IOI paper β€” see analysis/circuits.py. The top_path returned to the UI takes the top-N heads per layer, sorted by |contribution|.

On the IOI prompt above, this surfaces the well-known name-mover heads in the deepest layers (L9–L11) as the dominant writers β€” visible as the magenta path through the model graph in the screenshot.

Activation patching

analysis/patching.py implements the standard resampling ablation: run the clean prompt, cache its z vectors, then run the corrupted prompt while overriding one head's z at the patch position with the clean value. The change in P(target) is the head's causal contribution. The /api/patch endpoint sweeps all 144 heads (β‰ˆ12 s on Apple Silicon MPS) and the UI sorts them as a horizontal bar chart, click-to-select.

Logit lens

The classic logit lens β€” for each hidden state across the model, project through the final layer-norm + unembedding to read the model's "current best guess". Implemented in analysis/logit_lens.py, surfaced as a per-layer table in the right column. You can watch the answer crystallize as depth increases.

Streaming over WebSocket without blocking the event loop

api/ws.py runs the analysis on a worker thread (loop.run_in_executor) and uses an asyncio.Queue + loop.call_soon_threadsafe to bridge the per-layer callback into the FastAPI event loop. Each LayerUpdate is JSON-serialized, sent, then the loop yields for ~30 ms so the frontend can animate without dropping frames.

Preset circuits

Preset Prompt Expected circuit
IOI When Mary and John went to the store, John gave a drink to Name-mover heads in L9–L11 write " Mary".
Induction The cat sat on the mat. The cat sat on the Induction heads (L5–L7 typically) attend from current token to the token after its previous occurrence and copy " mat".
Factual recall The capital of France is MLPs do most of the work; few attention heads dominate.
Greater-than The war lasted from the year 1732 to 17 Early positional heads + deep digit heads.
Duplicate token The password is SWORDFISH. Again, the password is Duplicate-token heads in early layers feed downstream induction heads.

Project layout

circuit-trace/
β”œβ”€β”€ backend/
β”‚   β”œβ”€β”€ pyproject.toml                # uv project + deps (torch, transformers, fastapi, ...)
β”‚   β”œβ”€β”€ .python-version               # 3.11
β”‚   β”œβ”€β”€ circuit_trace/
β”‚   β”‚   β”œβ”€β”€ __init__.py
β”‚   β”‚   β”œβ”€β”€ __main__.py               # `python -m circuit_trace` banner + load
β”‚   β”‚   β”œβ”€β”€ config.py                 # env-var driven Settings
β”‚   β”‚   β”œβ”€β”€ logger.py                 # rich-styled logger + banner
β”‚   β”‚   β”œβ”€β”€ schemas.py                # Pydantic wire format (analyze, patch, stream)
β”‚   β”‚   β”œβ”€β”€ runner.py                 # orchestrator: prompt β†’ AnalysisResult
β”‚   β”‚   β”œβ”€β”€ app.py                    # FastAPI app factory + lifespan + CORS
β”‚   β”‚   β”œβ”€β”€ model/
β”‚   β”‚   β”‚   β”œβ”€β”€ loader.py             # GPT-2 + tokenizer + device pick + ModelBundle
β”‚   β”‚   β”‚   └── hooks.py              # ActivationCache with forward hooks + interventions
β”‚   β”‚   β”œβ”€β”€ analysis/
β”‚   β”‚   β”‚   β”œβ”€β”€ attention.py          # entropy, concentration, head_norm, summary
β”‚   β”‚   β”‚   β”œβ”€β”€ heads.py              # induction / duplicate / previous classifier
β”‚   β”‚   β”‚   β”œβ”€β”€ circuits.py           # backward direct-logit-attribution tracing
β”‚   β”‚   β”‚   β”œβ”€β”€ logit_lens.py         # per-layer top predictions
β”‚   β”‚   β”‚   └── patching.py           # activation patching (per-head causal sweep)
β”‚   β”‚   └── api/
β”‚   β”‚       β”œβ”€β”€ http.py               # /api/health /api/presets /api/analyze /api/patch
β”‚   β”‚       β”œβ”€β”€ ws.py                 # /ws/analyze streaming endpoint
β”‚   β”‚       β”œβ”€β”€ presets.py            # 5 preset prompts with expected circuits
β”‚   β”‚       └── state.py              # singleton model bundle (lazy + warmup)
β”‚   └── tests/                        # pytest: attention primitives, head classifier, runner
β”‚
β”œβ”€β”€ frontend/
β”‚   β”œβ”€β”€ package.json
β”‚   β”œβ”€β”€ vite.config.ts                # dev proxy /api + /ws β†’ :8000
β”‚   β”œβ”€β”€ tailwind.config.js
β”‚   β”œβ”€β”€ index.html
β”‚   └── src/
β”‚       β”œβ”€β”€ main.tsx
β”‚       β”œβ”€β”€ App.tsx                   # 3-row dashboard layout
β”‚       β”œβ”€β”€ index.css                 # Tailwind + dark theme tokens
β”‚       β”œβ”€β”€ types.ts                  # TS mirror of Pydantic schemas
β”‚       β”œβ”€β”€ lib/{api,ws,colors,format}.ts
β”‚       β”œβ”€β”€ store/useAnalysisStore.ts # Zustand store
β”‚       β”œβ”€β”€ hooks/{useAnalysis,useResizeObserver}.ts
β”‚       └── components/
β”‚           β”œβ”€β”€ Header.tsx            # logo + model badges + status pill
β”‚           β”œβ”€β”€ PromptPanel.tsx       # prompt textarea + presets + tokens
β”‚           β”œβ”€β”€ PredictionBar.tsx     # top-k bar chart
β”‚           β”œβ”€β”€ ModelGraph3D.tsx      # Three.js layer Γ— head grid + circuit edges
β”‚           β”œβ”€β”€ AttentionHeatmap.tsx  # D3 [seqΓ—seq] matrix
β”‚           β”œβ”€β”€ HeadInspector.tsx     # selected-head metrics + top pairs
β”‚           β”œβ”€β”€ CircuitTrace.tsx      # D3 layersΓ—heads contribution map + path
β”‚           β”œβ”€β”€ TokenFlow.tsx         # D3 arc diagram
β”‚           β”œβ”€β”€ LogitLensTable.tsx    # per-layer top guesses
β”‚           β”œβ”€β”€ PatchingPanel.tsx     # clean/corrupted prompts + causal bar chart
β”‚           └── ui/{Panel,Spinner,Badge}.tsx
β”‚
β”œβ”€β”€ docs/screenshot.png
└── scripts/{dev,check}.sh

Development

./scripts/check.sh   # backend pytest + frontend tsc + frontend build

Backend tests:

cd backend && uv run pytest -v
# tests/test_attention.py  ... 8 passed
# tests/test_heads.py      ... 5 passed
# tests/test_runner.py     ... 1 passed (loads GPT-2 once)

Configuration

Environment variables (all optional):

Var Default Purpose
CT_MODEL gpt2 Any HF causal LM with the GPT-2 module layout (e.g. distilgpt2, gpt2-medium).
CT_DEVICE auto cuda / mps / cpu / auto.
CT_PORT 8000 Backend HTTP+WS port.
CT_HOST 127.0.0.1 Bind host.
CT_MAX_TOKENS 64 Max prompt tokens (memory cap for the attention cube).
CT_CORS http://localhost:5173 Allowed CORS origin for the frontend.

Limitations

  • GPT-2 small only (124M); larger models will work but the streamed attention cube grows as O(L Β· H Β· seqΒ²).
  • Head classification thresholds are heuristic β€” the included unit tests use synthetic patterns; on noisy real prompts a head can drift between previous and duplicate for short sequences.
  • Activation patching is single-head; multi-head intersections are out of scope here.
  • The model is treated as fixed; no fine-tuning or training UI.

License

MIT.

Acknowledgements

The interpretability primitives implemented here come straight from these papers:

About

Mechanistic interpretability lab that dissects transformer attention heads and traces reasoning circuits in real time.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors