From 23f3ee74c081010ef918ac6342e37dd3b1882a8b Mon Sep 17 00:00:00 2001 From: RajdeepGupta07 Date: Thu, 11 Dec 2025 12:47:21 +0530 Subject: [PATCH] This is the plot which have made --- lot_activation_sizes.py | 46 +++++++++++++++++++++++++++++++++++++++++ print_all_shapes.py | 19 +++++++++++++++++ test_tl.py | 21 +++++++++++++++++++ 3 files changed, 86 insertions(+) create mode 100644 lot_activation_sizes.py create mode 100644 print_all_shapes.py create mode 100644 test_tl.py diff --git a/lot_activation_sizes.py b/lot_activation_sizes.py new file mode 100644 index 000000000..632a0b013 --- /dev/null +++ b/lot_activation_sizes.py @@ -0,0 +1,46 @@ +from transformer_lens import HookedTransformer +import torch +import matplotlib.pyplot as plt +import numpy as np + +model = HookedTransformer.from_pretrained("gpt2-small") +_, cache = model.run_with_cache("Hello world") + +names = [] +dim0 = [] +dim1 = [] +dim2 = [] + +for name, value in cache.items(): + if isinstance(value, torch.Tensor): + shp = tuple(value.shape) + names.append(name) + # record up to three dims, use 1 if missing to avoid plotting gaps + dim0.append(shp[0] if len(shp) > 0 else 1) + dim1.append(shp[1] if len(shp) > 1 else 1) + dim2.append(shp[2] if len(shp) > 2 else 1) + else: + # non-tensor entries get 0 dims + names.append(name) + dim0.append(0) + dim1.append(0) + dim2.append(0) + +# Keep order stable; you may want to trim long lists for readability +MAX = 60 +indices = list(range(min(len(names), MAX))) + +x = np.arange(len(indices)) + +plt.figure(figsize=(12, 6)) +plt.plot(x, [dim0[i] for i in indices], marker='o', label='dim0 (batch)') +plt.plot(x, [dim1[i] for i in indices], marker='o', label='dim1 (seq)') +plt.plot(x, [dim2[i] for i in indices], marker='o', label='dim2 (channels/heads/...)') + +plt.xticks(x, [names[i] for i in indices], rotation=90, fontsize=8) +plt.xlabel("Activation name (truncated to first {})".format(len(indices))) +plt.ylabel("Dimension size") +plt.title("Activation tensor dimensions (first {})".format(len(indices))) +plt.legend() +plt.tight_layout() +plt.show() diff --git a/print_all_shapes.py b/print_all_shapes.py new file mode 100644 index 000000000..770313084 --- /dev/null +++ b/print_all_shapes.py @@ -0,0 +1,19 @@ +from transformer_lens import HookedTransformer +import torch + +model = HookedTransformer.from_pretrained("gpt2-small") +logits, cache = model.run_with_cache("Hello world") + +print("type(logits) =", type(logits)) +if isinstance(logits, torch.Tensor): + print("logits:", logits.shape) +else: + print("logits is not a tensor!") + +for name, value in cache.items(): + if isinstance(value, torch.Tensor): + print(f"{name:40s} -> {tuple(value.shape)}") + else: + print(f"{name:40s} -> NON-TENSOR type: {type(value)}") + +print(logits.shape) # type: ignore[attr-defined] diff --git a/test_tl.py b/test_tl.py new file mode 100644 index 000000000..7ea11381a --- /dev/null +++ b/test_tl.py @@ -0,0 +1,21 @@ +from transformer_lens import HookedTransformer +import torch +from typing import cast + +print("Loading model...") +model = HookedTransformer.from_pretrained("gpt2-small") + +print("Running model...") +logits, activations = model.run_with_cache("Hello World") + +# runtime type + repr +print("type(logits) =", type(logits)) +print("repr(logits)[:200] =", repr(logits)[:200]) + +# safe checks and printing shape +if isinstance(logits, torch.Tensor): + print("logits.shape (runtime):", logits.shape) +else: + # cast for type-checkers (see next section) + logits = cast(torch.Tensor, logits) + print("After cast, logits.shape:", getattr(logits, "shape", None))