diff --git a/.env.example b/.env.example index 30865d0..aa1dc70 100644 --- a/.env.example +++ b/.env.example @@ -1,2 +1,4 @@ WANDB_API_KEY=your_api_key -WANDB_ENTITY=your_entity_name \ No newline at end of file +WANDB_ENTITY=your_entity_name +OPENAI_API_KEY=your_api_key # For autointerp +NEURONPEDIA_API_KEY=your_api_key # For neuronpdeia \ No newline at end of file diff --git a/e2e_sae/scripts/__init__.py b/e2e_sae/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/e2e_sae/scripts/autointerp.py b/e2e_sae/scripts/autointerp.py new file mode 100644 index 0000000..46be270 --- /dev/null +++ b/e2e_sae/scripts/autointerp.py @@ -0,0 +1,642 @@ +"""Run and analyze autointerp for SAEs. + +NOTE: Running this script currently requires installing ApolloResearch's fork of neuron-explainer. +https://github.com/ApolloResearch/automated-interpretability. This has been updated to work with +gpt4-turbo-2024-04-09 and fixes an OPENAI_API_KEY issue. + +This script requires the following environment variables: +- OPENAI_API_KEY: OpenAI API key. +- NEURONPEDIA_API_KEY: Neuronpedia API key. +These can be set in .env in the root of the repository (see .env.example). + +""" +import asyncio +import glob +import json +import os +import random +from datetime import datetime +from pathlib import Path +from typing import Any, Literal, TypeVar + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import requests +import seaborn as sns +import statsmodels.stats.api as sms +from dotenv import load_dotenv +from neuron_explainer.activations.activation_records import calculate_max_activation +from neuron_explainer.activations.activations import ActivationRecord +from neuron_explainer.explanations.calibrated_simulator import UncalibratedNeuronSimulator +from neuron_explainer.explanations.explainer import ContextSize, TokenActivationPairExplainer +from neuron_explainer.explanations.explanations import ScoredSimulation +from neuron_explainer.explanations.few_shot_examples import FewShotExampleSet +from neuron_explainer.explanations.prompt_builder import PromptFormat +from neuron_explainer.explanations.scoring import ( + _simulate_and_score_sequence, + aggregate_scored_sequence_simulations, +) +from neuron_explainer.explanations.simulator import ( + LogprobFreeExplanationTokenSimulator, + NeuronSimulator, +) +from numpy.typing import ArrayLike +from scipy.stats import bootstrap +from tenacity import retry, stop_after_attempt, wait_random_exponential +from tqdm import tqdm + +load_dotenv() # Required before importing UncalibratedNeuronSimulator +NEURONPEDIA_DOMAIN = "https://neuronpedia.org" +POSITIVE_INF_REPLACEMENT = 9999 +NEGATIVE_INF_REPLACEMENT = -9999 +NAN_REPLACEMENT = 0 +OTHER_INVALID_REPLACEMENT = -99999 + + +def NanAndInfReplacer(value: str): + """Replace NaNs and Infs in outputs.""" + replacements = { + "-Infinity": NEGATIVE_INF_REPLACEMENT, + "Infinity": POSITIVE_INF_REPLACEMENT, + "NaN": NAN_REPLACEMENT, + } + if value in replacements: + replaced_value = replacements[value] + return float(replaced_value) + else: + return NAN_REPLACEMENT + + +def get_neuronpedia_feature( + feature: int, layer: int, model: str = "gpt2-small", dataset: str = "res-jb" +) -> dict[str, Any]: + """Fetch a feature from Neuronpedia API.""" + url = f"{NEURONPEDIA_DOMAIN}/api/feature/{model}/{layer}-{dataset}/{feature}" + result = requests.get(url).json() + if "index" in result: + result["index"] = int(result["index"]) + else: + raise Exception(f"Feature {model}@{layer}-{dataset}:{feature} does not exist.") + return result + + +def test_key(api_key: str): + """Test the validity of the Neuronpedia API key.""" + url = f"{NEURONPEDIA_DOMAIN}/api/test" + body = {"apiKey": api_key} + response = requests.post(url, json=body) + if response.status_code != 200: + raise Exception("Neuronpedia API key is not valid.") + + +class NeuronpediaActivation: + """Represents an activation from Neuronpedia.""" + + def __init__(self, id: str, tokens: list[str], act_values: list[float]): + self.id = id + self.tokens = tokens + self.act_values = act_values + + +class NeuronpediaFeature: + """Represents a feature from Neuronpedia.""" + + def __init__( + self, + modelId: str, + layer: int, + dataset: str, + feature: int, + description: str = "", + activations: list[NeuronpediaActivation] | None = None, + autointerp_explanation: str = "", + autointerp_explanation_score: float = 0.0, + ): + self.modelId = modelId + self.layer = layer + self.dataset = dataset + self.feature = feature + self.description = description + self.activations = activations or [] + self.autointerp_explanation = autointerp_explanation + self.autointerp_explanation_score = autointerp_explanation_score + + def has_activating_text(self) -> bool: + """Check if the feature has activating text.""" + return any(max(activation.act_values) > 0 for activation in self.activations) + + +T = TypeVar("T") + + +@retry(wait=wait_random_exponential(min=1, max=500), stop=stop_after_attempt(10)) +def sleep_identity(x: T) -> T: + """Dummy function for retrying.""" + return x + + +@retry(wait=wait_random_exponential(min=1, max=500), stop=stop_after_attempt(10)) +async def simulate_and_score( + simulator: NeuronSimulator, activation_records: list[ActivationRecord] +) -> ScoredSimulation: + """Score an explanation of a neuron by how well it predicts activations on the given text sequences.""" + scored_sequence_simulations = await asyncio.gather( + *[ + sleep_identity( + _simulate_and_score_sequence( + simulator, + activation_record, + ) + ) + for activation_record in activation_records + ] + ) + return aggregate_scored_sequence_simulations(scored_sequence_simulations) + + +async def autointerp_neuronpedia_features( + features: list[NeuronpediaFeature], + openai_api_key: str, + autointerp_retry_attempts: int = 3, + autointerp_score_max_concurrent: int = 20, + neuronpedia_api_key: str = "", + do_score: bool = True, + output_dir: str = "neuronpedia_outputs/autointerp", + num_activations_to_use: int = 20, + max_explanation_activation_records: int = 20, + upload_to_neuronpedia: bool = True, + autointerp_explainer_model_name: Literal[ + "gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview", "gpt-4-turbo-2024-04-09" + ] = "gpt-4-1106-preview", + autointerp_scorer_model_name: Literal[ + "gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview", "gpt-4-turbo-2024-04-09" + ] = "gpt-3.5-turbo", +): + """ + Autointerp Neuronpedia features. + + Args: + features: List of NeuronpediaFeature objects. + openai_api_key: OpenAI API key. + autointerp_retry_attempts: Number of retry attempts for autointerp. + autointerp_score_max_concurrent: Maximum number of concurrent requests for autointerp scoring. + neuronpedia_api_key: Neuronpedia API key. + do_score: Whether to score the features. + output_dir: Output directory for saving the results. + num_activations_to_use: Number of activations to use. + max_explanation_activation_records: Maximum number of activation records for explanation. + upload_to_neuronpedia: Whether to upload the results to Neuronpedia. + autointerp_explainer_model_name: Model name for autointerp explainer. + autointerp_scorer_model_name: Model name for autointerp scorer. + + Returns: + None + """ + print("\n\n") + + os.environ["OPENAI_API_KEY"] = openai_api_key + + if upload_to_neuronpedia and not neuronpedia_api_key: + raise Exception( + "You need to provide a Neuronpedia API key to upload the results to Neuronpedia." + ) + + test_key(neuronpedia_api_key) + + print("\n\n=== Step 1) Fetching features from Neuronpedia") + for feature in features: + feature_data = get_neuronpedia_feature( + feature=feature.feature, + layer=feature.layer, + model=feature.modelId, + dataset=feature.dataset, + ) + + if "modelId" not in feature_data: + raise Exception( + f"Feature {feature.feature} in layer {feature.layer} of model {feature.modelId} and dataset {feature.dataset} does not exist." + ) + + if "activations" not in feature_data or len(feature_data["activations"]) == 0: + raise Exception( + f"Feature {feature.feature} in layer {feature.layer} of model {feature.modelId} and dataset {feature.dataset} does not have activations." + ) + + activations = feature_data["activations"] + activations_to_add = [] + for activation in activations: + if len(activations_to_add) < num_activations_to_use: + activations_to_add.append( + NeuronpediaActivation( + id=activation["id"], + tokens=activation["tokens"], + act_values=activation["values"], + ) + ) + feature.activations = activations_to_add + + if not feature.has_activating_text(): + raise Exception( + f"Feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature} appears dead - it does not have activating text." + ) + + for iteration_num, feature in enumerate(features): + start_time = datetime.now() + + print( + f"\n========== Feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature} ({iteration_num + 1} of {len(features)} Features) ==========" + ) + print( + f"\n=== Step 2) Explaining feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature}" + ) + + activation_records = [ + ActivationRecord(tokens=activation.tokens, activations=activation.act_values) + for activation in feature.activations + ] + + activation_records_explaining = activation_records[:max_explanation_activation_records] + + explainer = TokenActivationPairExplainer( + model_name=autointerp_explainer_model_name, + prompt_format=PromptFormat.HARMONY_V4, + context_size=ContextSize.SIXTEEN_K, + max_concurrent=1, + ) + + explanations = [] + for _ in range(autointerp_retry_attempts): + try: + explanations = await explainer.generate_explanations( + all_activation_records=activation_records_explaining, + max_activation=calculate_max_activation(activation_records_explaining), + num_samples=1, + ) + except Exception as e: + print(f"ERROR, RETRYING: {e}") + else: + break + else: + print( + f"ERROR: Failed to explain feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature}" + ) + + assert len(explanations) == 1 + explanation = explanations[0].rstrip(".") + print(f"===== {autointerp_explainer_model_name}'s explanation: {explanation}") + feature.autointerp_explanation = explanation + + if do_score: + print( + f"\n=== Step 3) Scoring feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature}" + ) + print("=== This can take up to 30 seconds.") + + temp_activation_records = [ + ActivationRecord( + tokens=[ + token.replace("<|endoftext|>", "<|not_endoftext|>") + .replace(" 55", "_55") + .encode("ascii", errors="backslashreplace") + .decode("ascii") + for token in activation_record.tokens + ], + activations=activation_record.activations, + ) + for activation_record in activation_records + ] + + score = None + scored_simulation = None + for _ in range(autointerp_retry_attempts): + try: + simulator = UncalibratedNeuronSimulator( + LogprobFreeExplanationTokenSimulator( + autointerp_scorer_model_name, + explanation, + json_mode=True, + max_concurrent=autointerp_score_max_concurrent, + few_shot_example_set=FewShotExampleSet.JL_FINE_TUNED, + prompt_format=PromptFormat.HARMONY_V4, + ) + ) + scored_simulation = await simulate_and_score(simulator, temp_activation_records) + score = scored_simulation.get_preferred_score() + except Exception as e: + print(f"ERROR, RETRYING: {e}") + else: + break + + if ( + score is None + or scored_simulation is None + or len(scored_simulation.scored_sequence_simulations) != num_activations_to_use + ): + print( + f"ERROR: Failed to score feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature}. Skipping it." + ) + continue + feature.autointerp_explanation_score = score + print(f"===== {autointerp_scorer_model_name}'s score: {(score * 100):.0f}") + + output_file = ( + f"{output_dir}/{feature.layer}-{feature.dataset}_feature-{feature.feature}_score-" + f"{feature.autointerp_explanation_score}_time-{datetime.now().strftime('%Y%m%d-%H%M%S')}.jsonl" + ) + os.makedirs(output_dir, exist_ok=True) + print(f"===== Your results will be saved to: {output_file} =====") + + output_data = json.dumps( + { + "apiKey": neuronpedia_api_key, + "feature": { + "modelId": feature.modelId, + "layer": f"{feature.layer}-{feature.dataset}", + "index": feature.feature, + "activations": feature.activations, + "explanation": feature.autointerp_explanation, + "explanationScore": feature.autointerp_explanation_score, + "explanationModel": autointerp_explainer_model_name, + "autointerpModel": autointerp_scorer_model_name, + "simulatedActivations": scored_simulation.scored_sequence_simulations, + }, + }, + default=vars, + ) + output_data_json = json.loads(output_data, parse_constant=NanAndInfReplacer) + output_data_str = json.dumps(output_data) + + print(f"\n=== Step 4) Saving feature to {output_file}") + with open(output_file, "a") as f: + f.write(output_data_str) + f.write("\n") + + if upload_to_neuronpedia: + print( + f"\n=== Step 5) Uploading feature to Neuronpedia: {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature}" + ) + url = f"{NEURONPEDIA_DOMAIN}/api/upload-explanation" + body = output_data_json + response = requests.post(url, json=body) + if response.status_code != 200: + print(f"ERROR: Couldn't upload explanation to Neuronpedia: {response.text}") + + end_time = datetime.now() + print(f"\n========== Time Spent for Feature: {end_time - start_time}\n") + + print("\n\n========== Generation and Upload Complete ==========\n\n") + + +def run_autointerp( + saes: list[str], + n_random_features: int, + dict_size: int, + feature_model_id: str, + autointerp_explainer_model_name: Literal[ + "gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview", "gpt-4-turbo-2024-04-09" + ], + autointerp_scorer_model_name: Literal[ + "gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview", "gpt-4-turbo-2024-04-09" + ], + out_dir: str | Path = "neuronpedia_outputs/autointerp", +): + """Explain and score random features across SAEs and upload to Neuronpedia. + + Args: + sae_sets: List of SAE sets. + n_random_features: Number of random features to explain and score for each SAE. + dict_size: Size of the SAE dictionary. + feature_model_id: Model ID that the SAE is attached to. + autointerp_explainer_model_name: Model name for autointerp explainer. + autointerp_scorer_model_name: Model name for autointerp scorer (much more expensive). + """ + for i in tqdm(range(n_random_features), desc="random features", total=n_random_features): + for sae in tqdm(saes, desc="sae", total=len(saes)): + layer = int(sae.split("-")[0]) + dataset = "-".join(sae.split("-")[1:]) + feature_exists = False + while not feature_exists: + feature = random.randint(0, dict_size) + feature_data = get_neuronpedia_feature( + feature=feature, + layer=layer, + model=feature_model_id, + dataset=dataset, + ) + if "activations" in feature_data and len(feature_data["activations"]) >= 20: + feature_exists = True + print(f"sae: {sae}, feature: {feature}") + + features = [ + NeuronpediaFeature( + modelId=feature_model_id, + layer=layer, + dataset=dataset, + feature=feature, + ), + ] + + asyncio.run( + autointerp_neuronpedia_features( + features=features, + openai_api_key=os.getenv("OPENAI_API_KEY", ""), + neuronpedia_api_key=os.getenv("NEURONPEDIA_API_KEY", ""), + autointerp_explainer_model_name=autointerp_explainer_model_name, + autointerp_scorer_model_name=autointerp_scorer_model_name, + num_activations_to_use=20, + max_explanation_activation_records=5, + output_dir=str(out_dir), + ) + ) + + feature_url = ( + f"https://neuronpedia.org/{feature_model_id}/{layer}-{dataset}/{feature}" + ) + print(f"Your feature is at: {feature_url}") + + +def get_autointerp_results_df(out_dir: Path): + autointerp_files = glob.glob(f"{out_dir}/**/*_feature-*_score-*", recursive=True) + stats = { + "layer": [], + "sae": [], + "feature": [], + "explanationScore": [], + "explanationModel": [], + "autointerpModel": [], + "explanation": [], + "sae_type": [], + } + for autointerp_file in tqdm(autointerp_files, "constructing stats from json autointerp files"): + with open(autointerp_file) as f: + json_data = json.loads(json.load(f)) + if "feature" in json_data: + json_data = json_data["feature"] + stats["layer"].append(int(json_data["layer"].split("-")[0])) + stats["sae"].append("-".join(json_data["layer"].split("-")[1:])) + stats["feature"].append(json_data["index"]) + stats["explanationScore"].append(json_data["explanationScore"]) + stats["autointerpModel"].append(json_data["autointerpModel"]) + stats["explanationModel"].append("gpt-4-1106-preview") + stats["explanation"].append(json_data["explanation"]) + stats["sae_type"].append( + {"e": "e2e", "l": "local", "r": "downstream"}[json_data["layer"].split("-")[1][-1]] + ) + df_stats = pd.DataFrame(stats) + assert not df_stats.empty + df_stats.dropna(inplace=True) + return df_stats + + +def pair_violin_plot(df_stats: pd.DataFrame, pairs: dict[int, dict[str, str]], out_file: Path): + fig, axs = plt.subplots(1, 3, sharey=True, figsize=(7, 3.5)) + + for ax, layer in zip(axs, pairs.keys(), strict=True): + sae_ids = [pairs[layer]["local"], pairs[layer]["downstream"]] + layer_data = df_stats[df_stats["layer"] == layer] + grouped_scores = layer_data.groupby("sae")["explanationScore"] + mean_explanationScores = [grouped_scores.get_group(sae_id).mean() for sae_id in sae_ids] + confidence_intervals = [ + sms.DescrStatsW(grouped_scores.get_group(sae_id)).tconfint_mean() for sae_id in sae_ids + ] + + data = layer_data.loc[layer_data.sae.isin(sae_ids)] + colors = { + "local": "#f0a70a", + "e2e": "#518c31", + "downstream": plt.get_cmap("tab20b").colors[2], # type: ignore[reportAttributeAccessIssue] + } + l = sns.violinplot( + data=data, + hue="sae_type", + x=[1 if sae_type == "downstream" else 0 for sae_type in data.sae_type], + y="explanationScore", + palette=colors, + native_scale=True, + orient="v", + ax=ax, + inner=None, + cut=0, + bw_adjust=0.7, + hue_order=["local", "downstream"], + alpha=0.8, + legend=False, + ) + ax.set_xticks([0, 1], ["Local", "Downstream"]) + ax.errorbar( + x=(0, 1), + y=mean_explanationScores, + yerr=[(c[1] - c[0]) / 2 for c in confidence_intervals], + fmt="o", + color="black", + capsize=5, + ) + ax.set_ylim((None, 1)) + ax.set_title(label=f"Layer {layer}") + + axs[0].set_ylabel("Auto-intepretability score") + + plt.tight_layout() + plt.savefig(out_file, bbox_inches="tight") + print(f"Saved to {out_file}") + + +def bootstrap_p_value(sample_a: ArrayLike, sample_b: ArrayLike) -> float: + """ + Computes 2 sided p-value, for null hypothesis that means are the same + """ + sample_a, sample_b = np.asarray(sample_a), np.asarray(sample_b) + mean_diff = np.mean(sample_a) - np.mean(sample_b) + n_bootstraps = 100_000 + bootstrapped_diffs = [] + combined = np.concatenate([sample_a, sample_b]) + for _ in range(n_bootstraps): + boot_a = np.random.choice(combined, len(sample_a), replace=True) + boot_b = np.random.choice(combined, len(sample_b), replace=True) + boot_diff = np.mean(boot_a) - np.mean(boot_b) + bootstrapped_diffs.append(boot_diff) + bootstrapped_diffs = np.array(bootstrapped_diffs) + p_value = np.mean(np.abs(bootstrapped_diffs) > np.abs(mean_diff)) + return p_value + + +def bootstrap_mean_diff(sample_a: ArrayLike, sample_b: ArrayLike) -> tuple[float, Any]: + def mean_diff(resample_a, resample_b): # type: ignore + return np.mean(resample_a) - np.mean(resample_b) + + result = bootstrap((sample_a, sample_b), n_resamples=100_000, statistic=mean_diff) + diff: float = np.mean(sample_a) - np.mean(sample_b) # type: ignore + return diff, result.confidence_interval + + +def compute_p_values(df: pd.DataFrame, pairs: dict[int, dict[str, str]]): + score_groups = df.groupby(["layer", "sae"]).explanationScore + for layer, sae_dict in pairs.items(): + sae_local, sae_downstream = sae_dict["local"], sae_dict["downstream"] + pval = bootstrap_p_value( + score_groups.get_group((layer, sae_downstream)).to_numpy(), + score_groups.get_group((layer, sae_local)).to_numpy(), + ) + # print(f"L{layer}, {sae_downstream} vs {sae_local}: p={pval}") + diff, ci = bootstrap_mean_diff( + score_groups.get_group((layer, sae_downstream)).to_numpy(), + score_groups.get_group((layer, sae_local)).to_numpy(), + ) + print(f"{layer}&${diff:.2f}\\ [{ci.low:.2f},{ci.high:.2f}]$&{pval:.2g}") + + +if __name__ == "__main__": + plot_out_dir = Path(__file__).parent / "out/autointerp/" + score_out_dir = Path(__file__).parent / "out/autointerp/" + score_out_dir.mkdir(parents=True, exist_ok=True) + ## Running autointerp + # Get runs for similar CE and similar L0 for e2e+Downstream and local + # Note that "10-res_slefr-ajt" does not exist, we use 10-res_scefr-ajt for similar l0 too + saes = [ + "2-res_scefr-ajt", + "2-res_slefr-ajt", + "2-res_sll-ajt", + "2-res_scl-ajt", + "6-res_scefr-ajt", + "6-res_slefr-ajt", + "6-res_sll-ajt", + "6-res_scl-ajt", + "10-res_scefr-ajt", + "10-res_sll-ajt", + "10-res_scl-ajt", + ] + run_autointerp( + saes=saes, + n_random_features=150, + dict_size=768 * 60, + feature_model_id="gpt2-small", + autointerp_explainer_model_name="gpt-4-turbo-2024-04-09", + autointerp_scorer_model_name="gpt-3.5-turbo", + out_dir=score_out_dir, + ) + + df = get_autointerp_results_df(score_out_dir) + + ## Analysis of autointerp results + + const_l0_pairs = { + 2: {"local": "res_sll-ajt", "downstream": "res_slefr-ajt"}, + 6: {"local": "res_sll-ajt", "downstream": "res_slefr-ajt"}, + 10: {"local": "res_sll-ajt", "downstream": "res_scefr-ajt"}, + } + + const_ce_pairs = { + 2: {"local": "res_scl-ajt", "downstream": "res_scefr-ajt"}, + 6: {"local": "res_scl-ajt", "downstream": "res_scefr-ajt"}, + 10: {"local": "res_scl-ajt", "downstream": "res_scefr-ajt"}, + } + + # compare_autointerp_results(df) + # compare_across_saes(df) + pair_violin_plot(df, const_l0_pairs, plot_out_dir / "l0_violin.png") + pair_violin_plot(df, const_ce_pairs, plot_out_dir / "ce_violin.png") + print("SAME L0") + compute_p_values(df, const_l0_pairs) + print("\nSAME CE") + compute_p_values(df, const_ce_pairs) diff --git a/e2e_sae/scripts/generate_dashboards.py b/e2e_sae/scripts/generate_dashboards.py deleted file mode 100644 index 14b95b3..0000000 --- a/e2e_sae/scripts/generate_dashboards.py +++ /dev/null @@ -1,1214 +0,0 @@ -"""Script for generating HTML feature dashboards -Usage: - $ python generate_dashboards.py - (Generates dashboards for the SAEs in ) - or - $ python generate_dashboards.py - (Requires that a path to the sae.pt file is provided in dashboards_config.pretrained_sae_paths) - -dashboard HTML files be saved in dashboards_config.save_dir - -Two types of dashboards can be created: - feature-centric: - These are individual dashboards for each feature specified in - dashboards_config.feature_indices, showing that feature's max activating examples, facts - about the distribution of when it is active, and what it promotes through the logit lens. - feature-centric dashboards will also be generated for all of the top features which apppear - in the prompt-centric dashboards. Saved in dashboards_config.save_dir/dashboards_{sae_name} - prompt-centric: - Given a prompt and a specific token position within it, find the most important features - active at that position, and make a dashboard showing where they all activate. There are - three ways of measuring the importance of features: "act_size" (show the features which - activated most strongly), "act_quantile" (show the features which activated much more than - they usually do), and "loss_effect" (show the features with the biggest logit-lens ablation - effect for predicting the correct next token - default). - Saved in dashboards_config.save_dir/prompt_dashboards - -This script currently relies on an old commit of Callum McDouglal's sae_vis package: -https://github.com/callummcdougall/sae_vis/commit/b28a0f7c7e936f4bea05528d952dfcd438533cce -""" -import math -from collections.abc import Iterable -from pathlib import Path -from typing import Annotated, Literal - -import fire -import numpy as np -import torch -from eindex import eindex -from einops import einsum, rearrange -from jaxtyping import Float, Int -from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, NonNegativeInt, PositiveInt -from sae_vis.data_fetching_fns import get_sequences_data -from sae_vis.data_storing_fns import ( - FeatureData, - FeatureVisParams, - HistogramData, - MiddlePlotsData, - MultiFeatureData, - MultiPromptData, - PromptData, - SequenceData, - SequenceMultiGroupData, -) -from sae_vis.utils_fns import QuantileCalculator, TopK, process_str_tok -from torch import Tensor -from torch.utils.data.dataset import IterableDataset -from tqdm import tqdm -from transformers import ( - AutoTokenizer, - PreTrainedTokenizer, - PreTrainedTokenizerBase, - PreTrainedTokenizerFast, -) - -from e2e_sae.data import DatasetConfig, create_data_loader -from e2e_sae.loader import load_pretrained_saes, load_tlens_model -from e2e_sae.log import logger -from e2e_sae.models.transformers import SAETransformer -from e2e_sae.scripts.train_tlens_saes.run_train_tlens_saes import Config -from e2e_sae.types import RootPath -from e2e_sae.utils import filter_names, load_config, to_numpy - -FeatureIndicesType = dict[str, list[int]] | dict[str, Int[Tensor, "some_feats"]] # noqa: F821 (jaxtyping/pyright doesn't like single dimensions) -StrScoreType = Literal["act_size", "act_quantile", "loss_effect"] - - -class PromptDashboardsConfig(BaseModel): - model_config = ConfigDict(extra="forbid", frozen=True) - n_random_prompt_dashboards: NonNegativeInt = Field( - default=50, - description="The number of random prompts to generate prompt-centric dashboards for." - "A feature-centric dashboard will be generated for random token positions in each prompt.", - ) - data: DatasetConfig | None = Field( - default=None, - description="DatasetConfig for getting random prompts." - "If None, then DashboardsConfig.data will be used", - ) - prompts: list[str] | None = Field( - default=None, - description="Specific prompts on which to generate prompt-centric feature dashboards. " - "A feature-centric dashboard will be generated for every token position in each prompt.", - ) - str_score: StrScoreType = Field( - default="loss_effect", - description="The ordering metric for which features are most important in prompt-centric " - "dashboards. Can be one of 'act_size', 'act_quantile', or 'loss_effect'", - ) - num_top_features: PositiveInt = Field( - default=10, - description="How many of the most relevant features to show for each prompt" - " in the prompt-centric dashboards", - ) - - -class DashboardsConfig(BaseModel): - model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True, frozen=True) - pretrained_sae_paths: Annotated[ - list[RootPath] | None, BeforeValidator(lambda x: [x] if isinstance(x, str | Path) else x) - ] = Field(None, description="Paths of the pretrained SAEs to load") - sae_config_path: RootPath | None = Field( - default=None, - description="Path to the config file used to train the SAEs" - " (if null, we'll assume it's at pretrained_sae_paths[0].parent / 'config.yaml')", - ) - n_samples: PositiveInt | None = None - batch_size: PositiveInt - minibatch_size_features: PositiveInt | None = Field( - default=256, - description="Num features in each batch of calculations (i.e. we break up the features to " - "avoid OOM errors).", - ) - data: DatasetConfig = Field( - description="DatasetConfig for the data which will be used to generate the dashboards", - ) - save_dir: RootPath | None = Field( - default=None, - description="The directory for saving the HTML feature dashboard files", - ) - sae_positions: Annotated[ - list[str] | None, BeforeValidator(lambda x: [x] if isinstance(x, str) else x) - ] = Field( - None, - description="The names of the SAE positions to generate dashboards for. " - "e.g. 'blocks.2.hook_resid_post'. If None, then all positions will be generated", - ) - feature_indices: FeatureIndicesType | list[int] | None = Field( - default=None, - description="The features for which to generate dashboards on each SAE. If none, then " - "we'll generate dashbaords for every feature.", - ) - prompt_centric: PromptDashboardsConfig | None = Field( - default=None, - description="Used to generate prompt-centric (rather than feature-centric) dashboards." - " Feature-centric dashboards will also be generated for every feature appaearing in these", - ) - seed: NonNegativeInt = 0 - - -def compute_feature_acts( - model: SAETransformer, - tokens: Int[Tensor, "batch pos"], - raw_sae_positions: list[str] | None = None, - feature_indices: FeatureIndicesType | None = None, - stop_at_layer: int = -1, -) -> tuple[dict[str, Float[Tensor, "... some_feats"]], Float[Tensor, "... dim"]]: - """Compute the activations of the SAEs in the model given a tensor of input tokens - - Args: - model: The SAETransformer containing the SAEs and the tlens_model - tokens: The inputs to the tlens_model - raw_sae_positions: The names of the SAEs we're interested in - feature_indices: The indices of the features we're interested in for each SAE - stop_at_layer: Where to stop the forward pass. final_resid_acts will be returned from here - - Returns: - - A dict of feature activations for each SAE. - feature_acts[sae_position_name] = the feature activations of that SAE - shape: batch pos some_feats - - The residual stream activations of the model at the final layer (or at stop_at_layer) - """ - if raw_sae_positions is None: - raw_sae_positions = model.raw_sae_positions - # Run model without SAEs - final_resid_acts, orig_acts = model.tlens_model.run_with_cache( - tokens, - names_filter=raw_sae_positions, - return_cache_object=False, - stop_at_layer=stop_at_layer, - ) - assert isinstance(final_resid_acts, Tensor) - feature_acts: dict[str, Float[Tensor, "... some_feats"]] = {} - # Run the activations through the SAEs - for hook_name in orig_acts: - sae = model.saes[hook_name.replace(".", "-")] - output, feature_acts[hook_name] = sae(orig_acts[hook_name]) - del output - feature_acts[hook_name] = feature_acts[hook_name].to("cpu") - if feature_indices is not None: - feature_acts[hook_name] = feature_acts[hook_name][..., feature_indices[hook_name]] - return feature_acts, final_resid_acts - - -def compute_feature_acts_on_distribution( - model: SAETransformer, - dataset_config: DatasetConfig, - batch_size: PositiveInt, - n_samples: PositiveInt | None = None, - raw_sae_positions: list[str] | None = None, - feature_indices: FeatureIndicesType | None = None, - stop_at_layer: int = -1, -) -> tuple[ - dict[str, Float[Tensor, "... some_feats"]], Float[Tensor, "... d_resid"], Int[Tensor, "..."] -]: - """Compute the activations of the SAEs in the model on a dataset of input tokens - - Args: - model: The SAETransformer containing the SAEs and the tlens_model - dataset_config: The DatasetConfig used to get the data loader for the tokens. - batch_size: The batch size of data run through the model when calculating the feature acts - n_samples: The number of batches of data to use for calculating the feature dashboard data - raw_sae_positions: The names of the SAEs we're interested in. If none, do all SAEs. - feature_indices: The indices of the features we're interested in for each SAE. If none, do - all features. - stop_at_layer: Where to stop the forward pass. final_resid_acts will be returned from here - - Returns: - - a dict of SAE inputs, activations, and outputs for each SAE. - feature_acts[sae_position_name] = the feature activations of that SAE - shape: batch pos feats (or # feature_indices) - - The residual stream activations of the model at the final layer (or at stop_at_layer) - - The tokens used as input to the model - """ - data_loader, _ = create_data_loader( - dataset_config, batch_size=batch_size, buffer_size=batch_size - ) - if raw_sae_positions is None: - raw_sae_positions = model.raw_sae_positions - assert raw_sae_positions is not None - device = model.saes[raw_sae_positions[0].replace(".", "-")].device - if n_samples is None: - # If streaming (i.e. if the dataset is an IterableDataset), we don't know the length - n_batches = None if isinstance(data_loader.dataset, IterableDataset) else len(data_loader) - else: - n_batches = math.ceil(n_samples / batch_size) - if not isinstance(data_loader.dataset, IterableDataset): - n_batches = min(n_batches, len(data_loader)) - - total_samples = 0 - feature_acts_lists: dict[str, list[Float[Tensor, "... some_feats"]]] = { - sae_name: [] for sae_name in raw_sae_positions - } - final_resid_acts_list: list[Float[Tensor, "... d_resid"]] = [] - tokens_list: list[Int[Tensor, "..."]] = [] - for batch in tqdm(data_loader, total=n_batches, desc="Computing feature acts"): - batch_tokens: Int[Tensor, "..."] = batch[dataset_config.column_name].to(device=device) - batch_feature_acts, batch_final_resid_acts = compute_feature_acts( - model=model, - tokens=batch_tokens, - raw_sae_positions=raw_sae_positions, - feature_indices=feature_indices, - stop_at_layer=stop_at_layer, - ) - for sae_name in raw_sae_positions: - feature_acts_lists[sae_name].append(batch_feature_acts[sae_name]) - final_resid_acts_list.append(batch_final_resid_acts) - tokens_list.append(batch_tokens) - total_samples += batch_tokens.shape[0] - if n_samples is not None and total_samples > n_samples: - break - final_resid_acts: Float[Tensor, "... d_resid"] = torch.cat(final_resid_acts_list, dim=0) - tokens: Int[Tensor, "..."] = torch.cat(tokens_list, dim=0) - feature_acts: dict[str, Float[Tensor, "... some_feats"]] = {} - for sae_name in raw_sae_positions: - feature_acts[sae_name] = torch.cat(tensors=feature_acts_lists[sae_name], dim=0) - return feature_acts, final_resid_acts, tokens - - -def create_vocab_dict(tokenizer: PreTrainedTokenizerBase) -> dict[int, str]: - """ - Creates a vocab dict suitable for dashboards by replacing all the special tokens with their - HTML representations. This function is adapted from sae_vis.create_vocab_dict() - """ - vocab_dict: dict[str, int] = tokenizer.get_vocab() - vocab_dict_processed: dict[int, str] = {v: process_str_tok(k) for k, v in vocab_dict.items()} - return vocab_dict_processed - - -@torch.inference_mode() -def parse_activation_data( - tokens: Int[Tensor, "batch pos"], - feature_acts: Float[Tensor, "... some_feats"], - final_resid_acts: Float[Tensor, "... d_resid"], - feature_resid_dirs: Float[Tensor, "some_feats dim"], - feature_indices_list: Iterable[int], - W_U: Float[Tensor, "dim d_vocab"], - vocab_dict: dict[int, str], - fvp: FeatureVisParams, -) -> MultiFeatureData: - """Convert generic activation data into a MultiFeatureData object, which can be used to create - the feature-centric visualisation. - Adapted from sae_vis.data_fetching_fns._get_feature_data() - final_resid_acts + W_U are used for the logit lens. - - Args: - tokens: The inputs to the model - feature_acts: The activations values of the features - final_resid_acts: The activations of the final layer of the model - feature_resid_dirs: The directions that each feature writes to the logit output - feature_indices_list: The indices of the features we're interested in - W_U: The unembed weights for the logit lens - vocab_dict: A dictionary mapping vocab indices to strings - fvp: FeatureVisParams, containing a bunch of settings. See the FeatureVisParams docstring in - sae_vis for more information. - - Returns: - A MultiFeatureData containing data for creating each feature's visualization, - as well as data for rank-ordering the feature visualizations when it comes time - to make the prompt-centric view (the `feature_act_quantiles` attribute). - Use MultiFeatureData[feature_idx].get_html() to generate the HTML dashboard for a - particular feature (returns a string of HTML). - - """ - device = W_U.device - feature_acts.to(device) - sequence_data_dict: dict[int, SequenceMultiGroupData] = {} - middle_plots_data_dict: dict[int, MiddlePlotsData] = {} - features_data: dict[int, FeatureData] = {} - # Calculate all data for the right-hand visualisations, i.e. the sequences - for i, feat in enumerate(feature_indices_list): - # Add this feature's sequence data to the list - sequence_data_dict[feat] = get_sequences_data( - tokens=tokens, - feat_acts=feature_acts[..., i], - resid_post=final_resid_acts, - feature_resid_dir=feature_resid_dirs[i], - W_U=W_U, - fvp=fvp, - ) - - # Get the logits of all features (i.e. the directions this feature writes to the logit output) - logits = einsum( - feature_resid_dirs, - W_U, - "feats d_model, d_model d_vocab -> feats d_vocab", - ) - for i, (feat, logit) in enumerate(zip(feature_indices_list, logits, strict=True)): - # Get data for logits (the histogram, and the table) - logits_histogram_data = HistogramData(logit, 40, "5 ticks") - top10_logits = TopK(logit, k=15, largest=True) - bottom10_logits = TopK(logit, k=15, largest=False) - - # Get data for feature activations histogram (the title, and the histogram) - feat_acts = feature_acts[..., i] - nonzero_feat_acts = feat_acts[feat_acts > 0] - frac_nonzero = nonzero_feat_acts.numel() / feat_acts.numel() - freq_histogram_data = HistogramData(nonzero_feat_acts, 40, "ints") - - # Create a MiddlePlotsData object from this, and add it to the dict - middle_plots_data_dict[feat] = MiddlePlotsData( - bottom10_logits=bottom10_logits, - top10_logits=top10_logits, - logits_histogram_data=logits_histogram_data, - freq_histogram_data=freq_histogram_data, - frac_nonzero=frac_nonzero, - ) - - # Return the output, as a dict of FeatureData items - for i, feat in enumerate(feature_indices_list): - features_data[feat] = FeatureData( - # Data-containing inputs (for the feature-centric visualisation) - sequence_data=sequence_data_dict[feat], - middle_plots_data=middle_plots_data_dict[feat], - left_tables_data=None, - # Non data-containing inputs - feature_idx=feat, - vocab_dict=vocab_dict, - fvp=fvp, - ) - - # Also get the quantiles, which will be useful for the prompt-centric visualisation - - feature_act_quantiles = QuantileCalculator( - data=rearrange(feature_acts, "... feats -> feats (...)") - ) - return MultiFeatureData(features_data, feature_act_quantiles) - - -def feature_indices_to_tensordict( - feature_indices_in: FeatureIndicesType | list[int] | None, - raw_sae_positions: list[str], - model: SAETransformer, -) -> dict[str, Tensor]: - """ "Convert feature indices to a dict of tensor indices""" - if feature_indices_in is None: - feature_indices = {} - for sae_name in raw_sae_positions: - feature_indices[sae_name] = torch.arange( - end=model.saes[sae_name.replace(".", "-")].n_dict_components - ) - # Otherwise make sure that feature_indices is a dict of Int[Tensor] - elif not isinstance(feature_indices_in, dict): - feature_indices = { - sae_name: Tensor(feature_indices_in).to("cpu").to(torch.int) - for sae_name in raw_sae_positions - } - else: - feature_indices: dict[str, Tensor] = { - sae_name: Tensor(feature_indices_in[sae_name]).to("cpu").to(torch.int) - for sae_name in raw_sae_positions - } - return feature_indices - - -@torch.inference_mode() -def get_dashboards_data( - model: SAETransformer, - dataset_config: DatasetConfig | None = None, - tokens: Int[Tensor, "batch pos"] | None = None, - sae_positions: list[str] | None = None, - feature_indices: FeatureIndicesType | list[int] | None = None, - n_samples: PositiveInt | None = None, - batch_size: PositiveInt | None = None, - minibatch_size_features: PositiveInt | None = None, - fvp: FeatureVisParams | None = None, - vocab_dict: dict[int, str] | None = None, -) -> dict[str, MultiFeatureData]: - """Gets data that needed to create the sequences in the feature-centric HTML visualisation - - Adapted from sae_vis.data_fetching_fns._get_feature_data() - - Args: - model: - The model (with SAEs) we'll be using to get the feature activations. - dataset_config: [Only used if tokens is None] - The DatasetConfig which will be used to get the data loader. If None, then tokens must - be supplied. - tokens: - The tokens we'll be using to get the feature activations. If None, then we'll use the - distribution from the dataset_config. - sae_positions: - The names of the SAEs we want to calculate feature dashboards for, - eg. ['blocks.0.hook_resid_pre']. If none, then we'll do all of them. - feature_indices: - The features we're actually computing for each SAE. These might just be a subset of - each SAE's full features. If None, then we'll do all of them. - n_samples: [Only used if tokens is None] - The number of batches of data to use for calculating the feature dashboard data when - using dataset_config. - batch_size: [Only used if tokens is None] - The number of batches of data to use for calculating the feature dashboard data when - using dataset_config - minibatch_size_features: - Num features in each batch of calculations (break up the features to avoid OOM errors). - fvp: - Feature visualization parameters, containing a bunch of other stuff. See the - FeatureVisParams docstring in sae_vis for more information. - vocab_dict: - vocab dict suitable for dashboards with all the special tokens replaced with their - HTML representations. If None then it will be created using create_vocab_dict(tokenizer) - - Returns: - A dict of [sae_position_name: MultiFeatureData]. Each MultiFeatureData contains data for - creating each feature's visualization, as well as data for rank-ordering the feature - visualizations when it comes time to make the prompt-centric view - (the `feature_act_quantiles` attribute). - Use dashboards_data[sae_name][feature_idx].get_html() to generate the HTML - dashboard for a particular feature (returns a string of HTML) - """ - # Get the vocab dict, which we'll use at the end - if vocab_dict is None: - assert ( - model.tlens_model.tokenizer is not None - ), "If voacab_dict is not supplied, the model must have a tokenizer" - vocab_dict = create_vocab_dict(model.tlens_model.tokenizer) - - if fvp is None: - fvp = FeatureVisParams(include_left_tables=False) - - if sae_positions is None: - raw_sae_positions: list[str] = model.raw_sae_positions - else: - raw_sae_positions: list[str] = filter_names( - list(model.tlens_model.hook_dict.keys()), sae_positions - ) - # If we haven't supplied any feature indicies, assume that we want all of them - feature_indices_tensors = feature_indices_to_tensordict( - feature_indices_in=feature_indices, - raw_sae_positions=raw_sae_positions, - model=model, - ) - for sae_name in raw_sae_positions: - assert ( - feature_indices_tensors[sae_name].max().item() - < model.saes[sae_name.replace(".", "-")].n_dict_components - ), "Error: Some feature indices are greater than the number of SAE features" - - device = model.saes[raw_sae_positions[0].replace(".", "-")].device - # Get the SAE feature activations (as well as their resudual stream inputs and outputs) - if tokens is None: - assert dataset_config is not None, "If no tokens are supplied, then config must be supplied" - assert ( - batch_size is not None - ), "If no tokens are supplied, then a batch_size must be supplied" - feature_acts, final_resid_acts, tokens = compute_feature_acts_on_distribution( - model=model, - dataset_config=dataset_config, - batch_size=batch_size, - raw_sae_positions=raw_sae_positions, - feature_indices=feature_indices_tensors, - n_samples=n_samples, - ) - else: - tokens.to(device) - feature_acts, final_resid_acts = compute_feature_acts( - model=model, - tokens=tokens, - raw_sae_positions=raw_sae_positions, - feature_indices=feature_indices_tensors, - ) - - # Filter out the never active features: - for sae_name in raw_sae_positions: - acts_sum = einsum(feature_acts[sae_name], "... some_feats -> some_feats").to("cpu") - feature_acts[sae_name] = feature_acts[sae_name][..., acts_sum > 0] - feature_indices_tensors[sae_name] = feature_indices_tensors[sae_name][acts_sum > 0] - del acts_sum - - dashboards_data: dict[str, MultiFeatureData] = { - name: MultiFeatureData() for name in raw_sae_positions - } - - for sae_name in raw_sae_positions: - sae = model.saes[sae_name.replace(".", "-")] - W_dec: Float[Tensor, "feats dim"] = sae.decoder.weight.T - feature_resid_dirs: Float[Tensor, "some_feats dim"] = W_dec[ - feature_indices_tensors[sae_name] - ] - W_U = model.tlens_model.W_U - - # Break up the features into batches - if minibatch_size_features is None: - feature_acts_batches = [feature_acts[sae_name]] - feature_batches = [feature_indices_tensors[sae_name].tolist()] - feature_resid_dir_batches = [feature_resid_dirs] - else: - feature_acts_batches = feature_acts[sae_name].split(minibatch_size_features, dim=-1) - feature_batches = [ - x.tolist() for x in feature_indices_tensors[sae_name].split(minibatch_size_features) - ] - feature_resid_dir_batches = feature_resid_dirs.split(minibatch_size_features) - for i in tqdm(iterable=range(len(feature_batches)), desc="Parsing activation data"): - new_feature_data = parse_activation_data( - tokens=tokens, - feature_acts=feature_acts_batches[i].to_dense().to(device), - final_resid_acts=final_resid_acts, - feature_resid_dirs=feature_resid_dir_batches[i], - feature_indices_list=feature_batches[i], - W_U=W_U, - vocab_dict=vocab_dict, - fvp=fvp, - ) - dashboards_data[sae_name].update(new_feature_data) - - return dashboards_data - - -@torch.inference_mode() -def parse_prompt_data( - tokens: Int[Tensor, "batch pos"], - str_tokens: list[str], - features_data: MultiFeatureData, - feature_acts: Float[Tensor, "seq some_feats"], - final_resid_acts: Float[Tensor, "seq d_resid"], - feature_resid_dirs: Float[Tensor, "some_feats dim"], - feature_indices_list: list[int], - W_U: Float[Tensor, "dim d_vocab"], - num_top_features: int = 10, -) -> MultiPromptData: - """Gets data needed to create the sequences in the prompt-centric HTML visualisation. - - This visualization displays dashboards for the most relevant features on a prompt. - Adapted from sae_vis.data_fetching_fns.get_prompt_data(). - - Args: - tokens: The input prompt to the model as tokens - str_tokens: The input prompt to the model as a list of strings (one string per token) - features_data: A MultiFeatureData containing information required to plot the features. - feature_acts: The activations values of the features - final_resid_acts: The activations of the final layer of the model - feature_resid_dirs: The directions that each feature writes to the logit output - feature_indices_list: The indices of the features we're interested in - W_U: The unembed weights for the logit lens - num_top_features: The number of top features to display in this view, for any given metric. - Returns: - A MultiPromptData object containing data for visualizing the most relevant features - given the prompt. - - Similar to parse_feature_data, except it just gets the data relevant for a particular - sequence (i.e. a custom one that the user inputs on their own). - - The ordering metric for relevant features is set by the str_score parameter in the - MultiPromptData.get_html() method: it can be "act_size", "act_quantile", or "loss_effect" - """ - torch.cuda.empty_cache() - device = W_U.device - n_feats = len(feature_indices_list) - batch, seq_len = tokens.shape - feats_contribution_to_loss = torch.empty(size=(n_feats, seq_len - 1), device=device) - - # Some logit computations which we only need to do once - correct_token_unembeddings = W_U[:, tokens[0, 1:]] # [d_model seq] - orig_logits = ( - final_resid_acts / final_resid_acts.std(dim=-1, keepdim=True) - ) @ W_U # [seq d_vocab] - - sequence_data_dict: dict[int, SequenceData] = {} - - for i, feat in enumerate(feature_indices_list): - # Calculate all data for the sequences - # (this is the only truly 'new' bit of calculation we need to do) - - # Get this feature's output vector, using an outer product over feature acts for all tokens - final_resid_acts_feature_effect = einsum( - feature_acts[..., i].to_dense().to(device), - feature_resid_dirs[i], - "seq, d_model -> seq d_model", - ) - - # Ablate the output vector from the residual stream, and get logits post-ablation - new_final_resid_acts = final_resid_acts - final_resid_acts_feature_effect - new_logits = (new_final_resid_acts / new_final_resid_acts.std(dim=-1, keepdim=True)) @ W_U - - # Get the top5 & bottom5 changes in logits - contribution_to_logprobs = orig_logits.log_softmax(dim=-1) - new_logits.log_softmax(dim=-1) - top5_contribution_to_logits = TopK(contribution_to_logprobs[:-1], k=5) - bottom5_contribution_to_logits = TopK(contribution_to_logprobs[:-1], k=5, largest=False) - - # Get the change in loss (which is negative of change of logprobs for correct token) - contribution_to_loss = eindex(-contribution_to_logprobs[:-1], tokens[0, 1:], "seq [seq]") - feats_contribution_to_loss[i, :] = contribution_to_loss - - # Store the sequence data - sequence_data_dict[feat] = SequenceData( - token_ids=tokens.squeeze(0).tolist(), - feat_acts=feature_acts[..., i].tolist(), - contribution_to_loss=[0.0] + contribution_to_loss.tolist(), - top5_token_ids=top5_contribution_to_logits.indices.tolist(), - top5_logit_contributions=top5_contribution_to_logits.values.tolist(), - bottom5_token_ids=bottom5_contribution_to_logits.indices.tolist(), - bottom5_logit_contributions=bottom5_contribution_to_logits.values.tolist(), - ) - - # Get the logits for the correct tokens - logits_for_correct_tokens = einsum( - feature_resid_dirs[i], correct_token_unembeddings, "d_model, d_model seq -> seq" - ) - - # Add the annotations data (feature activations and logit effect) to the histograms - freq_line_posn = feature_acts[..., i].tolist() - freq_line_text = [ - f"\\'{str_tok}\\'
{act:.3f}" - for str_tok, act in zip(str_tokens[1:], freq_line_posn, strict=False) - ] - middle_plots_data = features_data[feat].middle_plots_data - assert middle_plots_data is not None - middle_plots_data.freq_histogram_data.line_posn = freq_line_posn - middle_plots_data.freq_histogram_data.line_text = freq_line_text # type: ignore (due to typing bug in sae_vis) - logits_line_posn = logits_for_correct_tokens.tolist() - logits_line_text = [ - f"\\'{str_tok}\\'
{logits:.3f}" - for str_tok, logits in zip(str_tokens[1:], logits_line_posn, strict=False) - ] - middle_plots_data.logits_histogram_data.line_posn = logits_line_posn - middle_plots_data.logits_histogram_data.line_text = logits_line_text # type: ignore (due to typing bug in sae_vis) - - # Lastly, use the criteria (act size, act quantile, loss effect) to find top-scoring features - - # Construct a scores dict, which maps from things like ("act_quantile", seq_pos) - # to a list of the top-scoring features - scores_dict: dict[tuple[str, str], tuple[TopK, list[str]]] = {} - - for seq_pos in range(len(str_tokens)): - # Filter the feature activations, since we only need the ones that are non-zero - feat_acts_nonzero_filter = to_numpy(feature_acts[seq_pos] > 0) - feat_acts_nonzero_locations = np.nonzero(feat_acts_nonzero_filter)[0].tolist() - _feature_acts = ( - feature_acts[seq_pos, feat_acts_nonzero_filter].to_dense().to(device) - ) # [feats_filtered,] - _feature_indices_list = np.array(feature_indices_list)[feat_acts_nonzero_filter] - - if feat_acts_nonzero_filter.sum() > 0: - k = min(num_top_features, _feature_acts.numel()) - - # Get the "act_size" scores (we return it as a TopK object) - act_size_topk = TopK(_feature_acts, k=k, largest=True) - # Replace the indices with feature indices (these are different when - # feature_indices_list argument is not [0, 1, 2, ...]) - act_size_topk.indices[:] = _feature_indices_list[act_size_topk.indices] - scores_dict[("act_size", seq_pos)] = (act_size_topk, ".3f") # type: ignore (due to typing bug in sae_vis) - - # Get the "act_quantile" scores, which is just the fraction of cached feat acts that it - # is larger than - act_quantile, act_precision = features_data.feature_act_quantiles.get_quantile( - _feature_acts, feat_acts_nonzero_locations - ) - act_quantile_topk = TopK(act_quantile, k=k, largest=True) - act_formatting_topk = [f".{act_precision[i]-2}%" for i in act_quantile_topk.indices] - # Replace the indices with feature indices (these are different when - # feature_indices_list argument is not [0, 1, 2, ...]) - act_quantile_topk.indices[:] = _feature_indices_list[act_quantile_topk.indices] - scores_dict[("act_quantile", seq_pos)] = (act_quantile_topk, act_formatting_topk) # type: ignore (due to typing bug in sae_vis) - - # We don't measure loss effect on the first token - if seq_pos == 0: - continue - - # Filter the loss effects, since we only need the ones which have non-zero feature acts on - # the tokens before them - prev_feat_acts_nonzero_filter = to_numpy(feature_acts[seq_pos - 1] > 0) - _contribution_to_loss = feats_contribution_to_loss[ - prev_feat_acts_nonzero_filter, seq_pos - 1 - ] # [feats_filtered,] - _feature_indices_list_prev = np.array(feature_indices_list)[prev_feat_acts_nonzero_filter] - - if prev_feat_acts_nonzero_filter.sum() > 0: - k = min(num_top_features, _contribution_to_loss.numel()) - - # Get the "loss_effect" scores, which are just the min of features' contributions to - # loss (min because we're looking for helpful features, not harmful ones) - contribution_to_loss_topk = TopK(_contribution_to_loss, k=k, largest=False) - # Replace the indices with feature indices (these are different when - # feature_indices_list argument is not [0, 1, 2, ...]) - contribution_to_loss_topk.indices[:] = _feature_indices_list_prev[ - contribution_to_loss_topk.indices - ] - scores_dict[("loss_effect", seq_pos)] = (contribution_to_loss_topk, ".3f") # type: ignore (due to typing bug in sae_vis) - - # Get all the features which are required (i.e. all the sequence position indices) - feature_indices_list_required = set() - for score_topk, _ in scores_dict.values(): - feature_indices_list_required.update(set(score_topk.indices.tolist())) - prompt_data_dict = {} - for feat in feature_indices_list_required: - middle_plots_data = features_data[feat].middle_plots_data - assert middle_plots_data is not None - prompt_data_dict[feat] = PromptData( - prompt_data=sequence_data_dict[feat], - sequence_data=features_data[feat].sequence_data[0], - middle_plots_data=middle_plots_data, - ) - - return MultiPromptData( - prompt_str_toks=str_tokens, - prompt_data_dict=prompt_data_dict, - scores_dict=scores_dict, - ) - - -@torch.inference_mode() -def get_prompt_data( - model: SAETransformer, - tokens: Int[Tensor, "batch pos"], - str_tokens: list[str], - dashboards_data: dict[str, MultiFeatureData], - sae_positions: list[str] | None = None, - num_top_features: PositiveInt = 10, -) -> dict[str, MultiPromptData]: - """Gets data needed to create the sequences in the prompt-centric HTML visualisation. - - This visualization displays dashboards for the most relevant features on a prompt. - Adapted from sae_vis.data_fetching_fns.get_prompt_data() - - Args: - model: - The model (with SAEs) we'll be using to get the feature activations. - tokens: - The input prompt to the model as tokens - str_tokens: - The input prompt to the model as a list of strings (one string per token) - dashboards_data: - For each SAE, a MultiFeatureData containing information required to plot its features. - sae_positions: - The names of the SAEs we want to find relevant features in. - eg. ['blocks.0.hook_resid_pre']. If none, then we'll do all of them. - num_top_features: int - The number of top features to display in this view, for any given metric. - - Returns: - A dict of [sae_position_name: MultiPromptData]. Each MultiPromptData contains data for - visualizing the most relevant features in that SAE given the prompt. - Similar to get_feature_data, except it just gets the data relevant for a particular - sequence (i.e. a custom one that the user inputs on their own). - - The ordering metric for relevant features is set by the str_score parameter in the - MultiPromptData.get_html() method: it can be "act_size", "act_quantile", or "loss_effect" - """ - assert tokens.shape[-1] == len( - str_tokens - ), "Error: the number of tokens does not equal the number of str_tokens" - if sae_positions is None: - raw_sae_positions: list[str] = model.raw_sae_positions - else: - raw_sae_positions: list[str] = filter_names( - list(model.tlens_model.hook_dict.keys()), sae_positions - ) - feature_indices: dict[str, list[int]] = {} - for sae_name in raw_sae_positions: - feature_indices[sae_name] = list(dashboards_data[sae_name].feature_data_dict.keys()) - - feature_acts, final_resid_acts = compute_feature_acts( - model=model, - tokens=tokens, - raw_sae_positions=raw_sae_positions, - feature_indices=feature_indices, - ) - final_resid_acts = final_resid_acts.squeeze(dim=0) - - prompt_data: dict[str, MultiPromptData] = {} - - for sae_name in raw_sae_positions: - sae = model.saes[sae_name.replace(".", "-")] - feature_act_dir: Float[Tensor, "dim some_feats"] = sae.encoder[0].weight.T[ - :, feature_indices[sae_name] - ] # [d_in feats] - feature_resid_dirs: Float[Tensor, "some_feats dim"] = sae.decoder.weight.T[ - feature_indices[sae_name] - ] # [feats d_in] - assert ( - feature_act_dir.T.shape - == feature_resid_dirs.shape - == (len(feature_indices[sae_name]), sae.input_size) - ) - - prompt_data[sae_name] = parse_prompt_data( - tokens=tokens, - str_tokens=str_tokens, - features_data=dashboards_data[sae_name], - feature_acts=feature_acts[sae_name].squeeze(dim=0), - final_resid_acts=final_resid_acts, - feature_resid_dirs=feature_resid_dirs, - feature_indices_list=feature_indices[sae_name], - W_U=model.tlens_model.W_U, - num_top_features=num_top_features, - ) - return prompt_data - - -@torch.inference_mode() -def generate_feature_dashboard_html_files( - dashboards_data: dict[str, MultiFeatureData], - feature_indices: FeatureIndicesType | dict[str, set[int]] | None, - save_dir: str | Path = "", -): - """Generates viewable HTML dashboards for every feature in every SAE in dashboards_data""" - if feature_indices is None: - feature_indices = {name: dashboards_data[name].keys() for name in dashboards_data} - save_dir = Path(save_dir) - save_dir.mkdir(parents=True, exist_ok=True) - for sae_name in feature_indices: - logger.info(f"Saving HTML feature dashboards for the SAE at {sae_name}:") - folder = save_dir / Path(f"dashboards_{sae_name}") - folder.mkdir(parents=True, exist_ok=True) - for feature_idx in tqdm(feature_indices[sae_name], desc="Dashboard HTML files"): - feature_idx = ( - int(feature_idx.item()) if isinstance(feature_idx, Tensor) else feature_idx - ) - if feature_idx in dashboards_data[sae_name].keys(): - html_str = dashboards_data[sae_name][feature_idx].get_html() - filepath = folder / Path(f"feature-{feature_idx}.html") - with open(filepath, "w") as f: - f.write(html_str) - logger.info(f"Saved HTML feature dashboards in {folder}") - - -@torch.inference_mode() -def generate_prompt_dashboard_html_files( - model: SAETransformer, - tokens: Int[Tensor, "batch pos"], - str_tokens: list[str], - dashboards_data: dict[str, MultiFeatureData], - seq_pos: int | list[int] | None = None, - vocab_dict: dict[int, str] | None = None, - str_score: StrScoreType = "loss_effect", - save_dir: str | Path = "", -) -> dict[str, set[int]]: - """Generates viewable HTML dashboards for the most relevant features (measured by str_score) for - every SAE in dashboards_data. - - Returns the set of feature indices which were active""" - assert tokens.shape[-1] == len( - str_tokens - ), "Error: the number of tokens does not equal the number of str_tokens" - str_tokens = [s.replace("Ġ", " ") for s in str_tokens] - if isinstance(seq_pos, int): - seq_pos = [seq_pos] - if seq_pos is None: # Generate a dashboard for every position if none is specified - seq_pos = list(range(2, len(str_tokens) - 2)) - if vocab_dict is None: - assert ( - model.tlens_model.tokenizer is not None - ), "If voacab_dict is not supplied, the model must have a tokenizer" - vocab_dict = create_vocab_dict(model.tlens_model.tokenizer) - prompt_data = get_prompt_data( - model=model, tokens=tokens, str_tokens=str_tokens, dashboards_data=dashboards_data - ) - prompt = "".join(str_tokens) - # Use the beginning of the prompt for the filename, but make sure that it's safe for a filename - str_tokens_safe_for_filenames = [ - "".join(c for c in token if c.isalpha() or c.isdigit() or c == " ") - .rstrip() - .replace(" ", "-") - for token in str_tokens - ] - filename_from_prompt = "".join(str_tokens_safe_for_filenames) - filename_from_prompt = filename_from_prompt[:50] - save_dir = Path(save_dir) - save_dir.mkdir(parents=True, exist_ok=True) - used_features: dict[str, set[int]] = {sae_name: set() for sae_name in dashboards_data} - for sae_name in dashboards_data: - seq_pos_with_scores: set[int] = { - int(x[1]) - for x in prompt_data["blocks.1.hook_resid_post"].scores_dict - if x[0] == str_score - } - for seq_pos_i in seq_pos_with_scores.intersection(seq_pos): - # Find the most relevant features (by {str_score}) for the token - # '{str_tokens[seq_pos_i]}' in the prompt '{prompt}: - html_str = prompt_data[sae_name].get_html(seq_pos_i, str_score, vocab_dict) - # Insert a title - title: str = ( - f"

  The most relevant features from {sae_name},
  " - f"measured by {str_score} on the '{str_tokens[seq_pos_i].replace('Ġ',' ')}' " - f"token (token number {seq_pos_i}) in the prompt '{prompt}':

" - ) - substr = "
" - html_str = html_str.replace( - substr, "
" + title + "
\n" + substr - ) - filepath = save_dir / Path( - f"prompt-{filename_from_prompt}_token-{seq_pos_i}-" - f"{str_tokens_safe_for_filenames[seq_pos_i]}_-{str_score.replace('_','-')}_" - f"sae-{sae_name}.html" - ) - with open(filepath, "w") as f: - f.write(html_str) - scores = prompt_data[sae_name].scores_dict[(str_score, seq_pos_i)][0] # type: ignore - used_features[sae_name] = used_features[sae_name].union(scores.indices.tolist()) - return used_features - - -@torch.inference_mode() -def generate_random_prompt_dashboards( - model: SAETransformer, - dashboards_data: dict[str, MultiFeatureData], - dashboards_config: DashboardsConfig, - use_model_tokenizer: bool = False, - save_dir: RootPath | None = None, -) -> dict[str, set[int]]: - """Generates prompt-centric HTML dashboards for prompts from the training distribution. - - A data_loader is created using the dashboards_config.prompt_centric.data if it exists, - otherwise using the dashboards_config.data config. - For each random prompt, dashboards are generated for three consecutive sequence positions.""" - np.random.seed(dashboards_config.seed) - if save_dir is None: - save_dir = dashboards_config.save_dir - assert save_dir is not None, ( - "generate_random_prompt_dashboards() saves HTML files, but no save_dir was specified in" - + " the dashboards_config or given as input" - ) - assert dashboards_config.prompt_centric is not None, ( - "generate_random_prompt_dashboards() makes prompt-centric dashboards: " - + "the dashboards_config.prompt_centric config must exist" - ) - dataset_config = ( - dashboards_config.prompt_centric.data - if dashboards_config.prompt_centric.data - else dashboards_config.data - ) - data_loader, _ = create_data_loader(dataset_config=dataset_config, batch_size=1, buffer_size=1) - assert model.tlens_model.tokenizer is not None, "The model must have a tokenizer" - if use_model_tokenizer: - tokenizer = model.tlens_model.tokenizer - assert isinstance(tokenizer, PreTrainedTokenizer | PreTrainedTokenizerFast) - else: - tokenizer = AutoTokenizer.from_pretrained(dashboards_config.data.tokenizer_name) - vocab_dict = create_vocab_dict(tokenizer) - if dashboards_config.sae_positions is None: - raw_sae_positions: list[str] = model.raw_sae_positions - else: - raw_sae_positions: list[str] = filter_names( - list(model.tlens_model.hook_dict.keys()), dashboards_config.sae_positions - ) - - used_features: dict[str, set[int]] = {sae_name: set() for sae_name in dashboards_data} - device = model.saes[raw_sae_positions[0].replace(".", "-")].device - n_prompts = (dashboards_config.prompt_centric.n_random_prompt_dashboards + 2) // 3 - for prompt_idx, batch in tqdm( - enumerate(data_loader), - total=n_prompts, - desc="Random prompt dashboards", - ): - batch_tokens: Int[Tensor, "1 pos"] = batch[dashboards_config.data.column_name].to( - device=device - ) - assert len(batch_tokens.shape) == 2 and batch_tokens.shape[0] == 1 - # Use the tokens from the first <|endoftext|> token to the next - bos_inds = torch.argwhere(batch_tokens == tokenizer.bos_token_id)[:, 1] - if len(bos_inds) > 1: - batch_tokens = batch_tokens[:, bos_inds[0] : bos_inds[1]] - str_tokens = tokenizer.convert_ids_to_tokens(batch_tokens.squeeze(dim=0).tolist()) - assert isinstance(str_tokens, list) - seq_len: int = batch_tokens.shape[1] - # Generate dashboards for three consecutive positions in the prompt, chosen randomly - if seq_len > 4: # Ensure the prompt is long enough for three positions + next token effect - seq_pos_c = np.random.randint(1, seq_len - 3) - seq_pos = [seq_pos_c - 1, seq_pos_c, seq_pos_c + 1] - used_features_now = generate_prompt_dashboard_html_files( - model=model, - tokens=batch_tokens, - str_tokens=str_tokens, - dashboards_data=dashboards_data, - seq_pos=seq_pos, - vocab_dict=vocab_dict, - str_score=dashboards_config.prompt_centric.str_score, - save_dir=save_dir, - ) - for sae_name in used_features: - used_features[sae_name] = used_features[sae_name].union(used_features_now[sae_name]) - - if prompt_idx > n_prompts: - break - return used_features - - -@torch.inference_mode() -def generate_dashboards( - model: SAETransformer, dashboards_config: DashboardsConfig, save_dir: RootPath | None = None -) -> None: - """Generate HTML feature dashboards for an SAETransformer and save them. - - First the data for the dashboards are crated using dashboards_data = get_dashboards_data(), - then prompt-centric HTML dashboards are created (if dashboards_config.prompt_centric exists), - then feature-centric HTML dashboards are created for any features in - dashboards_config.feature_indices (all features if this is None), or any features which - appeared in prompt-centric dashboards. - Dashboards are saved in dashboards_config.save_dir - """ - if save_dir is None: - save_dir = dashboards_config.save_dir - assert save_dir is not None, ( - "generate_dashboards() saves HTML files, but no save_dir was specified in the" - + " dashboards_config or given as input" - ) - # Deal with the possible input typles of sae_positions - if dashboards_config.sae_positions is None: - raw_sae_positions = model.raw_sae_positions - else: - raw_sae_positions = filter_names( - list(model.tlens_model.hook_dict.keys()), dashboards_config.sae_positions - ) - # Deal with the possible input typles of feature_indices - feature_indices = feature_indices_to_tensordict( - dashboards_config.feature_indices, raw_sae_positions, model - ) - - # Get the data used in the dashboards - dashboards_data: dict[str, MultiFeatureData] = get_dashboards_data( - model=model, - dataset_config=dashboards_config.data, - sae_positions=raw_sae_positions, - # We need data for every feature if we're generating prompt-centric dashboards: - feature_indices=None if dashboards_config.prompt_centric else feature_indices, - n_samples=dashboards_config.n_samples, - batch_size=dashboards_config.batch_size, - minibatch_size_features=dashboards_config.minibatch_size_features, - ) - - # Generate the prompt-centric dashboards and record which features were active on them - used_features: dict[str, set[int]] = {sae_name: set() for sae_name in dashboards_data} - if dashboards_config.prompt_centric: - prompt_dashboard_saving_folder = save_dir / Path("prompt-dashboards") - prompt_dashboard_saving_folder.mkdir(parents=True, exist_ok=True) - # Generate random prompt-centric dashboards - if dashboards_config.prompt_centric.n_random_prompt_dashboards > 0: - used_features_now = generate_random_prompt_dashboards( - model=model, - dashboards_data=dashboards_data, - dashboards_config=dashboards_config, - save_dir=prompt_dashboard_saving_folder, - ) - for sae_name in used_features: - used_features[sae_name] = used_features[sae_name].union(used_features_now[sae_name]) - - # Generate dashboards for specific prompts - if dashboards_config.prompt_centric.prompts is not None: - tokenizer = AutoTokenizer.from_pretrained(dashboards_config.data.tokenizer_name) - vocab_dict = create_vocab_dict(tokenizer) - for prompt in dashboards_config.prompt_centric.prompts: - tokens = tokenizer(prompt)["input_ids"] - list_tokens = tokens.tolist() if isinstance(tokens, Tensor) else tokens - assert isinstance(list_tokens, list) - str_tokens = tokenizer.convert_ids_to_tokens(list_tokens) - assert isinstance(str_tokens, list) - used_features_now = generate_prompt_dashboard_html_files( - model=model, - tokens=torch.Tensor(tokens).to(dtype=torch.int).unsqueeze(dim=0), - str_tokens=str_tokens, - dashboards_data=dashboards_data, - str_score=dashboards_config.prompt_centric.str_score, - vocab_dict=vocab_dict, - save_dir=prompt_dashboard_saving_folder, - ) - for sae_name in used_features: - used_features[sae_name] = used_features[sae_name].union( - used_features_now[sae_name] - ) - - for sae_name in raw_sae_positions: - used_features[sae_name] = used_features[sae_name].union( - set(feature_indices[sae_name].tolist()) - ) - - # Generate the viewable HTML feature dashboard files - dashboard_html_saving_folder = save_dir / Path("feature-dashboards") - dashboard_html_saving_folder.mkdir(parents=True, exist_ok=True) - generate_feature_dashboard_html_files( - dashboards_data=dashboards_data, - feature_indices=used_features if dashboards_config.prompt_centric else feature_indices, - save_dir=dashboard_html_saving_folder, - ) - - -# Load the saved SAEs and the corresponding model -def load_SAETransformer_from_saes_paths( - pretrained_sae_paths: list[RootPath] | list[str] | None, - config_path: RootPath | str | None = None, - sae_positions: list[str] | None = None, -) -> tuple[SAETransformer, Config, list[str]]: - if pretrained_sae_paths is not None: - pretrained_sae_paths = [Path(p) for p in pretrained_sae_paths] - for path in pretrained_sae_paths: - assert path.exists(), f"pretrained_sae_path: {path} does not exist" - assert path.is_file() and ( - path.suffix == ".pt" or path.suffix == ".pth" - ), f"pretrained_sae_path: {path} is not a .pt or .pth file" - - if config_path is None: - assert ( - pretrained_sae_paths is not None - ), "Either config_path or pretrained_sae_paths must be provided" - config_path = pretrained_sae_paths[0].parent / "config.yaml" - config_path = Path(config_path) - assert config_path.exists(), f"config_path: {config_path} does not exist" - assert ( - config_path.is_file() and config_path.suffix == ".yaml" - ), f"config_path: {config_path} does not exist" - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - config = load_config(config_path, config_model=Config) - if pretrained_sae_paths is None: - pretrained_sae_paths = config.saes.pretrained_sae_paths - assert pretrained_sae_paths is not None, "pretrained_sae_paths must be given or in config" - logger.info(config) - - tlens_model = load_tlens_model( - tlens_model_name=config.tlens_model_name, tlens_model_path=config.tlens_model_path - ) - assert tlens_model is not None - - if sae_positions is None: - sae_positions = config.saes.sae_positions - - raw_sae_positions = filter_names(list(tlens_model.hook_dict.keys()), sae_positions) - model = SAETransformer( - tlens_model=tlens_model, - raw_sae_positions=raw_sae_positions, - dict_size_to_input_ratio=config.saes.dict_size_to_input_ratio, - init_decoder_orthogonal=False, - ).to(device=device) - - all_param_names = [name for name, _ in model.saes.named_parameters()] - trainable_param_names = load_pretrained_saes( - saes=model.saes, - pretrained_sae_paths=pretrained_sae_paths, - all_param_names=all_param_names, - retrain_saes=config.saes.retrain_saes, - ) - return model, config, trainable_param_names - - -def main( - config_path_or_obj: Path | str | DashboardsConfig, - pretrained_sae_paths: Path | str | list[Path] | list[str] | None, -) -> None: - dashboards_config = load_config(config_path_or_obj, config_model=DashboardsConfig) - logger.info(dashboards_config) - - if pretrained_sae_paths is None: - assert ( - dashboards_config.pretrained_sae_paths is not None - ), "pretrained_sae_paths must be provided, either in the dashboards config or as an input" - pretrained_sae_paths = dashboards_config.pretrained_sae_paths - else: - pretrained_sae_paths = ( - pretrained_sae_paths - if isinstance(pretrained_sae_paths, list) - else [Path(pretrained_sae_paths)] - ) - - logger.info("Loading the model and SAEs") - model, _, _ = load_SAETransformer_from_saes_paths( - pretrained_sae_paths, dashboards_config.sae_config_path - ) - logger.info("done") - - save_dir = dashboards_config.save_dir or Path(pretrained_sae_paths[0]).parent - logger.info(f"The HTML dashboards will be saved in {save_dir}") - generate_dashboards(model, dashboards_config, save_dir=save_dir) - - -if __name__ == "__main__": - fire.Fire(main) diff --git a/pyproject.toml b/pyproject.toml index 6e79c78..f5df53d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,15 +12,18 @@ dependencies = [ "wandb~=0.16.2", "fire~=0.5.0", "tqdm~=4.66.1", - "pytest~=8.0.0", + "pytest~=8.1.2", "ipykernel~=6.29.0", "transformer-lens~=1.14.0", "jaxtyping~=0.2.25", "python-dotenv~=1.0.1", "zstandard~=0.22.0", - "matplotlib>=3.5.3", + "matplotlib~=3.5.3", + "seaborn~=0.13.2", + "tenacity~=8.2.3", + "statsmodels~=0.14.2", "eindex-callum@git+https://github.com/callummcdougall/eindex", - "sae_vis@git+https://github.com/callummcdougall/sae_vis.git@b28a0f7c7e936f4bea05528d952dfcd438533cce" + "neuron_explainer@git+https://github.com/ApolloResearch/automated-interpretability.git" ] [project.urls] @@ -29,7 +32,7 @@ repository = "https://github.com/ApolloResearch/e2e_sae" [project.optional-dependencies] dev = [ "ruff~=0.1.14", - "pyright~=1.1.357", + "pyright~=1.1.360", "pre-commit~=3.6.0", ] diff --git a/tests/test_dashboards.py b/tests/test_dashboards.py deleted file mode 100644 index f72beb4..0000000 --- a/tests/test_dashboards.py +++ /dev/null @@ -1,124 +0,0 @@ -from pathlib import Path - -import pytest -import torch -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast - -from e2e_sae.data import DatasetConfig -from e2e_sae.loader import load_tlens_model -from e2e_sae.models.transformers import SAETransformer -from e2e_sae.scripts.generate_dashboards import ( - DashboardsConfig, - compute_feature_acts, - create_vocab_dict, - generate_dashboards, -) -from e2e_sae.utils import set_seed -from tests.utils import get_tinystories_config - -Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast - - -@pytest.fixture(scope="module") -def tinystories_model() -> SAETransformer: - tlens_model = load_tlens_model( - tlens_model_name="roneneldan/TinyStories-1M", tlens_model_path=None - ) - sae_position = "blocks.2.hook_resid_post" - config = get_tinystories_config({"saes": {"sae_positions": sae_position}}) - model = SAETransformer( - tlens_model=tlens_model, - raw_sae_positions=[sae_position], - dict_size_to_input_ratio=config.saes.dict_size_to_input_ratio, - ) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - return model - - -def test_compute_feature_acts(tinystories_model: SAETransformer): - set_seed(0) - prompt = "Once upon" - tokenizer = tinystories_model.tlens_model.tokenizer - assert tokenizer is not None - tokens = tokenizer(prompt, return_tensors="pt")["input_ids"] - assert isinstance(tokens, torch.Tensor) - feature_indices = {name: list(range(7)) for name in tinystories_model.raw_sae_positions} - feature_acts, final_resid_acts = compute_feature_acts( - tinystories_model, tokens, feature_indices=feature_indices - ) - for sae_name, acts in feature_acts.items(): - assert acts.shape[0] == 1 # batch size - assert acts.shape[2] == 7 # feature_indices - - -def test_create_vocab_dict(tinystories_model: SAETransformer): - tokenizer = tinystories_model.tlens_model.tokenizer - assert tokenizer is not None - vocab_dict = create_vocab_dict(tokenizer) - assert isinstance(tokenizer, PreTrainedTokenizerFast) - assert len(vocab_dict) == len(tokenizer.vocab) - for token_id, token_str in vocab_dict.items(): - assert isinstance(token_id, int) - assert isinstance(token_str, str) - - -def check_valid_feature_dashboard_htmls(folder: Path): - assert folder.exists() - for html_file in folder.iterdir(): - assert html_file.name.endswith(".html") - assert html_file.exists() - with open(html_file) as f: - html_content = f.read() - assert isinstance(html_content, str) - assert len(html_content) > 100 - assert "Plotly.newPlot('histogram-acts'" in html_content - assert '
'" in html_content - assert "Feature #" in html_content - assert '