diff --git a/.env.example b/.env.example
index 30865d0..aa1dc70 100644
--- a/.env.example
+++ b/.env.example
@@ -1,2 +1,4 @@
WANDB_API_KEY=your_api_key
-WANDB_ENTITY=your_entity_name
\ No newline at end of file
+WANDB_ENTITY=your_entity_name
+OPENAI_API_KEY=your_api_key # For autointerp
+NEURONPEDIA_API_KEY=your_api_key # For neuronpdeia
\ No newline at end of file
diff --git a/e2e_sae/scripts/__init__.py b/e2e_sae/scripts/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/e2e_sae/scripts/autointerp.py b/e2e_sae/scripts/autointerp.py
new file mode 100644
index 0000000..46be270
--- /dev/null
+++ b/e2e_sae/scripts/autointerp.py
@@ -0,0 +1,642 @@
+"""Run and analyze autointerp for SAEs.
+
+NOTE: Running this script currently requires installing ApolloResearch's fork of neuron-explainer.
+https://github.com/ApolloResearch/automated-interpretability. This has been updated to work with
+gpt4-turbo-2024-04-09 and fixes an OPENAI_API_KEY issue.
+
+This script requires the following environment variables:
+- OPENAI_API_KEY: OpenAI API key.
+- NEURONPEDIA_API_KEY: Neuronpedia API key.
+These can be set in .env in the root of the repository (see .env.example).
+
+"""
+import asyncio
+import glob
+import json
+import os
+import random
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Literal, TypeVar
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import requests
+import seaborn as sns
+import statsmodels.stats.api as sms
+from dotenv import load_dotenv
+from neuron_explainer.activations.activation_records import calculate_max_activation
+from neuron_explainer.activations.activations import ActivationRecord
+from neuron_explainer.explanations.calibrated_simulator import UncalibratedNeuronSimulator
+from neuron_explainer.explanations.explainer import ContextSize, TokenActivationPairExplainer
+from neuron_explainer.explanations.explanations import ScoredSimulation
+from neuron_explainer.explanations.few_shot_examples import FewShotExampleSet
+from neuron_explainer.explanations.prompt_builder import PromptFormat
+from neuron_explainer.explanations.scoring import (
+ _simulate_and_score_sequence,
+ aggregate_scored_sequence_simulations,
+)
+from neuron_explainer.explanations.simulator import (
+ LogprobFreeExplanationTokenSimulator,
+ NeuronSimulator,
+)
+from numpy.typing import ArrayLike
+from scipy.stats import bootstrap
+from tenacity import retry, stop_after_attempt, wait_random_exponential
+from tqdm import tqdm
+
+load_dotenv() # Required before importing UncalibratedNeuronSimulator
+NEURONPEDIA_DOMAIN = "https://neuronpedia.org"
+POSITIVE_INF_REPLACEMENT = 9999
+NEGATIVE_INF_REPLACEMENT = -9999
+NAN_REPLACEMENT = 0
+OTHER_INVALID_REPLACEMENT = -99999
+
+
+def NanAndInfReplacer(value: str):
+ """Replace NaNs and Infs in outputs."""
+ replacements = {
+ "-Infinity": NEGATIVE_INF_REPLACEMENT,
+ "Infinity": POSITIVE_INF_REPLACEMENT,
+ "NaN": NAN_REPLACEMENT,
+ }
+ if value in replacements:
+ replaced_value = replacements[value]
+ return float(replaced_value)
+ else:
+ return NAN_REPLACEMENT
+
+
+def get_neuronpedia_feature(
+ feature: int, layer: int, model: str = "gpt2-small", dataset: str = "res-jb"
+) -> dict[str, Any]:
+ """Fetch a feature from Neuronpedia API."""
+ url = f"{NEURONPEDIA_DOMAIN}/api/feature/{model}/{layer}-{dataset}/{feature}"
+ result = requests.get(url).json()
+ if "index" in result:
+ result["index"] = int(result["index"])
+ else:
+ raise Exception(f"Feature {model}@{layer}-{dataset}:{feature} does not exist.")
+ return result
+
+
+def test_key(api_key: str):
+ """Test the validity of the Neuronpedia API key."""
+ url = f"{NEURONPEDIA_DOMAIN}/api/test"
+ body = {"apiKey": api_key}
+ response = requests.post(url, json=body)
+ if response.status_code != 200:
+ raise Exception("Neuronpedia API key is not valid.")
+
+
+class NeuronpediaActivation:
+ """Represents an activation from Neuronpedia."""
+
+ def __init__(self, id: str, tokens: list[str], act_values: list[float]):
+ self.id = id
+ self.tokens = tokens
+ self.act_values = act_values
+
+
+class NeuronpediaFeature:
+ """Represents a feature from Neuronpedia."""
+
+ def __init__(
+ self,
+ modelId: str,
+ layer: int,
+ dataset: str,
+ feature: int,
+ description: str = "",
+ activations: list[NeuronpediaActivation] | None = None,
+ autointerp_explanation: str = "",
+ autointerp_explanation_score: float = 0.0,
+ ):
+ self.modelId = modelId
+ self.layer = layer
+ self.dataset = dataset
+ self.feature = feature
+ self.description = description
+ self.activations = activations or []
+ self.autointerp_explanation = autointerp_explanation
+ self.autointerp_explanation_score = autointerp_explanation_score
+
+ def has_activating_text(self) -> bool:
+ """Check if the feature has activating text."""
+ return any(max(activation.act_values) > 0 for activation in self.activations)
+
+
+T = TypeVar("T")
+
+
+@retry(wait=wait_random_exponential(min=1, max=500), stop=stop_after_attempt(10))
+def sleep_identity(x: T) -> T:
+ """Dummy function for retrying."""
+ return x
+
+
+@retry(wait=wait_random_exponential(min=1, max=500), stop=stop_after_attempt(10))
+async def simulate_and_score(
+ simulator: NeuronSimulator, activation_records: list[ActivationRecord]
+) -> ScoredSimulation:
+ """Score an explanation of a neuron by how well it predicts activations on the given text sequences."""
+ scored_sequence_simulations = await asyncio.gather(
+ *[
+ sleep_identity(
+ _simulate_and_score_sequence(
+ simulator,
+ activation_record,
+ )
+ )
+ for activation_record in activation_records
+ ]
+ )
+ return aggregate_scored_sequence_simulations(scored_sequence_simulations)
+
+
+async def autointerp_neuronpedia_features(
+ features: list[NeuronpediaFeature],
+ openai_api_key: str,
+ autointerp_retry_attempts: int = 3,
+ autointerp_score_max_concurrent: int = 20,
+ neuronpedia_api_key: str = "",
+ do_score: bool = True,
+ output_dir: str = "neuronpedia_outputs/autointerp",
+ num_activations_to_use: int = 20,
+ max_explanation_activation_records: int = 20,
+ upload_to_neuronpedia: bool = True,
+ autointerp_explainer_model_name: Literal[
+ "gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview", "gpt-4-turbo-2024-04-09"
+ ] = "gpt-4-1106-preview",
+ autointerp_scorer_model_name: Literal[
+ "gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview", "gpt-4-turbo-2024-04-09"
+ ] = "gpt-3.5-turbo",
+):
+ """
+ Autointerp Neuronpedia features.
+
+ Args:
+ features: List of NeuronpediaFeature objects.
+ openai_api_key: OpenAI API key.
+ autointerp_retry_attempts: Number of retry attempts for autointerp.
+ autointerp_score_max_concurrent: Maximum number of concurrent requests for autointerp scoring.
+ neuronpedia_api_key: Neuronpedia API key.
+ do_score: Whether to score the features.
+ output_dir: Output directory for saving the results.
+ num_activations_to_use: Number of activations to use.
+ max_explanation_activation_records: Maximum number of activation records for explanation.
+ upload_to_neuronpedia: Whether to upload the results to Neuronpedia.
+ autointerp_explainer_model_name: Model name for autointerp explainer.
+ autointerp_scorer_model_name: Model name for autointerp scorer.
+
+ Returns:
+ None
+ """
+ print("\n\n")
+
+ os.environ["OPENAI_API_KEY"] = openai_api_key
+
+ if upload_to_neuronpedia and not neuronpedia_api_key:
+ raise Exception(
+ "You need to provide a Neuronpedia API key to upload the results to Neuronpedia."
+ )
+
+ test_key(neuronpedia_api_key)
+
+ print("\n\n=== Step 1) Fetching features from Neuronpedia")
+ for feature in features:
+ feature_data = get_neuronpedia_feature(
+ feature=feature.feature,
+ layer=feature.layer,
+ model=feature.modelId,
+ dataset=feature.dataset,
+ )
+
+ if "modelId" not in feature_data:
+ raise Exception(
+ f"Feature {feature.feature} in layer {feature.layer} of model {feature.modelId} and dataset {feature.dataset} does not exist."
+ )
+
+ if "activations" not in feature_data or len(feature_data["activations"]) == 0:
+ raise Exception(
+ f"Feature {feature.feature} in layer {feature.layer} of model {feature.modelId} and dataset {feature.dataset} does not have activations."
+ )
+
+ activations = feature_data["activations"]
+ activations_to_add = []
+ for activation in activations:
+ if len(activations_to_add) < num_activations_to_use:
+ activations_to_add.append(
+ NeuronpediaActivation(
+ id=activation["id"],
+ tokens=activation["tokens"],
+ act_values=activation["values"],
+ )
+ )
+ feature.activations = activations_to_add
+
+ if not feature.has_activating_text():
+ raise Exception(
+ f"Feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature} appears dead - it does not have activating text."
+ )
+
+ for iteration_num, feature in enumerate(features):
+ start_time = datetime.now()
+
+ print(
+ f"\n========== Feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature} ({iteration_num + 1} of {len(features)} Features) =========="
+ )
+ print(
+ f"\n=== Step 2) Explaining feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature}"
+ )
+
+ activation_records = [
+ ActivationRecord(tokens=activation.tokens, activations=activation.act_values)
+ for activation in feature.activations
+ ]
+
+ activation_records_explaining = activation_records[:max_explanation_activation_records]
+
+ explainer = TokenActivationPairExplainer(
+ model_name=autointerp_explainer_model_name,
+ prompt_format=PromptFormat.HARMONY_V4,
+ context_size=ContextSize.SIXTEEN_K,
+ max_concurrent=1,
+ )
+
+ explanations = []
+ for _ in range(autointerp_retry_attempts):
+ try:
+ explanations = await explainer.generate_explanations(
+ all_activation_records=activation_records_explaining,
+ max_activation=calculate_max_activation(activation_records_explaining),
+ num_samples=1,
+ )
+ except Exception as e:
+ print(f"ERROR, RETRYING: {e}")
+ else:
+ break
+ else:
+ print(
+ f"ERROR: Failed to explain feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature}"
+ )
+
+ assert len(explanations) == 1
+ explanation = explanations[0].rstrip(".")
+ print(f"===== {autointerp_explainer_model_name}'s explanation: {explanation}")
+ feature.autointerp_explanation = explanation
+
+ if do_score:
+ print(
+ f"\n=== Step 3) Scoring feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature}"
+ )
+ print("=== This can take up to 30 seconds.")
+
+ temp_activation_records = [
+ ActivationRecord(
+ tokens=[
+ token.replace("<|endoftext|>", "<|not_endoftext|>")
+ .replace(" 55", "_55")
+ .encode("ascii", errors="backslashreplace")
+ .decode("ascii")
+ for token in activation_record.tokens
+ ],
+ activations=activation_record.activations,
+ )
+ for activation_record in activation_records
+ ]
+
+ score = None
+ scored_simulation = None
+ for _ in range(autointerp_retry_attempts):
+ try:
+ simulator = UncalibratedNeuronSimulator(
+ LogprobFreeExplanationTokenSimulator(
+ autointerp_scorer_model_name,
+ explanation,
+ json_mode=True,
+ max_concurrent=autointerp_score_max_concurrent,
+ few_shot_example_set=FewShotExampleSet.JL_FINE_TUNED,
+ prompt_format=PromptFormat.HARMONY_V4,
+ )
+ )
+ scored_simulation = await simulate_and_score(simulator, temp_activation_records)
+ score = scored_simulation.get_preferred_score()
+ except Exception as e:
+ print(f"ERROR, RETRYING: {e}")
+ else:
+ break
+
+ if (
+ score is None
+ or scored_simulation is None
+ or len(scored_simulation.scored_sequence_simulations) != num_activations_to_use
+ ):
+ print(
+ f"ERROR: Failed to score feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature}. Skipping it."
+ )
+ continue
+ feature.autointerp_explanation_score = score
+ print(f"===== {autointerp_scorer_model_name}'s score: {(score * 100):.0f}")
+
+ output_file = (
+ f"{output_dir}/{feature.layer}-{feature.dataset}_feature-{feature.feature}_score-"
+ f"{feature.autointerp_explanation_score}_time-{datetime.now().strftime('%Y%m%d-%H%M%S')}.jsonl"
+ )
+ os.makedirs(output_dir, exist_ok=True)
+ print(f"===== Your results will be saved to: {output_file} =====")
+
+ output_data = json.dumps(
+ {
+ "apiKey": neuronpedia_api_key,
+ "feature": {
+ "modelId": feature.modelId,
+ "layer": f"{feature.layer}-{feature.dataset}",
+ "index": feature.feature,
+ "activations": feature.activations,
+ "explanation": feature.autointerp_explanation,
+ "explanationScore": feature.autointerp_explanation_score,
+ "explanationModel": autointerp_explainer_model_name,
+ "autointerpModel": autointerp_scorer_model_name,
+ "simulatedActivations": scored_simulation.scored_sequence_simulations,
+ },
+ },
+ default=vars,
+ )
+ output_data_json = json.loads(output_data, parse_constant=NanAndInfReplacer)
+ output_data_str = json.dumps(output_data)
+
+ print(f"\n=== Step 4) Saving feature to {output_file}")
+ with open(output_file, "a") as f:
+ f.write(output_data_str)
+ f.write("\n")
+
+ if upload_to_neuronpedia:
+ print(
+ f"\n=== Step 5) Uploading feature to Neuronpedia: {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature}"
+ )
+ url = f"{NEURONPEDIA_DOMAIN}/api/upload-explanation"
+ body = output_data_json
+ response = requests.post(url, json=body)
+ if response.status_code != 200:
+ print(f"ERROR: Couldn't upload explanation to Neuronpedia: {response.text}")
+
+ end_time = datetime.now()
+ print(f"\n========== Time Spent for Feature: {end_time - start_time}\n")
+
+ print("\n\n========== Generation and Upload Complete ==========\n\n")
+
+
+def run_autointerp(
+ saes: list[str],
+ n_random_features: int,
+ dict_size: int,
+ feature_model_id: str,
+ autointerp_explainer_model_name: Literal[
+ "gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview", "gpt-4-turbo-2024-04-09"
+ ],
+ autointerp_scorer_model_name: Literal[
+ "gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview", "gpt-4-turbo-2024-04-09"
+ ],
+ out_dir: str | Path = "neuronpedia_outputs/autointerp",
+):
+ """Explain and score random features across SAEs and upload to Neuronpedia.
+
+ Args:
+ sae_sets: List of SAE sets.
+ n_random_features: Number of random features to explain and score for each SAE.
+ dict_size: Size of the SAE dictionary.
+ feature_model_id: Model ID that the SAE is attached to.
+ autointerp_explainer_model_name: Model name for autointerp explainer.
+ autointerp_scorer_model_name: Model name for autointerp scorer (much more expensive).
+ """
+ for i in tqdm(range(n_random_features), desc="random features", total=n_random_features):
+ for sae in tqdm(saes, desc="sae", total=len(saes)):
+ layer = int(sae.split("-")[0])
+ dataset = "-".join(sae.split("-")[1:])
+ feature_exists = False
+ while not feature_exists:
+ feature = random.randint(0, dict_size)
+ feature_data = get_neuronpedia_feature(
+ feature=feature,
+ layer=layer,
+ model=feature_model_id,
+ dataset=dataset,
+ )
+ if "activations" in feature_data and len(feature_data["activations"]) >= 20:
+ feature_exists = True
+ print(f"sae: {sae}, feature: {feature}")
+
+ features = [
+ NeuronpediaFeature(
+ modelId=feature_model_id,
+ layer=layer,
+ dataset=dataset,
+ feature=feature,
+ ),
+ ]
+
+ asyncio.run(
+ autointerp_neuronpedia_features(
+ features=features,
+ openai_api_key=os.getenv("OPENAI_API_KEY", ""),
+ neuronpedia_api_key=os.getenv("NEURONPEDIA_API_KEY", ""),
+ autointerp_explainer_model_name=autointerp_explainer_model_name,
+ autointerp_scorer_model_name=autointerp_scorer_model_name,
+ num_activations_to_use=20,
+ max_explanation_activation_records=5,
+ output_dir=str(out_dir),
+ )
+ )
+
+ feature_url = (
+ f"https://neuronpedia.org/{feature_model_id}/{layer}-{dataset}/{feature}"
+ )
+ print(f"Your feature is at: {feature_url}")
+
+
+def get_autointerp_results_df(out_dir: Path):
+ autointerp_files = glob.glob(f"{out_dir}/**/*_feature-*_score-*", recursive=True)
+ stats = {
+ "layer": [],
+ "sae": [],
+ "feature": [],
+ "explanationScore": [],
+ "explanationModel": [],
+ "autointerpModel": [],
+ "explanation": [],
+ "sae_type": [],
+ }
+ for autointerp_file in tqdm(autointerp_files, "constructing stats from json autointerp files"):
+ with open(autointerp_file) as f:
+ json_data = json.loads(json.load(f))
+ if "feature" in json_data:
+ json_data = json_data["feature"]
+ stats["layer"].append(int(json_data["layer"].split("-")[0]))
+ stats["sae"].append("-".join(json_data["layer"].split("-")[1:]))
+ stats["feature"].append(json_data["index"])
+ stats["explanationScore"].append(json_data["explanationScore"])
+ stats["autointerpModel"].append(json_data["autointerpModel"])
+ stats["explanationModel"].append("gpt-4-1106-preview")
+ stats["explanation"].append(json_data["explanation"])
+ stats["sae_type"].append(
+ {"e": "e2e", "l": "local", "r": "downstream"}[json_data["layer"].split("-")[1][-1]]
+ )
+ df_stats = pd.DataFrame(stats)
+ assert not df_stats.empty
+ df_stats.dropna(inplace=True)
+ return df_stats
+
+
+def pair_violin_plot(df_stats: pd.DataFrame, pairs: dict[int, dict[str, str]], out_file: Path):
+ fig, axs = plt.subplots(1, 3, sharey=True, figsize=(7, 3.5))
+
+ for ax, layer in zip(axs, pairs.keys(), strict=True):
+ sae_ids = [pairs[layer]["local"], pairs[layer]["downstream"]]
+ layer_data = df_stats[df_stats["layer"] == layer]
+ grouped_scores = layer_data.groupby("sae")["explanationScore"]
+ mean_explanationScores = [grouped_scores.get_group(sae_id).mean() for sae_id in sae_ids]
+ confidence_intervals = [
+ sms.DescrStatsW(grouped_scores.get_group(sae_id)).tconfint_mean() for sae_id in sae_ids
+ ]
+
+ data = layer_data.loc[layer_data.sae.isin(sae_ids)]
+ colors = {
+ "local": "#f0a70a",
+ "e2e": "#518c31",
+ "downstream": plt.get_cmap("tab20b").colors[2], # type: ignore[reportAttributeAccessIssue]
+ }
+ l = sns.violinplot(
+ data=data,
+ hue="sae_type",
+ x=[1 if sae_type == "downstream" else 0 for sae_type in data.sae_type],
+ y="explanationScore",
+ palette=colors,
+ native_scale=True,
+ orient="v",
+ ax=ax,
+ inner=None,
+ cut=0,
+ bw_adjust=0.7,
+ hue_order=["local", "downstream"],
+ alpha=0.8,
+ legend=False,
+ )
+ ax.set_xticks([0, 1], ["Local", "Downstream"])
+ ax.errorbar(
+ x=(0, 1),
+ y=mean_explanationScores,
+ yerr=[(c[1] - c[0]) / 2 for c in confidence_intervals],
+ fmt="o",
+ color="black",
+ capsize=5,
+ )
+ ax.set_ylim((None, 1))
+ ax.set_title(label=f"Layer {layer}")
+
+ axs[0].set_ylabel("Auto-intepretability score")
+
+ plt.tight_layout()
+ plt.savefig(out_file, bbox_inches="tight")
+ print(f"Saved to {out_file}")
+
+
+def bootstrap_p_value(sample_a: ArrayLike, sample_b: ArrayLike) -> float:
+ """
+ Computes 2 sided p-value, for null hypothesis that means are the same
+ """
+ sample_a, sample_b = np.asarray(sample_a), np.asarray(sample_b)
+ mean_diff = np.mean(sample_a) - np.mean(sample_b)
+ n_bootstraps = 100_000
+ bootstrapped_diffs = []
+ combined = np.concatenate([sample_a, sample_b])
+ for _ in range(n_bootstraps):
+ boot_a = np.random.choice(combined, len(sample_a), replace=True)
+ boot_b = np.random.choice(combined, len(sample_b), replace=True)
+ boot_diff = np.mean(boot_a) - np.mean(boot_b)
+ bootstrapped_diffs.append(boot_diff)
+ bootstrapped_diffs = np.array(bootstrapped_diffs)
+ p_value = np.mean(np.abs(bootstrapped_diffs) > np.abs(mean_diff))
+ return p_value
+
+
+def bootstrap_mean_diff(sample_a: ArrayLike, sample_b: ArrayLike) -> tuple[float, Any]:
+ def mean_diff(resample_a, resample_b): # type: ignore
+ return np.mean(resample_a) - np.mean(resample_b)
+
+ result = bootstrap((sample_a, sample_b), n_resamples=100_000, statistic=mean_diff)
+ diff: float = np.mean(sample_a) - np.mean(sample_b) # type: ignore
+ return diff, result.confidence_interval
+
+
+def compute_p_values(df: pd.DataFrame, pairs: dict[int, dict[str, str]]):
+ score_groups = df.groupby(["layer", "sae"]).explanationScore
+ for layer, sae_dict in pairs.items():
+ sae_local, sae_downstream = sae_dict["local"], sae_dict["downstream"]
+ pval = bootstrap_p_value(
+ score_groups.get_group((layer, sae_downstream)).to_numpy(),
+ score_groups.get_group((layer, sae_local)).to_numpy(),
+ )
+ # print(f"L{layer}, {sae_downstream} vs {sae_local}: p={pval}")
+ diff, ci = bootstrap_mean_diff(
+ score_groups.get_group((layer, sae_downstream)).to_numpy(),
+ score_groups.get_group((layer, sae_local)).to_numpy(),
+ )
+ print(f"{layer}&${diff:.2f}\\ [{ci.low:.2f},{ci.high:.2f}]$&{pval:.2g}")
+
+
+if __name__ == "__main__":
+ plot_out_dir = Path(__file__).parent / "out/autointerp/"
+ score_out_dir = Path(__file__).parent / "out/autointerp/"
+ score_out_dir.mkdir(parents=True, exist_ok=True)
+ ## Running autointerp
+ # Get runs for similar CE and similar L0 for e2e+Downstream and local
+ # Note that "10-res_slefr-ajt" does not exist, we use 10-res_scefr-ajt for similar l0 too
+ saes = [
+ "2-res_scefr-ajt",
+ "2-res_slefr-ajt",
+ "2-res_sll-ajt",
+ "2-res_scl-ajt",
+ "6-res_scefr-ajt",
+ "6-res_slefr-ajt",
+ "6-res_sll-ajt",
+ "6-res_scl-ajt",
+ "10-res_scefr-ajt",
+ "10-res_sll-ajt",
+ "10-res_scl-ajt",
+ ]
+ run_autointerp(
+ saes=saes,
+ n_random_features=150,
+ dict_size=768 * 60,
+ feature_model_id="gpt2-small",
+ autointerp_explainer_model_name="gpt-4-turbo-2024-04-09",
+ autointerp_scorer_model_name="gpt-3.5-turbo",
+ out_dir=score_out_dir,
+ )
+
+ df = get_autointerp_results_df(score_out_dir)
+
+ ## Analysis of autointerp results
+
+ const_l0_pairs = {
+ 2: {"local": "res_sll-ajt", "downstream": "res_slefr-ajt"},
+ 6: {"local": "res_sll-ajt", "downstream": "res_slefr-ajt"},
+ 10: {"local": "res_sll-ajt", "downstream": "res_scefr-ajt"},
+ }
+
+ const_ce_pairs = {
+ 2: {"local": "res_scl-ajt", "downstream": "res_scefr-ajt"},
+ 6: {"local": "res_scl-ajt", "downstream": "res_scefr-ajt"},
+ 10: {"local": "res_scl-ajt", "downstream": "res_scefr-ajt"},
+ }
+
+ # compare_autointerp_results(df)
+ # compare_across_saes(df)
+ pair_violin_plot(df, const_l0_pairs, plot_out_dir / "l0_violin.png")
+ pair_violin_plot(df, const_ce_pairs, plot_out_dir / "ce_violin.png")
+ print("SAME L0")
+ compute_p_values(df, const_l0_pairs)
+ print("\nSAME CE")
+ compute_p_values(df, const_ce_pairs)
diff --git a/e2e_sae/scripts/generate_dashboards.py b/e2e_sae/scripts/generate_dashboards.py
deleted file mode 100644
index 14b95b3..0000000
--- a/e2e_sae/scripts/generate_dashboards.py
+++ /dev/null
@@ -1,1214 +0,0 @@
-"""Script for generating HTML feature dashboards
-Usage:
- $ python generate_dashboards.py
- (Generates dashboards for the SAEs in )
- or
- $ python generate_dashboards.py
- (Requires that a path to the sae.pt file is provided in dashboards_config.pretrained_sae_paths)
-
-dashboard HTML files be saved in dashboards_config.save_dir
-
-Two types of dashboards can be created:
- feature-centric:
- These are individual dashboards for each feature specified in
- dashboards_config.feature_indices, showing that feature's max activating examples, facts
- about the distribution of when it is active, and what it promotes through the logit lens.
- feature-centric dashboards will also be generated for all of the top features which apppear
- in the prompt-centric dashboards. Saved in dashboards_config.save_dir/dashboards_{sae_name}
- prompt-centric:
- Given a prompt and a specific token position within it, find the most important features
- active at that position, and make a dashboard showing where they all activate. There are
- three ways of measuring the importance of features: "act_size" (show the features which
- activated most strongly), "act_quantile" (show the features which activated much more than
- they usually do), and "loss_effect" (show the features with the biggest logit-lens ablation
- effect for predicting the correct next token - default).
- Saved in dashboards_config.save_dir/prompt_dashboards
-
-This script currently relies on an old commit of Callum McDouglal's sae_vis package:
-https://github.com/callummcdougall/sae_vis/commit/b28a0f7c7e936f4bea05528d952dfcd438533cce
-"""
-import math
-from collections.abc import Iterable
-from pathlib import Path
-from typing import Annotated, Literal
-
-import fire
-import numpy as np
-import torch
-from eindex import eindex
-from einops import einsum, rearrange
-from jaxtyping import Float, Int
-from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, NonNegativeInt, PositiveInt
-from sae_vis.data_fetching_fns import get_sequences_data
-from sae_vis.data_storing_fns import (
- FeatureData,
- FeatureVisParams,
- HistogramData,
- MiddlePlotsData,
- MultiFeatureData,
- MultiPromptData,
- PromptData,
- SequenceData,
- SequenceMultiGroupData,
-)
-from sae_vis.utils_fns import QuantileCalculator, TopK, process_str_tok
-from torch import Tensor
-from torch.utils.data.dataset import IterableDataset
-from tqdm import tqdm
-from transformers import (
- AutoTokenizer,
- PreTrainedTokenizer,
- PreTrainedTokenizerBase,
- PreTrainedTokenizerFast,
-)
-
-from e2e_sae.data import DatasetConfig, create_data_loader
-from e2e_sae.loader import load_pretrained_saes, load_tlens_model
-from e2e_sae.log import logger
-from e2e_sae.models.transformers import SAETransformer
-from e2e_sae.scripts.train_tlens_saes.run_train_tlens_saes import Config
-from e2e_sae.types import RootPath
-from e2e_sae.utils import filter_names, load_config, to_numpy
-
-FeatureIndicesType = dict[str, list[int]] | dict[str, Int[Tensor, "some_feats"]] # noqa: F821 (jaxtyping/pyright doesn't like single dimensions)
-StrScoreType = Literal["act_size", "act_quantile", "loss_effect"]
-
-
-class PromptDashboardsConfig(BaseModel):
- model_config = ConfigDict(extra="forbid", frozen=True)
- n_random_prompt_dashboards: NonNegativeInt = Field(
- default=50,
- description="The number of random prompts to generate prompt-centric dashboards for."
- "A feature-centric dashboard will be generated for random token positions in each prompt.",
- )
- data: DatasetConfig | None = Field(
- default=None,
- description="DatasetConfig for getting random prompts."
- "If None, then DashboardsConfig.data will be used",
- )
- prompts: list[str] | None = Field(
- default=None,
- description="Specific prompts on which to generate prompt-centric feature dashboards. "
- "A feature-centric dashboard will be generated for every token position in each prompt.",
- )
- str_score: StrScoreType = Field(
- default="loss_effect",
- description="The ordering metric for which features are most important in prompt-centric "
- "dashboards. Can be one of 'act_size', 'act_quantile', or 'loss_effect'",
- )
- num_top_features: PositiveInt = Field(
- default=10,
- description="How many of the most relevant features to show for each prompt"
- " in the prompt-centric dashboards",
- )
-
-
-class DashboardsConfig(BaseModel):
- model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True, frozen=True)
- pretrained_sae_paths: Annotated[
- list[RootPath] | None, BeforeValidator(lambda x: [x] if isinstance(x, str | Path) else x)
- ] = Field(None, description="Paths of the pretrained SAEs to load")
- sae_config_path: RootPath | None = Field(
- default=None,
- description="Path to the config file used to train the SAEs"
- " (if null, we'll assume it's at pretrained_sae_paths[0].parent / 'config.yaml')",
- )
- n_samples: PositiveInt | None = None
- batch_size: PositiveInt
- minibatch_size_features: PositiveInt | None = Field(
- default=256,
- description="Num features in each batch of calculations (i.e. we break up the features to "
- "avoid OOM errors).",
- )
- data: DatasetConfig = Field(
- description="DatasetConfig for the data which will be used to generate the dashboards",
- )
- save_dir: RootPath | None = Field(
- default=None,
- description="The directory for saving the HTML feature dashboard files",
- )
- sae_positions: Annotated[
- list[str] | None, BeforeValidator(lambda x: [x] if isinstance(x, str) else x)
- ] = Field(
- None,
- description="The names of the SAE positions to generate dashboards for. "
- "e.g. 'blocks.2.hook_resid_post'. If None, then all positions will be generated",
- )
- feature_indices: FeatureIndicesType | list[int] | None = Field(
- default=None,
- description="The features for which to generate dashboards on each SAE. If none, then "
- "we'll generate dashbaords for every feature.",
- )
- prompt_centric: PromptDashboardsConfig | None = Field(
- default=None,
- description="Used to generate prompt-centric (rather than feature-centric) dashboards."
- " Feature-centric dashboards will also be generated for every feature appaearing in these",
- )
- seed: NonNegativeInt = 0
-
-
-def compute_feature_acts(
- model: SAETransformer,
- tokens: Int[Tensor, "batch pos"],
- raw_sae_positions: list[str] | None = None,
- feature_indices: FeatureIndicesType | None = None,
- stop_at_layer: int = -1,
-) -> tuple[dict[str, Float[Tensor, "... some_feats"]], Float[Tensor, "... dim"]]:
- """Compute the activations of the SAEs in the model given a tensor of input tokens
-
- Args:
- model: The SAETransformer containing the SAEs and the tlens_model
- tokens: The inputs to the tlens_model
- raw_sae_positions: The names of the SAEs we're interested in
- feature_indices: The indices of the features we're interested in for each SAE
- stop_at_layer: Where to stop the forward pass. final_resid_acts will be returned from here
-
- Returns:
- - A dict of feature activations for each SAE.
- feature_acts[sae_position_name] = the feature activations of that SAE
- shape: batch pos some_feats
- - The residual stream activations of the model at the final layer (or at stop_at_layer)
- """
- if raw_sae_positions is None:
- raw_sae_positions = model.raw_sae_positions
- # Run model without SAEs
- final_resid_acts, orig_acts = model.tlens_model.run_with_cache(
- tokens,
- names_filter=raw_sae_positions,
- return_cache_object=False,
- stop_at_layer=stop_at_layer,
- )
- assert isinstance(final_resid_acts, Tensor)
- feature_acts: dict[str, Float[Tensor, "... some_feats"]] = {}
- # Run the activations through the SAEs
- for hook_name in orig_acts:
- sae = model.saes[hook_name.replace(".", "-")]
- output, feature_acts[hook_name] = sae(orig_acts[hook_name])
- del output
- feature_acts[hook_name] = feature_acts[hook_name].to("cpu")
- if feature_indices is not None:
- feature_acts[hook_name] = feature_acts[hook_name][..., feature_indices[hook_name]]
- return feature_acts, final_resid_acts
-
-
-def compute_feature_acts_on_distribution(
- model: SAETransformer,
- dataset_config: DatasetConfig,
- batch_size: PositiveInt,
- n_samples: PositiveInt | None = None,
- raw_sae_positions: list[str] | None = None,
- feature_indices: FeatureIndicesType | None = None,
- stop_at_layer: int = -1,
-) -> tuple[
- dict[str, Float[Tensor, "... some_feats"]], Float[Tensor, "... d_resid"], Int[Tensor, "..."]
-]:
- """Compute the activations of the SAEs in the model on a dataset of input tokens
-
- Args:
- model: The SAETransformer containing the SAEs and the tlens_model
- dataset_config: The DatasetConfig used to get the data loader for the tokens.
- batch_size: The batch size of data run through the model when calculating the feature acts
- n_samples: The number of batches of data to use for calculating the feature dashboard data
- raw_sae_positions: The names of the SAEs we're interested in. If none, do all SAEs.
- feature_indices: The indices of the features we're interested in for each SAE. If none, do
- all features.
- stop_at_layer: Where to stop the forward pass. final_resid_acts will be returned from here
-
- Returns:
- - a dict of SAE inputs, activations, and outputs for each SAE.
- feature_acts[sae_position_name] = the feature activations of that SAE
- shape: batch pos feats (or # feature_indices)
- - The residual stream activations of the model at the final layer (or at stop_at_layer)
- - The tokens used as input to the model
- """
- data_loader, _ = create_data_loader(
- dataset_config, batch_size=batch_size, buffer_size=batch_size
- )
- if raw_sae_positions is None:
- raw_sae_positions = model.raw_sae_positions
- assert raw_sae_positions is not None
- device = model.saes[raw_sae_positions[0].replace(".", "-")].device
- if n_samples is None:
- # If streaming (i.e. if the dataset is an IterableDataset), we don't know the length
- n_batches = None if isinstance(data_loader.dataset, IterableDataset) else len(data_loader)
- else:
- n_batches = math.ceil(n_samples / batch_size)
- if not isinstance(data_loader.dataset, IterableDataset):
- n_batches = min(n_batches, len(data_loader))
-
- total_samples = 0
- feature_acts_lists: dict[str, list[Float[Tensor, "... some_feats"]]] = {
- sae_name: [] for sae_name in raw_sae_positions
- }
- final_resid_acts_list: list[Float[Tensor, "... d_resid"]] = []
- tokens_list: list[Int[Tensor, "..."]] = []
- for batch in tqdm(data_loader, total=n_batches, desc="Computing feature acts"):
- batch_tokens: Int[Tensor, "..."] = batch[dataset_config.column_name].to(device=device)
- batch_feature_acts, batch_final_resid_acts = compute_feature_acts(
- model=model,
- tokens=batch_tokens,
- raw_sae_positions=raw_sae_positions,
- feature_indices=feature_indices,
- stop_at_layer=stop_at_layer,
- )
- for sae_name in raw_sae_positions:
- feature_acts_lists[sae_name].append(batch_feature_acts[sae_name])
- final_resid_acts_list.append(batch_final_resid_acts)
- tokens_list.append(batch_tokens)
- total_samples += batch_tokens.shape[0]
- if n_samples is not None and total_samples > n_samples:
- break
- final_resid_acts: Float[Tensor, "... d_resid"] = torch.cat(final_resid_acts_list, dim=0)
- tokens: Int[Tensor, "..."] = torch.cat(tokens_list, dim=0)
- feature_acts: dict[str, Float[Tensor, "... some_feats"]] = {}
- for sae_name in raw_sae_positions:
- feature_acts[sae_name] = torch.cat(tensors=feature_acts_lists[sae_name], dim=0)
- return feature_acts, final_resid_acts, tokens
-
-
-def create_vocab_dict(tokenizer: PreTrainedTokenizerBase) -> dict[int, str]:
- """
- Creates a vocab dict suitable for dashboards by replacing all the special tokens with their
- HTML representations. This function is adapted from sae_vis.create_vocab_dict()
- """
- vocab_dict: dict[str, int] = tokenizer.get_vocab()
- vocab_dict_processed: dict[int, str] = {v: process_str_tok(k) for k, v in vocab_dict.items()}
- return vocab_dict_processed
-
-
-@torch.inference_mode()
-def parse_activation_data(
- tokens: Int[Tensor, "batch pos"],
- feature_acts: Float[Tensor, "... some_feats"],
- final_resid_acts: Float[Tensor, "... d_resid"],
- feature_resid_dirs: Float[Tensor, "some_feats dim"],
- feature_indices_list: Iterable[int],
- W_U: Float[Tensor, "dim d_vocab"],
- vocab_dict: dict[int, str],
- fvp: FeatureVisParams,
-) -> MultiFeatureData:
- """Convert generic activation data into a MultiFeatureData object, which can be used to create
- the feature-centric visualisation.
- Adapted from sae_vis.data_fetching_fns._get_feature_data()
- final_resid_acts + W_U are used for the logit lens.
-
- Args:
- tokens: The inputs to the model
- feature_acts: The activations values of the features
- final_resid_acts: The activations of the final layer of the model
- feature_resid_dirs: The directions that each feature writes to the logit output
- feature_indices_list: The indices of the features we're interested in
- W_U: The unembed weights for the logit lens
- vocab_dict: A dictionary mapping vocab indices to strings
- fvp: FeatureVisParams, containing a bunch of settings. See the FeatureVisParams docstring in
- sae_vis for more information.
-
- Returns:
- A MultiFeatureData containing data for creating each feature's visualization,
- as well as data for rank-ordering the feature visualizations when it comes time
- to make the prompt-centric view (the `feature_act_quantiles` attribute).
- Use MultiFeatureData[feature_idx].get_html() to generate the HTML dashboard for a
- particular feature (returns a string of HTML).
-
- """
- device = W_U.device
- feature_acts.to(device)
- sequence_data_dict: dict[int, SequenceMultiGroupData] = {}
- middle_plots_data_dict: dict[int, MiddlePlotsData] = {}
- features_data: dict[int, FeatureData] = {}
- # Calculate all data for the right-hand visualisations, i.e. the sequences
- for i, feat in enumerate(feature_indices_list):
- # Add this feature's sequence data to the list
- sequence_data_dict[feat] = get_sequences_data(
- tokens=tokens,
- feat_acts=feature_acts[..., i],
- resid_post=final_resid_acts,
- feature_resid_dir=feature_resid_dirs[i],
- W_U=W_U,
- fvp=fvp,
- )
-
- # Get the logits of all features (i.e. the directions this feature writes to the logit output)
- logits = einsum(
- feature_resid_dirs,
- W_U,
- "feats d_model, d_model d_vocab -> feats d_vocab",
- )
- for i, (feat, logit) in enumerate(zip(feature_indices_list, logits, strict=True)):
- # Get data for logits (the histogram, and the table)
- logits_histogram_data = HistogramData(logit, 40, "5 ticks")
- top10_logits = TopK(logit, k=15, largest=True)
- bottom10_logits = TopK(logit, k=15, largest=False)
-
- # Get data for feature activations histogram (the title, and the histogram)
- feat_acts = feature_acts[..., i]
- nonzero_feat_acts = feat_acts[feat_acts > 0]
- frac_nonzero = nonzero_feat_acts.numel() / feat_acts.numel()
- freq_histogram_data = HistogramData(nonzero_feat_acts, 40, "ints")
-
- # Create a MiddlePlotsData object from this, and add it to the dict
- middle_plots_data_dict[feat] = MiddlePlotsData(
- bottom10_logits=bottom10_logits,
- top10_logits=top10_logits,
- logits_histogram_data=logits_histogram_data,
- freq_histogram_data=freq_histogram_data,
- frac_nonzero=frac_nonzero,
- )
-
- # Return the output, as a dict of FeatureData items
- for i, feat in enumerate(feature_indices_list):
- features_data[feat] = FeatureData(
- # Data-containing inputs (for the feature-centric visualisation)
- sequence_data=sequence_data_dict[feat],
- middle_plots_data=middle_plots_data_dict[feat],
- left_tables_data=None,
- # Non data-containing inputs
- feature_idx=feat,
- vocab_dict=vocab_dict,
- fvp=fvp,
- )
-
- # Also get the quantiles, which will be useful for the prompt-centric visualisation
-
- feature_act_quantiles = QuantileCalculator(
- data=rearrange(feature_acts, "... feats -> feats (...)")
- )
- return MultiFeatureData(features_data, feature_act_quantiles)
-
-
-def feature_indices_to_tensordict(
- feature_indices_in: FeatureIndicesType | list[int] | None,
- raw_sae_positions: list[str],
- model: SAETransformer,
-) -> dict[str, Tensor]:
- """ "Convert feature indices to a dict of tensor indices"""
- if feature_indices_in is None:
- feature_indices = {}
- for sae_name in raw_sae_positions:
- feature_indices[sae_name] = torch.arange(
- end=model.saes[sae_name.replace(".", "-")].n_dict_components
- )
- # Otherwise make sure that feature_indices is a dict of Int[Tensor]
- elif not isinstance(feature_indices_in, dict):
- feature_indices = {
- sae_name: Tensor(feature_indices_in).to("cpu").to(torch.int)
- for sae_name in raw_sae_positions
- }
- else:
- feature_indices: dict[str, Tensor] = {
- sae_name: Tensor(feature_indices_in[sae_name]).to("cpu").to(torch.int)
- for sae_name in raw_sae_positions
- }
- return feature_indices
-
-
-@torch.inference_mode()
-def get_dashboards_data(
- model: SAETransformer,
- dataset_config: DatasetConfig | None = None,
- tokens: Int[Tensor, "batch pos"] | None = None,
- sae_positions: list[str] | None = None,
- feature_indices: FeatureIndicesType | list[int] | None = None,
- n_samples: PositiveInt | None = None,
- batch_size: PositiveInt | None = None,
- minibatch_size_features: PositiveInt | None = None,
- fvp: FeatureVisParams | None = None,
- vocab_dict: dict[int, str] | None = None,
-) -> dict[str, MultiFeatureData]:
- """Gets data that needed to create the sequences in the feature-centric HTML visualisation
-
- Adapted from sae_vis.data_fetching_fns._get_feature_data()
-
- Args:
- model:
- The model (with SAEs) we'll be using to get the feature activations.
- dataset_config: [Only used if tokens is None]
- The DatasetConfig which will be used to get the data loader. If None, then tokens must
- be supplied.
- tokens:
- The tokens we'll be using to get the feature activations. If None, then we'll use the
- distribution from the dataset_config.
- sae_positions:
- The names of the SAEs we want to calculate feature dashboards for,
- eg. ['blocks.0.hook_resid_pre']. If none, then we'll do all of them.
- feature_indices:
- The features we're actually computing for each SAE. These might just be a subset of
- each SAE's full features. If None, then we'll do all of them.
- n_samples: [Only used if tokens is None]
- The number of batches of data to use for calculating the feature dashboard data when
- using dataset_config.
- batch_size: [Only used if tokens is None]
- The number of batches of data to use for calculating the feature dashboard data when
- using dataset_config
- minibatch_size_features:
- Num features in each batch of calculations (break up the features to avoid OOM errors).
- fvp:
- Feature visualization parameters, containing a bunch of other stuff. See the
- FeatureVisParams docstring in sae_vis for more information.
- vocab_dict:
- vocab dict suitable for dashboards with all the special tokens replaced with their
- HTML representations. If None then it will be created using create_vocab_dict(tokenizer)
-
- Returns:
- A dict of [sae_position_name: MultiFeatureData]. Each MultiFeatureData contains data for
- creating each feature's visualization, as well as data for rank-ordering the feature
- visualizations when it comes time to make the prompt-centric view
- (the `feature_act_quantiles` attribute).
- Use dashboards_data[sae_name][feature_idx].get_html() to generate the HTML
- dashboard for a particular feature (returns a string of HTML)
- """
- # Get the vocab dict, which we'll use at the end
- if vocab_dict is None:
- assert (
- model.tlens_model.tokenizer is not None
- ), "If voacab_dict is not supplied, the model must have a tokenizer"
- vocab_dict = create_vocab_dict(model.tlens_model.tokenizer)
-
- if fvp is None:
- fvp = FeatureVisParams(include_left_tables=False)
-
- if sae_positions is None:
- raw_sae_positions: list[str] = model.raw_sae_positions
- else:
- raw_sae_positions: list[str] = filter_names(
- list(model.tlens_model.hook_dict.keys()), sae_positions
- )
- # If we haven't supplied any feature indicies, assume that we want all of them
- feature_indices_tensors = feature_indices_to_tensordict(
- feature_indices_in=feature_indices,
- raw_sae_positions=raw_sae_positions,
- model=model,
- )
- for sae_name in raw_sae_positions:
- assert (
- feature_indices_tensors[sae_name].max().item()
- < model.saes[sae_name.replace(".", "-")].n_dict_components
- ), "Error: Some feature indices are greater than the number of SAE features"
-
- device = model.saes[raw_sae_positions[0].replace(".", "-")].device
- # Get the SAE feature activations (as well as their resudual stream inputs and outputs)
- if tokens is None:
- assert dataset_config is not None, "If no tokens are supplied, then config must be supplied"
- assert (
- batch_size is not None
- ), "If no tokens are supplied, then a batch_size must be supplied"
- feature_acts, final_resid_acts, tokens = compute_feature_acts_on_distribution(
- model=model,
- dataset_config=dataset_config,
- batch_size=batch_size,
- raw_sae_positions=raw_sae_positions,
- feature_indices=feature_indices_tensors,
- n_samples=n_samples,
- )
- else:
- tokens.to(device)
- feature_acts, final_resid_acts = compute_feature_acts(
- model=model,
- tokens=tokens,
- raw_sae_positions=raw_sae_positions,
- feature_indices=feature_indices_tensors,
- )
-
- # Filter out the never active features:
- for sae_name in raw_sae_positions:
- acts_sum = einsum(feature_acts[sae_name], "... some_feats -> some_feats").to("cpu")
- feature_acts[sae_name] = feature_acts[sae_name][..., acts_sum > 0]
- feature_indices_tensors[sae_name] = feature_indices_tensors[sae_name][acts_sum > 0]
- del acts_sum
-
- dashboards_data: dict[str, MultiFeatureData] = {
- name: MultiFeatureData() for name in raw_sae_positions
- }
-
- for sae_name in raw_sae_positions:
- sae = model.saes[sae_name.replace(".", "-")]
- W_dec: Float[Tensor, "feats dim"] = sae.decoder.weight.T
- feature_resid_dirs: Float[Tensor, "some_feats dim"] = W_dec[
- feature_indices_tensors[sae_name]
- ]
- W_U = model.tlens_model.W_U
-
- # Break up the features into batches
- if minibatch_size_features is None:
- feature_acts_batches = [feature_acts[sae_name]]
- feature_batches = [feature_indices_tensors[sae_name].tolist()]
- feature_resid_dir_batches = [feature_resid_dirs]
- else:
- feature_acts_batches = feature_acts[sae_name].split(minibatch_size_features, dim=-1)
- feature_batches = [
- x.tolist() for x in feature_indices_tensors[sae_name].split(minibatch_size_features)
- ]
- feature_resid_dir_batches = feature_resid_dirs.split(minibatch_size_features)
- for i in tqdm(iterable=range(len(feature_batches)), desc="Parsing activation data"):
- new_feature_data = parse_activation_data(
- tokens=tokens,
- feature_acts=feature_acts_batches[i].to_dense().to(device),
- final_resid_acts=final_resid_acts,
- feature_resid_dirs=feature_resid_dir_batches[i],
- feature_indices_list=feature_batches[i],
- W_U=W_U,
- vocab_dict=vocab_dict,
- fvp=fvp,
- )
- dashboards_data[sae_name].update(new_feature_data)
-
- return dashboards_data
-
-
-@torch.inference_mode()
-def parse_prompt_data(
- tokens: Int[Tensor, "batch pos"],
- str_tokens: list[str],
- features_data: MultiFeatureData,
- feature_acts: Float[Tensor, "seq some_feats"],
- final_resid_acts: Float[Tensor, "seq d_resid"],
- feature_resid_dirs: Float[Tensor, "some_feats dim"],
- feature_indices_list: list[int],
- W_U: Float[Tensor, "dim d_vocab"],
- num_top_features: int = 10,
-) -> MultiPromptData:
- """Gets data needed to create the sequences in the prompt-centric HTML visualisation.
-
- This visualization displays dashboards for the most relevant features on a prompt.
- Adapted from sae_vis.data_fetching_fns.get_prompt_data().
-
- Args:
- tokens: The input prompt to the model as tokens
- str_tokens: The input prompt to the model as a list of strings (one string per token)
- features_data: A MultiFeatureData containing information required to plot the features.
- feature_acts: The activations values of the features
- final_resid_acts: The activations of the final layer of the model
- feature_resid_dirs: The directions that each feature writes to the logit output
- feature_indices_list: The indices of the features we're interested in
- W_U: The unembed weights for the logit lens
- num_top_features: The number of top features to display in this view, for any given metric.
- Returns:
- A MultiPromptData object containing data for visualizing the most relevant features
- given the prompt.
-
- Similar to parse_feature_data, except it just gets the data relevant for a particular
- sequence (i.e. a custom one that the user inputs on their own).
-
- The ordering metric for relevant features is set by the str_score parameter in the
- MultiPromptData.get_html() method: it can be "act_size", "act_quantile", or "loss_effect"
- """
- torch.cuda.empty_cache()
- device = W_U.device
- n_feats = len(feature_indices_list)
- batch, seq_len = tokens.shape
- feats_contribution_to_loss = torch.empty(size=(n_feats, seq_len - 1), device=device)
-
- # Some logit computations which we only need to do once
- correct_token_unembeddings = W_U[:, tokens[0, 1:]] # [d_model seq]
- orig_logits = (
- final_resid_acts / final_resid_acts.std(dim=-1, keepdim=True)
- ) @ W_U # [seq d_vocab]
-
- sequence_data_dict: dict[int, SequenceData] = {}
-
- for i, feat in enumerate(feature_indices_list):
- # Calculate all data for the sequences
- # (this is the only truly 'new' bit of calculation we need to do)
-
- # Get this feature's output vector, using an outer product over feature acts for all tokens
- final_resid_acts_feature_effect = einsum(
- feature_acts[..., i].to_dense().to(device),
- feature_resid_dirs[i],
- "seq, d_model -> seq d_model",
- )
-
- # Ablate the output vector from the residual stream, and get logits post-ablation
- new_final_resid_acts = final_resid_acts - final_resid_acts_feature_effect
- new_logits = (new_final_resid_acts / new_final_resid_acts.std(dim=-1, keepdim=True)) @ W_U
-
- # Get the top5 & bottom5 changes in logits
- contribution_to_logprobs = orig_logits.log_softmax(dim=-1) - new_logits.log_softmax(dim=-1)
- top5_contribution_to_logits = TopK(contribution_to_logprobs[:-1], k=5)
- bottom5_contribution_to_logits = TopK(contribution_to_logprobs[:-1], k=5, largest=False)
-
- # Get the change in loss (which is negative of change of logprobs for correct token)
- contribution_to_loss = eindex(-contribution_to_logprobs[:-1], tokens[0, 1:], "seq [seq]")
- feats_contribution_to_loss[i, :] = contribution_to_loss
-
- # Store the sequence data
- sequence_data_dict[feat] = SequenceData(
- token_ids=tokens.squeeze(0).tolist(),
- feat_acts=feature_acts[..., i].tolist(),
- contribution_to_loss=[0.0] + contribution_to_loss.tolist(),
- top5_token_ids=top5_contribution_to_logits.indices.tolist(),
- top5_logit_contributions=top5_contribution_to_logits.values.tolist(),
- bottom5_token_ids=bottom5_contribution_to_logits.indices.tolist(),
- bottom5_logit_contributions=bottom5_contribution_to_logits.values.tolist(),
- )
-
- # Get the logits for the correct tokens
- logits_for_correct_tokens = einsum(
- feature_resid_dirs[i], correct_token_unembeddings, "d_model, d_model seq -> seq"
- )
-
- # Add the annotations data (feature activations and logit effect) to the histograms
- freq_line_posn = feature_acts[..., i].tolist()
- freq_line_text = [
- f"\\'{str_tok}\\'
{act:.3f}"
- for str_tok, act in zip(str_tokens[1:], freq_line_posn, strict=False)
- ]
- middle_plots_data = features_data[feat].middle_plots_data
- assert middle_plots_data is not None
- middle_plots_data.freq_histogram_data.line_posn = freq_line_posn
- middle_plots_data.freq_histogram_data.line_text = freq_line_text # type: ignore (due to typing bug in sae_vis)
- logits_line_posn = logits_for_correct_tokens.tolist()
- logits_line_text = [
- f"\\'{str_tok}\\'
{logits:.3f}"
- for str_tok, logits in zip(str_tokens[1:], logits_line_posn, strict=False)
- ]
- middle_plots_data.logits_histogram_data.line_posn = logits_line_posn
- middle_plots_data.logits_histogram_data.line_text = logits_line_text # type: ignore (due to typing bug in sae_vis)
-
- # Lastly, use the criteria (act size, act quantile, loss effect) to find top-scoring features
-
- # Construct a scores dict, which maps from things like ("act_quantile", seq_pos)
- # to a list of the top-scoring features
- scores_dict: dict[tuple[str, str], tuple[TopK, list[str]]] = {}
-
- for seq_pos in range(len(str_tokens)):
- # Filter the feature activations, since we only need the ones that are non-zero
- feat_acts_nonzero_filter = to_numpy(feature_acts[seq_pos] > 0)
- feat_acts_nonzero_locations = np.nonzero(feat_acts_nonzero_filter)[0].tolist()
- _feature_acts = (
- feature_acts[seq_pos, feat_acts_nonzero_filter].to_dense().to(device)
- ) # [feats_filtered,]
- _feature_indices_list = np.array(feature_indices_list)[feat_acts_nonzero_filter]
-
- if feat_acts_nonzero_filter.sum() > 0:
- k = min(num_top_features, _feature_acts.numel())
-
- # Get the "act_size" scores (we return it as a TopK object)
- act_size_topk = TopK(_feature_acts, k=k, largest=True)
- # Replace the indices with feature indices (these are different when
- # feature_indices_list argument is not [0, 1, 2, ...])
- act_size_topk.indices[:] = _feature_indices_list[act_size_topk.indices]
- scores_dict[("act_size", seq_pos)] = (act_size_topk, ".3f") # type: ignore (due to typing bug in sae_vis)
-
- # Get the "act_quantile" scores, which is just the fraction of cached feat acts that it
- # is larger than
- act_quantile, act_precision = features_data.feature_act_quantiles.get_quantile(
- _feature_acts, feat_acts_nonzero_locations
- )
- act_quantile_topk = TopK(act_quantile, k=k, largest=True)
- act_formatting_topk = [f".{act_precision[i]-2}%" for i in act_quantile_topk.indices]
- # Replace the indices with feature indices (these are different when
- # feature_indices_list argument is not [0, 1, 2, ...])
- act_quantile_topk.indices[:] = _feature_indices_list[act_quantile_topk.indices]
- scores_dict[("act_quantile", seq_pos)] = (act_quantile_topk, act_formatting_topk) # type: ignore (due to typing bug in sae_vis)
-
- # We don't measure loss effect on the first token
- if seq_pos == 0:
- continue
-
- # Filter the loss effects, since we only need the ones which have non-zero feature acts on
- # the tokens before them
- prev_feat_acts_nonzero_filter = to_numpy(feature_acts[seq_pos - 1] > 0)
- _contribution_to_loss = feats_contribution_to_loss[
- prev_feat_acts_nonzero_filter, seq_pos - 1
- ] # [feats_filtered,]
- _feature_indices_list_prev = np.array(feature_indices_list)[prev_feat_acts_nonzero_filter]
-
- if prev_feat_acts_nonzero_filter.sum() > 0:
- k = min(num_top_features, _contribution_to_loss.numel())
-
- # Get the "loss_effect" scores, which are just the min of features' contributions to
- # loss (min because we're looking for helpful features, not harmful ones)
- contribution_to_loss_topk = TopK(_contribution_to_loss, k=k, largest=False)
- # Replace the indices with feature indices (these are different when
- # feature_indices_list argument is not [0, 1, 2, ...])
- contribution_to_loss_topk.indices[:] = _feature_indices_list_prev[
- contribution_to_loss_topk.indices
- ]
- scores_dict[("loss_effect", seq_pos)] = (contribution_to_loss_topk, ".3f") # type: ignore (due to typing bug in sae_vis)
-
- # Get all the features which are required (i.e. all the sequence position indices)
- feature_indices_list_required = set()
- for score_topk, _ in scores_dict.values():
- feature_indices_list_required.update(set(score_topk.indices.tolist()))
- prompt_data_dict = {}
- for feat in feature_indices_list_required:
- middle_plots_data = features_data[feat].middle_plots_data
- assert middle_plots_data is not None
- prompt_data_dict[feat] = PromptData(
- prompt_data=sequence_data_dict[feat],
- sequence_data=features_data[feat].sequence_data[0],
- middle_plots_data=middle_plots_data,
- )
-
- return MultiPromptData(
- prompt_str_toks=str_tokens,
- prompt_data_dict=prompt_data_dict,
- scores_dict=scores_dict,
- )
-
-
-@torch.inference_mode()
-def get_prompt_data(
- model: SAETransformer,
- tokens: Int[Tensor, "batch pos"],
- str_tokens: list[str],
- dashboards_data: dict[str, MultiFeatureData],
- sae_positions: list[str] | None = None,
- num_top_features: PositiveInt = 10,
-) -> dict[str, MultiPromptData]:
- """Gets data needed to create the sequences in the prompt-centric HTML visualisation.
-
- This visualization displays dashboards for the most relevant features on a prompt.
- Adapted from sae_vis.data_fetching_fns.get_prompt_data()
-
- Args:
- model:
- The model (with SAEs) we'll be using to get the feature activations.
- tokens:
- The input prompt to the model as tokens
- str_tokens:
- The input prompt to the model as a list of strings (one string per token)
- dashboards_data:
- For each SAE, a MultiFeatureData containing information required to plot its features.
- sae_positions:
- The names of the SAEs we want to find relevant features in.
- eg. ['blocks.0.hook_resid_pre']. If none, then we'll do all of them.
- num_top_features: int
- The number of top features to display in this view, for any given metric.
-
- Returns:
- A dict of [sae_position_name: MultiPromptData]. Each MultiPromptData contains data for
- visualizing the most relevant features in that SAE given the prompt.
- Similar to get_feature_data, except it just gets the data relevant for a particular
- sequence (i.e. a custom one that the user inputs on their own).
-
- The ordering metric for relevant features is set by the str_score parameter in the
- MultiPromptData.get_html() method: it can be "act_size", "act_quantile", or "loss_effect"
- """
- assert tokens.shape[-1] == len(
- str_tokens
- ), "Error: the number of tokens does not equal the number of str_tokens"
- if sae_positions is None:
- raw_sae_positions: list[str] = model.raw_sae_positions
- else:
- raw_sae_positions: list[str] = filter_names(
- list(model.tlens_model.hook_dict.keys()), sae_positions
- )
- feature_indices: dict[str, list[int]] = {}
- for sae_name in raw_sae_positions:
- feature_indices[sae_name] = list(dashboards_data[sae_name].feature_data_dict.keys())
-
- feature_acts, final_resid_acts = compute_feature_acts(
- model=model,
- tokens=tokens,
- raw_sae_positions=raw_sae_positions,
- feature_indices=feature_indices,
- )
- final_resid_acts = final_resid_acts.squeeze(dim=0)
-
- prompt_data: dict[str, MultiPromptData] = {}
-
- for sae_name in raw_sae_positions:
- sae = model.saes[sae_name.replace(".", "-")]
- feature_act_dir: Float[Tensor, "dim some_feats"] = sae.encoder[0].weight.T[
- :, feature_indices[sae_name]
- ] # [d_in feats]
- feature_resid_dirs: Float[Tensor, "some_feats dim"] = sae.decoder.weight.T[
- feature_indices[sae_name]
- ] # [feats d_in]
- assert (
- feature_act_dir.T.shape
- == feature_resid_dirs.shape
- == (len(feature_indices[sae_name]), sae.input_size)
- )
-
- prompt_data[sae_name] = parse_prompt_data(
- tokens=tokens,
- str_tokens=str_tokens,
- features_data=dashboards_data[sae_name],
- feature_acts=feature_acts[sae_name].squeeze(dim=0),
- final_resid_acts=final_resid_acts,
- feature_resid_dirs=feature_resid_dirs,
- feature_indices_list=feature_indices[sae_name],
- W_U=model.tlens_model.W_U,
- num_top_features=num_top_features,
- )
- return prompt_data
-
-
-@torch.inference_mode()
-def generate_feature_dashboard_html_files(
- dashboards_data: dict[str, MultiFeatureData],
- feature_indices: FeatureIndicesType | dict[str, set[int]] | None,
- save_dir: str | Path = "",
-):
- """Generates viewable HTML dashboards for every feature in every SAE in dashboards_data"""
- if feature_indices is None:
- feature_indices = {name: dashboards_data[name].keys() for name in dashboards_data}
- save_dir = Path(save_dir)
- save_dir.mkdir(parents=True, exist_ok=True)
- for sae_name in feature_indices:
- logger.info(f"Saving HTML feature dashboards for the SAE at {sae_name}:")
- folder = save_dir / Path(f"dashboards_{sae_name}")
- folder.mkdir(parents=True, exist_ok=True)
- for feature_idx in tqdm(feature_indices[sae_name], desc="Dashboard HTML files"):
- feature_idx = (
- int(feature_idx.item()) if isinstance(feature_idx, Tensor) else feature_idx
- )
- if feature_idx in dashboards_data[sae_name].keys():
- html_str = dashboards_data[sae_name][feature_idx].get_html()
- filepath = folder / Path(f"feature-{feature_idx}.html")
- with open(filepath, "w") as f:
- f.write(html_str)
- logger.info(f"Saved HTML feature dashboards in {folder}")
-
-
-@torch.inference_mode()
-def generate_prompt_dashboard_html_files(
- model: SAETransformer,
- tokens: Int[Tensor, "batch pos"],
- str_tokens: list[str],
- dashboards_data: dict[str, MultiFeatureData],
- seq_pos: int | list[int] | None = None,
- vocab_dict: dict[int, str] | None = None,
- str_score: StrScoreType = "loss_effect",
- save_dir: str | Path = "",
-) -> dict[str, set[int]]:
- """Generates viewable HTML dashboards for the most relevant features (measured by str_score) for
- every SAE in dashboards_data.
-
- Returns the set of feature indices which were active"""
- assert tokens.shape[-1] == len(
- str_tokens
- ), "Error: the number of tokens does not equal the number of str_tokens"
- str_tokens = [s.replace("Ġ", " ") for s in str_tokens]
- if isinstance(seq_pos, int):
- seq_pos = [seq_pos]
- if seq_pos is None: # Generate a dashboard for every position if none is specified
- seq_pos = list(range(2, len(str_tokens) - 2))
- if vocab_dict is None:
- assert (
- model.tlens_model.tokenizer is not None
- ), "If voacab_dict is not supplied, the model must have a tokenizer"
- vocab_dict = create_vocab_dict(model.tlens_model.tokenizer)
- prompt_data = get_prompt_data(
- model=model, tokens=tokens, str_tokens=str_tokens, dashboards_data=dashboards_data
- )
- prompt = "".join(str_tokens)
- # Use the beginning of the prompt for the filename, but make sure that it's safe for a filename
- str_tokens_safe_for_filenames = [
- "".join(c for c in token if c.isalpha() or c.isdigit() or c == " ")
- .rstrip()
- .replace(" ", "-")
- for token in str_tokens
- ]
- filename_from_prompt = "".join(str_tokens_safe_for_filenames)
- filename_from_prompt = filename_from_prompt[:50]
- save_dir = Path(save_dir)
- save_dir.mkdir(parents=True, exist_ok=True)
- used_features: dict[str, set[int]] = {sae_name: set() for sae_name in dashboards_data}
- for sae_name in dashboards_data:
- seq_pos_with_scores: set[int] = {
- int(x[1])
- for x in prompt_data["blocks.1.hook_resid_post"].scores_dict
- if x[0] == str_score
- }
- for seq_pos_i in seq_pos_with_scores.intersection(seq_pos):
- # Find the most relevant features (by {str_score}) for the token
- # '{str_tokens[seq_pos_i]}' in the prompt '{prompt}:
- html_str = prompt_data[sae_name].get_html(seq_pos_i, str_score, vocab_dict)
- # Insert a title
- title: str = (
- f"