From 7cbc9326ecc83c66efbcde0bbbda009ebdc94d22 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 29 Apr 2024 11:50:33 +0100 Subject: [PATCH 01/12] Make autointerp script --- e2e_sae/scripts/autointerp.py | 530 +++++++++++ e2e_sae/scripts/generate_dashboards.py | 1214 ------------------------ pyproject.toml | 8 +- tests/test_dashboards.py | 124 --- 4 files changed, 535 insertions(+), 1341 deletions(-) create mode 100644 e2e_sae/scripts/autointerp.py delete mode 100644 e2e_sae/scripts/generate_dashboards.py delete mode 100644 tests/test_dashboards.py diff --git a/e2e_sae/scripts/autointerp.py b/e2e_sae/scripts/autointerp.py new file mode 100644 index 0000000..2cb3866 --- /dev/null +++ b/e2e_sae/scripts/autointerp.py @@ -0,0 +1,530 @@ +import asyncio +import glob +import json +import os +import random +from datetime import datetime +from typing import Any, Literal, TypeVar + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import requests +import seaborn as sns +from dotenv import load_dotenv + +load_dotenv() # Required before importing UncalibratedNeuronSimulator +from neuron_explainer.activations.activation_records import calculate_max_activation +from neuron_explainer.activations.activations import ActivationRecord +from neuron_explainer.explanations.calibrated_simulator import UncalibratedNeuronSimulator +from neuron_explainer.explanations.explainer import ContextSize, TokenActivationPairExplainer +from neuron_explainer.explanations.explanations import ScoredSimulation +from neuron_explainer.explanations.few_shot_examples import FewShotExampleSet +from neuron_explainer.explanations.prompt_builder import PromptFormat +from neuron_explainer.explanations.scoring import ( + _simulate_and_score_sequence, + aggregate_scored_sequence_simulations, +) +from neuron_explainer.explanations.simulator import ( + LogprobFreeExplanationTokenSimulator, + NeuronSimulator, +) +from tenacity import retry, stop_after_attempt, wait_random_exponential +from tqdm import tqdm + +NEURONPEDIA_DOMAIN = "https://neuronpedia.org" +POSITIVE_INF_REPLACEMENT = 9999 +NEGATIVE_INF_REPLACEMENT = -9999 +NAN_REPLACEMENT = 0 +OTHER_INVALID_REPLACEMENT = -99999 + + +def NanAndInfReplacer(value: str): + """Replace NaNs and Infs in outputs.""" + replacements = { + "-Infinity": NEGATIVE_INF_REPLACEMENT, + "Infinity": POSITIVE_INF_REPLACEMENT, + "NaN": NAN_REPLACEMENT, + } + if value in replacements: + replaced_value = replacements[value] + return float(replaced_value) + else: + return NAN_REPLACEMENT + + +def get_neuronpedia_feature( + feature: int, layer: int, model: str = "gpt2-small", dataset: str = "res-jb" +) -> dict[str, Any]: + """Fetch a feature from Neuronpedia API.""" + url = f"{NEURONPEDIA_DOMAIN}/api/feature/{model}/{layer}-{dataset}/{feature}" + result = requests.get(url).json() + result["index"] = int(result["index"]) + return result + + +def test_key(api_key: str): + """Test the validity of the Neuronpedia API key.""" + url = f"{NEURONPEDIA_DOMAIN}/api/test" + body = {"apiKey": api_key} + response = requests.post(url, json=body) + if response.status_code != 200: + raise Exception("Neuronpedia API key is not valid.") + + +class NeuronpediaActivation: + """Represents an activation from Neuronpedia.""" + + def __init__(self, id: str, tokens: list[str], act_values: list[float]): + self.id = id + self.tokens = tokens + self.act_values = act_values + + +class NeuronpediaFeature: + """Represents a feature from Neuronpedia.""" + + def __init__( + self, + modelId: str, + layer: int, + dataset: str, + feature: int, + description: str = "", + activations: list[NeuronpediaActivation] | None = None, + autointerp_explanation: str = "", + autointerp_explanation_score: float = 0.0, + ): + self.modelId = modelId + self.layer = layer + self.dataset = dataset + self.feature = feature + self.description = description + self.activations = activations or [] + self.autointerp_explanation = autointerp_explanation + self.autointerp_explanation_score = autointerp_explanation_score + + def has_activating_text(self) -> bool: + """Check if the feature has activating text.""" + return any(max(activation.act_values) > 0 for activation in self.activations) + + +T = TypeVar("T") + + +@retry(wait=wait_random_exponential(min=1, max=500), stop=stop_after_attempt(10)) +def sleep_identity(x: T) -> T: + """Dummy function for retrying.""" + return x + + +@retry(wait=wait_random_exponential(min=1, max=500), stop=stop_after_attempt(10)) +async def simulate_and_score( + simulator: NeuronSimulator, activation_records: list[ActivationRecord] +) -> ScoredSimulation: + """Score an explanation of a neuron by how well it predicts activations on the given text sequences.""" + scored_sequence_simulations = await asyncio.gather( + *[ + sleep_identity( + _simulate_and_score_sequence( + simulator, + activation_record, + ) + ) + for activation_record in activation_records + ] + ) + return aggregate_scored_sequence_simulations(scored_sequence_simulations) + + +async def autointerp_neuronpedia_features( + features: list[NeuronpediaFeature], + openai_api_key: str, + autointerp_retry_attempts: int = 3, + autointerp_score_max_concurrent: int = 20, + neuronpedia_api_key: str = "", + do_score: bool = True, + output_dir: str = "neuronpedia_outputs/autointerp", + num_activations_to_use: int = 20, + max_explanation_activation_records: int = 20, + upload_to_neuronpedia: bool = True, + autointerp_explainer_model_name: Literal[ + "gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview" + ] = "gpt-4-1106-preview", + autointerp_scorer_model_name: Literal[ + "gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview" + ] = "gpt-3.5-turbo", +): + """ + Autointerp Neuronpedia features. + + Args: + features: List of NeuronpediaFeature objects. + openai_api_key: OpenAI API key. + autointerp_retry_attempts: Number of retry attempts for autointerp. + autointerp_score_max_concurrent: Maximum number of concurrent requests for autointerp scoring. + neuronpedia_api_key: Neuronpedia API key. + do_score: Whether to score the features. + output_dir: Output directory for saving the results. + num_activations_to_use: Number of activations to use. + max_explanation_activation_records: Maximum number of activation records for explanation. + upload_to_neuronpedia: Whether to upload the results to Neuronpedia. + autointerp_explainer_model_name: Model name for autointerp explainer. + autointerp_scorer_model_name: Model name for autointerp scorer. + + Returns: + None + """ + print("\n\n") + + os.environ["OPENAI_API_KEY"] = openai_api_key + + if upload_to_neuronpedia and not neuronpedia_api_key: + raise Exception( + "You need to provide a Neuronpedia API key to upload the results to Neuronpedia." + ) + + test_key(neuronpedia_api_key) + + print("\n\n=== Step 1) Fetching features from Neuronpedia") + for feature in features: + feature_data = get_neuronpedia_feature( + feature=feature.feature, + layer=feature.layer, + model=feature.modelId, + dataset=feature.dataset, + ) + + if "modelId" not in feature_data: + raise Exception( + f"Feature {feature.feature} in layer {feature.layer} of model {feature.modelId} and dataset {feature.dataset} does not exist." + ) + + if "activations" not in feature_data or len(feature_data["activations"]) == 0: + raise Exception( + f"Feature {feature.feature} in layer {feature.layer} of model {feature.modelId} and dataset {feature.dataset} does not have activations." + ) + + activations = feature_data["activations"] + activations_to_add = [] + for activation in activations: + if len(activations_to_add) < num_activations_to_use: + activations_to_add.append( + NeuronpediaActivation( + id=activation["id"], + tokens=activation["tokens"], + act_values=activation["values"], + ) + ) + feature.activations = activations_to_add + + if not feature.has_activating_text(): + raise Exception( + f"Feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature} appears dead - it does not have activating text." + ) + + for iteration_num, feature in enumerate(features): + start_time = datetime.now() + + print( + f"\n========== Feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature} ({iteration_num + 1} of {len(features)} Features) ==========" + ) + print( + f"\n=== Step 2) Explaining feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature}" + ) + + activation_records = [ + ActivationRecord(tokens=activation.tokens, activations=activation.act_values) + for activation in feature.activations + ] + + activation_records_explaining = activation_records[:max_explanation_activation_records] + + explainer = TokenActivationPairExplainer( + model_name=autointerp_explainer_model_name, + prompt_format=PromptFormat.HARMONY_V4, + context_size=ContextSize.SIXTEEN_K, + max_concurrent=1, + ) + + explanations = [] + for _ in range(autointerp_retry_attempts): + try: + explanations = await explainer.generate_explanations( + all_activation_records=activation_records_explaining, + max_activation=calculate_max_activation(activation_records_explaining), + num_samples=1, + ) + except Exception as e: + print(f"ERROR, RETRYING: {e}") + else: + break + else: + print( + f"ERROR: Failed to explain feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature}" + ) + + assert len(explanations) == 1 + explanation = explanations[0].rstrip(".") + print(f"===== {autointerp_explainer_model_name}'s explanation: {explanation}") + feature.autointerp_explanation = explanation + + if do_score: + print( + f"\n=== Step 3) Scoring feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature}" + ) + print("=== This can take up to 30 seconds.") + + temp_activation_records = [ + ActivationRecord( + tokens=[ + token.replace("<|endoftext|>", "<|not_endoftext|>") + .replace(" 55", "_55") + .encode("ascii", errors="backslashreplace") + .decode("ascii") + for token in activation_record.tokens + ], + activations=activation_record.activations, + ) + for activation_record in activation_records + ] + + score = None + scored_simulation = None + for _ in range(autointerp_retry_attempts): + try: + simulator = UncalibratedNeuronSimulator( + LogprobFreeExplanationTokenSimulator( + autointerp_scorer_model_name, + explanation, + json_mode=True, + max_concurrent=autointerp_score_max_concurrent, + few_shot_example_set=FewShotExampleSet.JL_FINE_TUNED, + prompt_format=PromptFormat.HARMONY_V4, + ) + ) + scored_simulation = await simulate_and_score(simulator, temp_activation_records) + score = scored_simulation.get_preferred_score() + except Exception as e: + print(f"ERROR, RETRYING: {e}") + else: + break + + if ( + score is None + or scored_simulation is None + or len(scored_simulation.scored_sequence_simulations) != num_activations_to_use + ): + print( + f"ERROR: Failed to score feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature}. Skipping it." + ) + continue + feature.autointerp_explanation_score = score + print(f"===== {autointerp_scorer_model_name}'s score: {(score * 100):.0f}") + + output_file = ( + f"{output_dir}/{feature.layer}-{feature.dataset}_feature-{feature.feature}_score-" + f"{feature.autointerp_explanation_score}_time-{datetime.now().strftime('%Y%m%d-%H%M%S')}.jsonl" + ) + os.makedirs(output_dir, exist_ok=True) + print(f"===== Your results will be saved to: {output_file} =====") + + output_data = json.dumps( + { + "apiKey": neuronpedia_api_key, + "feature": { + "modelId": feature.modelId, + "layer": f"{feature.layer}-{feature.dataset}", + "index": feature.feature, + "activations": feature.activations, + "explanation": feature.autointerp_explanation, + "explanationScore": feature.autointerp_explanation_score, + "explanationModel": autointerp_explainer_model_name, + "autointerpModel": autointerp_scorer_model_name, + "simulatedActivations": scored_simulation.scored_sequence_simulations, + }, + }, + default=vars, + ) + output_data_json = json.loads(output_data, parse_constant=NanAndInfReplacer) + output_data_str = json.dumps(output_data) + + print(f"\n=== Step 4) Saving feature to {output_file}") + with open(output_file, "a") as f: + f.write(output_data_str) + f.write("\n") + + if upload_to_neuronpedia: + print( + f"\n=== Step 5) Uploading feature to Neuronpedia: {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature}" + ) + url = f"{NEURONPEDIA_DOMAIN}/api/upload-explanation" + body = output_data_json + response = requests.post(url, json=body) + if response.status_code != 200: + print(f"ERROR: Couldn't upload explanation to Neuronpedia: {response.text}") + + end_time = datetime.now() + print(f"\n========== Time Spent for Feature: {end_time - start_time}\n") + + print("\n\n========== Generation and Upload Complete ==========\n\n") + + +def run_autointerp( + sae_sets: list[list[str]], + n_random_features: int, + dict_size: int, + feature_model_id: str, + autointerp_explainer_model_name: Literal["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview"], + autointerp_scorer_model_name: Literal["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview"], +): + """Explain and score random features across SAEs and upload to Neuronpedia. + + Args: + sae_sets: List of SAE sets. + n_random_features: Number of random features to explain and score for each SAE. + dict_size: Size of the SAE dictionary. + feature_model_id: Model ID that the SAE is attached to. + autointerp_explainer_model_name: Model name for autointerp explainer. + autointerp_scorer_model_name: Model name for autointerp scorer (much more expensive). + """ + for sae_set in tqdm(sae_sets, desc="sae sets"): + for i in tqdm(range(n_random_features), desc="random features"): + for sae in tqdm(sae_set, desc="sae"): + layer = int(sae.split("-")[0]) + dataset = "-".join(sae.split("-")[1:]) + feature_exists = False + while not feature_exists: + feature = random.randint(0, dict_size) + feature_data = get_neuronpedia_feature( + feature=feature, + layer=layer, + model=feature_model_id, + dataset=dataset, + ) + if "activations" in feature_data and len(feature_data["activations"]) >= 20: + feature_exists = True + print(f"sae: {sae}, feature: {feature}") + + features = [ + NeuronpediaFeature( + modelId=feature_model_id, + layer=layer, + dataset=dataset, + feature=feature, + ), + ] + + asyncio.run( + autointerp_neuronpedia_features( + features=features, + openai_api_key=os.getenv("OPENAI_API_KEY", ""), + neuronpedia_api_key=os.getenv("NEURONPEDIA_API_KEY", ""), + autointerp_explainer_model_name=autointerp_explainer_model_name, + autointerp_scorer_model_name=autointerp_scorer_model_name, + num_activations_to_use=20, + max_explanation_activation_records=5, + ) + ) + + feature_url = f"https://neuronpedia.org/{feature_model_id}/{layer}-{dataset}/{feature}" + print(f"Your feature is at: {feature_url}") + + +def compare_autointerp_results(output_dir: str): + """ + Compare autointerp results across SAEs. + + Args: + output_dir: Directory containing the autointerp output files. + + Returns: + None + """ + autointerp_files = glob.glob(f"{output_dir}/*_feature-*_score-*") + stats = { + "layer": [], + "sae": [], + "feature": [], + "explanationScore": [], + "explanationModel": [], + "autointerpModel": [], + "explanation": [], + } + for autointerp_file in tqdm(autointerp_files, "constructing stats from json autointerp files"): + with open(autointerp_file) as f: + json_data = json.loads(json.load(f)) + if "feature" in json_data: + json_data = json_data["feature"] + stats["layer"].append(int(json_data["layer"].split("-")[0])) + stats["sae"].append("-".join(json_data["layer"].split("-")[1:])) + stats["feature"].append(json_data["index"]) + stats["explanationScore"].append(json_data["explanationScore"]) + stats["autointerpModel"].append(json_data["autointerpModel"]) + stats["explanationModel"].append("gpt-4-1106-preview") + stats["explanation"].append(json_data["explanation"]) + df_stats = pd.DataFrame(stats) + + sns.set_theme(style="whitegrid", palette="pastel") + g = sns.catplot( + data=df_stats, + x="sae", + y="explanationScore", + kind="box", + palette="pastel", + hue="layer", + showmeans=True, + meanprops={ + "marker": "o", + "markerfacecolor": "white", + "markeredgecolor": "black", + "markersize": "10", + }, + ) + sns.swarmplot( + data=df_stats, x="sae", y="explanationScore", color="k", size=4, ax=g.ax, legend=False + ) + plt.title("Quality of auto-interpretability explanations across SAEs") + plt.ylabel("Auto-interpretability score") + plt.xlabel("SAE") + plt.savefig("Auto-interpretability_results.png", bbox_inches="tight") + plt.savefig("Auto-interpretability_results.pdf", bbox_inches="tight") + plt.show() + + sns.set_theme(rc={"figure.figsize": (6, 6)}) + b = sns.barplot( + df_stats, + x="sae", + y="explanationScore", + palette="pastel", + hue="layer", + capsize=0.3, + legend=False, + ) + s = sns.swarmplot(data=df_stats, x="sae", y="explanationScore", color="k", alpha=0.25, size=6) + plt.title("Quality of auto-interpretability explanations across SAEs") + plt.ylabel("Auto-interpretability score") + plt.xlabel("SAE") + plt.yticks(np.arange(0, 1, 0.1)) + plt.savefig("Auto-interpretability_results_bar.png", bbox_inches="tight") + plt.savefig("Auto-interpretability_results_bar.pdf", bbox_inches="tight") + plt.show() + + +if __name__ == "__main__": + sae_sets = [ + ["6-res_sll-ajt", "6-res_scl-ajt", "6-res_scefr-ajt"], + ["10-res_scefr-ajt", "10-res_scl-ajt", "10-res_sll-ajt"], + ["2-res_scefr-ajt", "2-res_scl-ajt"], + ["2-res_slefr-ajt", "2-res_sll-ajt"], + ] + run_autointerp( + sae_sets=sae_sets, + n_random_features=2, + dict_size=768 * 60, + feature_model_id="gpt2-small", + autointerp_explainer_model_name="gpt-4-1106-preview", + autointerp_scorer_model_name="gpt-3.5-turbo", + ) + + compare_autointerp_results("neuronpedia_outputs/autointerp") diff --git a/e2e_sae/scripts/generate_dashboards.py b/e2e_sae/scripts/generate_dashboards.py deleted file mode 100644 index 14b95b3..0000000 --- a/e2e_sae/scripts/generate_dashboards.py +++ /dev/null @@ -1,1214 +0,0 @@ -"""Script for generating HTML feature dashboards -Usage: - $ python generate_dashboards.py - (Generates dashboards for the SAEs in ) - or - $ python generate_dashboards.py - (Requires that a path to the sae.pt file is provided in dashboards_config.pretrained_sae_paths) - -dashboard HTML files be saved in dashboards_config.save_dir - -Two types of dashboards can be created: - feature-centric: - These are individual dashboards for each feature specified in - dashboards_config.feature_indices, showing that feature's max activating examples, facts - about the distribution of when it is active, and what it promotes through the logit lens. - feature-centric dashboards will also be generated for all of the top features which apppear - in the prompt-centric dashboards. Saved in dashboards_config.save_dir/dashboards_{sae_name} - prompt-centric: - Given a prompt and a specific token position within it, find the most important features - active at that position, and make a dashboard showing where they all activate. There are - three ways of measuring the importance of features: "act_size" (show the features which - activated most strongly), "act_quantile" (show the features which activated much more than - they usually do), and "loss_effect" (show the features with the biggest logit-lens ablation - effect for predicting the correct next token - default). - Saved in dashboards_config.save_dir/prompt_dashboards - -This script currently relies on an old commit of Callum McDouglal's sae_vis package: -https://github.com/callummcdougall/sae_vis/commit/b28a0f7c7e936f4bea05528d952dfcd438533cce -""" -import math -from collections.abc import Iterable -from pathlib import Path -from typing import Annotated, Literal - -import fire -import numpy as np -import torch -from eindex import eindex -from einops import einsum, rearrange -from jaxtyping import Float, Int -from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, NonNegativeInt, PositiveInt -from sae_vis.data_fetching_fns import get_sequences_data -from sae_vis.data_storing_fns import ( - FeatureData, - FeatureVisParams, - HistogramData, - MiddlePlotsData, - MultiFeatureData, - MultiPromptData, - PromptData, - SequenceData, - SequenceMultiGroupData, -) -from sae_vis.utils_fns import QuantileCalculator, TopK, process_str_tok -from torch import Tensor -from torch.utils.data.dataset import IterableDataset -from tqdm import tqdm -from transformers import ( - AutoTokenizer, - PreTrainedTokenizer, - PreTrainedTokenizerBase, - PreTrainedTokenizerFast, -) - -from e2e_sae.data import DatasetConfig, create_data_loader -from e2e_sae.loader import load_pretrained_saes, load_tlens_model -from e2e_sae.log import logger -from e2e_sae.models.transformers import SAETransformer -from e2e_sae.scripts.train_tlens_saes.run_train_tlens_saes import Config -from e2e_sae.types import RootPath -from e2e_sae.utils import filter_names, load_config, to_numpy - -FeatureIndicesType = dict[str, list[int]] | dict[str, Int[Tensor, "some_feats"]] # noqa: F821 (jaxtyping/pyright doesn't like single dimensions) -StrScoreType = Literal["act_size", "act_quantile", "loss_effect"] - - -class PromptDashboardsConfig(BaseModel): - model_config = ConfigDict(extra="forbid", frozen=True) - n_random_prompt_dashboards: NonNegativeInt = Field( - default=50, - description="The number of random prompts to generate prompt-centric dashboards for." - "A feature-centric dashboard will be generated for random token positions in each prompt.", - ) - data: DatasetConfig | None = Field( - default=None, - description="DatasetConfig for getting random prompts." - "If None, then DashboardsConfig.data will be used", - ) - prompts: list[str] | None = Field( - default=None, - description="Specific prompts on which to generate prompt-centric feature dashboards. " - "A feature-centric dashboard will be generated for every token position in each prompt.", - ) - str_score: StrScoreType = Field( - default="loss_effect", - description="The ordering metric for which features are most important in prompt-centric " - "dashboards. Can be one of 'act_size', 'act_quantile', or 'loss_effect'", - ) - num_top_features: PositiveInt = Field( - default=10, - description="How many of the most relevant features to show for each prompt" - " in the prompt-centric dashboards", - ) - - -class DashboardsConfig(BaseModel): - model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True, frozen=True) - pretrained_sae_paths: Annotated[ - list[RootPath] | None, BeforeValidator(lambda x: [x] if isinstance(x, str | Path) else x) - ] = Field(None, description="Paths of the pretrained SAEs to load") - sae_config_path: RootPath | None = Field( - default=None, - description="Path to the config file used to train the SAEs" - " (if null, we'll assume it's at pretrained_sae_paths[0].parent / 'config.yaml')", - ) - n_samples: PositiveInt | None = None - batch_size: PositiveInt - minibatch_size_features: PositiveInt | None = Field( - default=256, - description="Num features in each batch of calculations (i.e. we break up the features to " - "avoid OOM errors).", - ) - data: DatasetConfig = Field( - description="DatasetConfig for the data which will be used to generate the dashboards", - ) - save_dir: RootPath | None = Field( - default=None, - description="The directory for saving the HTML feature dashboard files", - ) - sae_positions: Annotated[ - list[str] | None, BeforeValidator(lambda x: [x] if isinstance(x, str) else x) - ] = Field( - None, - description="The names of the SAE positions to generate dashboards for. " - "e.g. 'blocks.2.hook_resid_post'. If None, then all positions will be generated", - ) - feature_indices: FeatureIndicesType | list[int] | None = Field( - default=None, - description="The features for which to generate dashboards on each SAE. If none, then " - "we'll generate dashbaords for every feature.", - ) - prompt_centric: PromptDashboardsConfig | None = Field( - default=None, - description="Used to generate prompt-centric (rather than feature-centric) dashboards." - " Feature-centric dashboards will also be generated for every feature appaearing in these", - ) - seed: NonNegativeInt = 0 - - -def compute_feature_acts( - model: SAETransformer, - tokens: Int[Tensor, "batch pos"], - raw_sae_positions: list[str] | None = None, - feature_indices: FeatureIndicesType | None = None, - stop_at_layer: int = -1, -) -> tuple[dict[str, Float[Tensor, "... some_feats"]], Float[Tensor, "... dim"]]: - """Compute the activations of the SAEs in the model given a tensor of input tokens - - Args: - model: The SAETransformer containing the SAEs and the tlens_model - tokens: The inputs to the tlens_model - raw_sae_positions: The names of the SAEs we're interested in - feature_indices: The indices of the features we're interested in for each SAE - stop_at_layer: Where to stop the forward pass. final_resid_acts will be returned from here - - Returns: - - A dict of feature activations for each SAE. - feature_acts[sae_position_name] = the feature activations of that SAE - shape: batch pos some_feats - - The residual stream activations of the model at the final layer (or at stop_at_layer) - """ - if raw_sae_positions is None: - raw_sae_positions = model.raw_sae_positions - # Run model without SAEs - final_resid_acts, orig_acts = model.tlens_model.run_with_cache( - tokens, - names_filter=raw_sae_positions, - return_cache_object=False, - stop_at_layer=stop_at_layer, - ) - assert isinstance(final_resid_acts, Tensor) - feature_acts: dict[str, Float[Tensor, "... some_feats"]] = {} - # Run the activations through the SAEs - for hook_name in orig_acts: - sae = model.saes[hook_name.replace(".", "-")] - output, feature_acts[hook_name] = sae(orig_acts[hook_name]) - del output - feature_acts[hook_name] = feature_acts[hook_name].to("cpu") - if feature_indices is not None: - feature_acts[hook_name] = feature_acts[hook_name][..., feature_indices[hook_name]] - return feature_acts, final_resid_acts - - -def compute_feature_acts_on_distribution( - model: SAETransformer, - dataset_config: DatasetConfig, - batch_size: PositiveInt, - n_samples: PositiveInt | None = None, - raw_sae_positions: list[str] | None = None, - feature_indices: FeatureIndicesType | None = None, - stop_at_layer: int = -1, -) -> tuple[ - dict[str, Float[Tensor, "... some_feats"]], Float[Tensor, "... d_resid"], Int[Tensor, "..."] -]: - """Compute the activations of the SAEs in the model on a dataset of input tokens - - Args: - model: The SAETransformer containing the SAEs and the tlens_model - dataset_config: The DatasetConfig used to get the data loader for the tokens. - batch_size: The batch size of data run through the model when calculating the feature acts - n_samples: The number of batches of data to use for calculating the feature dashboard data - raw_sae_positions: The names of the SAEs we're interested in. If none, do all SAEs. - feature_indices: The indices of the features we're interested in for each SAE. If none, do - all features. - stop_at_layer: Where to stop the forward pass. final_resid_acts will be returned from here - - Returns: - - a dict of SAE inputs, activations, and outputs for each SAE. - feature_acts[sae_position_name] = the feature activations of that SAE - shape: batch pos feats (or # feature_indices) - - The residual stream activations of the model at the final layer (or at stop_at_layer) - - The tokens used as input to the model - """ - data_loader, _ = create_data_loader( - dataset_config, batch_size=batch_size, buffer_size=batch_size - ) - if raw_sae_positions is None: - raw_sae_positions = model.raw_sae_positions - assert raw_sae_positions is not None - device = model.saes[raw_sae_positions[0].replace(".", "-")].device - if n_samples is None: - # If streaming (i.e. if the dataset is an IterableDataset), we don't know the length - n_batches = None if isinstance(data_loader.dataset, IterableDataset) else len(data_loader) - else: - n_batches = math.ceil(n_samples / batch_size) - if not isinstance(data_loader.dataset, IterableDataset): - n_batches = min(n_batches, len(data_loader)) - - total_samples = 0 - feature_acts_lists: dict[str, list[Float[Tensor, "... some_feats"]]] = { - sae_name: [] for sae_name in raw_sae_positions - } - final_resid_acts_list: list[Float[Tensor, "... d_resid"]] = [] - tokens_list: list[Int[Tensor, "..."]] = [] - for batch in tqdm(data_loader, total=n_batches, desc="Computing feature acts"): - batch_tokens: Int[Tensor, "..."] = batch[dataset_config.column_name].to(device=device) - batch_feature_acts, batch_final_resid_acts = compute_feature_acts( - model=model, - tokens=batch_tokens, - raw_sae_positions=raw_sae_positions, - feature_indices=feature_indices, - stop_at_layer=stop_at_layer, - ) - for sae_name in raw_sae_positions: - feature_acts_lists[sae_name].append(batch_feature_acts[sae_name]) - final_resid_acts_list.append(batch_final_resid_acts) - tokens_list.append(batch_tokens) - total_samples += batch_tokens.shape[0] - if n_samples is not None and total_samples > n_samples: - break - final_resid_acts: Float[Tensor, "... d_resid"] = torch.cat(final_resid_acts_list, dim=0) - tokens: Int[Tensor, "..."] = torch.cat(tokens_list, dim=0) - feature_acts: dict[str, Float[Tensor, "... some_feats"]] = {} - for sae_name in raw_sae_positions: - feature_acts[sae_name] = torch.cat(tensors=feature_acts_lists[sae_name], dim=0) - return feature_acts, final_resid_acts, tokens - - -def create_vocab_dict(tokenizer: PreTrainedTokenizerBase) -> dict[int, str]: - """ - Creates a vocab dict suitable for dashboards by replacing all the special tokens with their - HTML representations. This function is adapted from sae_vis.create_vocab_dict() - """ - vocab_dict: dict[str, int] = tokenizer.get_vocab() - vocab_dict_processed: dict[int, str] = {v: process_str_tok(k) for k, v in vocab_dict.items()} - return vocab_dict_processed - - -@torch.inference_mode() -def parse_activation_data( - tokens: Int[Tensor, "batch pos"], - feature_acts: Float[Tensor, "... some_feats"], - final_resid_acts: Float[Tensor, "... d_resid"], - feature_resid_dirs: Float[Tensor, "some_feats dim"], - feature_indices_list: Iterable[int], - W_U: Float[Tensor, "dim d_vocab"], - vocab_dict: dict[int, str], - fvp: FeatureVisParams, -) -> MultiFeatureData: - """Convert generic activation data into a MultiFeatureData object, which can be used to create - the feature-centric visualisation. - Adapted from sae_vis.data_fetching_fns._get_feature_data() - final_resid_acts + W_U are used for the logit lens. - - Args: - tokens: The inputs to the model - feature_acts: The activations values of the features - final_resid_acts: The activations of the final layer of the model - feature_resid_dirs: The directions that each feature writes to the logit output - feature_indices_list: The indices of the features we're interested in - W_U: The unembed weights for the logit lens - vocab_dict: A dictionary mapping vocab indices to strings - fvp: FeatureVisParams, containing a bunch of settings. See the FeatureVisParams docstring in - sae_vis for more information. - - Returns: - A MultiFeatureData containing data for creating each feature's visualization, - as well as data for rank-ordering the feature visualizations when it comes time - to make the prompt-centric view (the `feature_act_quantiles` attribute). - Use MultiFeatureData[feature_idx].get_html() to generate the HTML dashboard for a - particular feature (returns a string of HTML). - - """ - device = W_U.device - feature_acts.to(device) - sequence_data_dict: dict[int, SequenceMultiGroupData] = {} - middle_plots_data_dict: dict[int, MiddlePlotsData] = {} - features_data: dict[int, FeatureData] = {} - # Calculate all data for the right-hand visualisations, i.e. the sequences - for i, feat in enumerate(feature_indices_list): - # Add this feature's sequence data to the list - sequence_data_dict[feat] = get_sequences_data( - tokens=tokens, - feat_acts=feature_acts[..., i], - resid_post=final_resid_acts, - feature_resid_dir=feature_resid_dirs[i], - W_U=W_U, - fvp=fvp, - ) - - # Get the logits of all features (i.e. the directions this feature writes to the logit output) - logits = einsum( - feature_resid_dirs, - W_U, - "feats d_model, d_model d_vocab -> feats d_vocab", - ) - for i, (feat, logit) in enumerate(zip(feature_indices_list, logits, strict=True)): - # Get data for logits (the histogram, and the table) - logits_histogram_data = HistogramData(logit, 40, "5 ticks") - top10_logits = TopK(logit, k=15, largest=True) - bottom10_logits = TopK(logit, k=15, largest=False) - - # Get data for feature activations histogram (the title, and the histogram) - feat_acts = feature_acts[..., i] - nonzero_feat_acts = feat_acts[feat_acts > 0] - frac_nonzero = nonzero_feat_acts.numel() / feat_acts.numel() - freq_histogram_data = HistogramData(nonzero_feat_acts, 40, "ints") - - # Create a MiddlePlotsData object from this, and add it to the dict - middle_plots_data_dict[feat] = MiddlePlotsData( - bottom10_logits=bottom10_logits, - top10_logits=top10_logits, - logits_histogram_data=logits_histogram_data, - freq_histogram_data=freq_histogram_data, - frac_nonzero=frac_nonzero, - ) - - # Return the output, as a dict of FeatureData items - for i, feat in enumerate(feature_indices_list): - features_data[feat] = FeatureData( - # Data-containing inputs (for the feature-centric visualisation) - sequence_data=sequence_data_dict[feat], - middle_plots_data=middle_plots_data_dict[feat], - left_tables_data=None, - # Non data-containing inputs - feature_idx=feat, - vocab_dict=vocab_dict, - fvp=fvp, - ) - - # Also get the quantiles, which will be useful for the prompt-centric visualisation - - feature_act_quantiles = QuantileCalculator( - data=rearrange(feature_acts, "... feats -> feats (...)") - ) - return MultiFeatureData(features_data, feature_act_quantiles) - - -def feature_indices_to_tensordict( - feature_indices_in: FeatureIndicesType | list[int] | None, - raw_sae_positions: list[str], - model: SAETransformer, -) -> dict[str, Tensor]: - """ "Convert feature indices to a dict of tensor indices""" - if feature_indices_in is None: - feature_indices = {} - for sae_name in raw_sae_positions: - feature_indices[sae_name] = torch.arange( - end=model.saes[sae_name.replace(".", "-")].n_dict_components - ) - # Otherwise make sure that feature_indices is a dict of Int[Tensor] - elif not isinstance(feature_indices_in, dict): - feature_indices = { - sae_name: Tensor(feature_indices_in).to("cpu").to(torch.int) - for sae_name in raw_sae_positions - } - else: - feature_indices: dict[str, Tensor] = { - sae_name: Tensor(feature_indices_in[sae_name]).to("cpu").to(torch.int) - for sae_name in raw_sae_positions - } - return feature_indices - - -@torch.inference_mode() -def get_dashboards_data( - model: SAETransformer, - dataset_config: DatasetConfig | None = None, - tokens: Int[Tensor, "batch pos"] | None = None, - sae_positions: list[str] | None = None, - feature_indices: FeatureIndicesType | list[int] | None = None, - n_samples: PositiveInt | None = None, - batch_size: PositiveInt | None = None, - minibatch_size_features: PositiveInt | None = None, - fvp: FeatureVisParams | None = None, - vocab_dict: dict[int, str] | None = None, -) -> dict[str, MultiFeatureData]: - """Gets data that needed to create the sequences in the feature-centric HTML visualisation - - Adapted from sae_vis.data_fetching_fns._get_feature_data() - - Args: - model: - The model (with SAEs) we'll be using to get the feature activations. - dataset_config: [Only used if tokens is None] - The DatasetConfig which will be used to get the data loader. If None, then tokens must - be supplied. - tokens: - The tokens we'll be using to get the feature activations. If None, then we'll use the - distribution from the dataset_config. - sae_positions: - The names of the SAEs we want to calculate feature dashboards for, - eg. ['blocks.0.hook_resid_pre']. If none, then we'll do all of them. - feature_indices: - The features we're actually computing for each SAE. These might just be a subset of - each SAE's full features. If None, then we'll do all of them. - n_samples: [Only used if tokens is None] - The number of batches of data to use for calculating the feature dashboard data when - using dataset_config. - batch_size: [Only used if tokens is None] - The number of batches of data to use for calculating the feature dashboard data when - using dataset_config - minibatch_size_features: - Num features in each batch of calculations (break up the features to avoid OOM errors). - fvp: - Feature visualization parameters, containing a bunch of other stuff. See the - FeatureVisParams docstring in sae_vis for more information. - vocab_dict: - vocab dict suitable for dashboards with all the special tokens replaced with their - HTML representations. If None then it will be created using create_vocab_dict(tokenizer) - - Returns: - A dict of [sae_position_name: MultiFeatureData]. Each MultiFeatureData contains data for - creating each feature's visualization, as well as data for rank-ordering the feature - visualizations when it comes time to make the prompt-centric view - (the `feature_act_quantiles` attribute). - Use dashboards_data[sae_name][feature_idx].get_html() to generate the HTML - dashboard for a particular feature (returns a string of HTML) - """ - # Get the vocab dict, which we'll use at the end - if vocab_dict is None: - assert ( - model.tlens_model.tokenizer is not None - ), "If voacab_dict is not supplied, the model must have a tokenizer" - vocab_dict = create_vocab_dict(model.tlens_model.tokenizer) - - if fvp is None: - fvp = FeatureVisParams(include_left_tables=False) - - if sae_positions is None: - raw_sae_positions: list[str] = model.raw_sae_positions - else: - raw_sae_positions: list[str] = filter_names( - list(model.tlens_model.hook_dict.keys()), sae_positions - ) - # If we haven't supplied any feature indicies, assume that we want all of them - feature_indices_tensors = feature_indices_to_tensordict( - feature_indices_in=feature_indices, - raw_sae_positions=raw_sae_positions, - model=model, - ) - for sae_name in raw_sae_positions: - assert ( - feature_indices_tensors[sae_name].max().item() - < model.saes[sae_name.replace(".", "-")].n_dict_components - ), "Error: Some feature indices are greater than the number of SAE features" - - device = model.saes[raw_sae_positions[0].replace(".", "-")].device - # Get the SAE feature activations (as well as their resudual stream inputs and outputs) - if tokens is None: - assert dataset_config is not None, "If no tokens are supplied, then config must be supplied" - assert ( - batch_size is not None - ), "If no tokens are supplied, then a batch_size must be supplied" - feature_acts, final_resid_acts, tokens = compute_feature_acts_on_distribution( - model=model, - dataset_config=dataset_config, - batch_size=batch_size, - raw_sae_positions=raw_sae_positions, - feature_indices=feature_indices_tensors, - n_samples=n_samples, - ) - else: - tokens.to(device) - feature_acts, final_resid_acts = compute_feature_acts( - model=model, - tokens=tokens, - raw_sae_positions=raw_sae_positions, - feature_indices=feature_indices_tensors, - ) - - # Filter out the never active features: - for sae_name in raw_sae_positions: - acts_sum = einsum(feature_acts[sae_name], "... some_feats -> some_feats").to("cpu") - feature_acts[sae_name] = feature_acts[sae_name][..., acts_sum > 0] - feature_indices_tensors[sae_name] = feature_indices_tensors[sae_name][acts_sum > 0] - del acts_sum - - dashboards_data: dict[str, MultiFeatureData] = { - name: MultiFeatureData() for name in raw_sae_positions - } - - for sae_name in raw_sae_positions: - sae = model.saes[sae_name.replace(".", "-")] - W_dec: Float[Tensor, "feats dim"] = sae.decoder.weight.T - feature_resid_dirs: Float[Tensor, "some_feats dim"] = W_dec[ - feature_indices_tensors[sae_name] - ] - W_U = model.tlens_model.W_U - - # Break up the features into batches - if minibatch_size_features is None: - feature_acts_batches = [feature_acts[sae_name]] - feature_batches = [feature_indices_tensors[sae_name].tolist()] - feature_resid_dir_batches = [feature_resid_dirs] - else: - feature_acts_batches = feature_acts[sae_name].split(minibatch_size_features, dim=-1) - feature_batches = [ - x.tolist() for x in feature_indices_tensors[sae_name].split(minibatch_size_features) - ] - feature_resid_dir_batches = feature_resid_dirs.split(minibatch_size_features) - for i in tqdm(iterable=range(len(feature_batches)), desc="Parsing activation data"): - new_feature_data = parse_activation_data( - tokens=tokens, - feature_acts=feature_acts_batches[i].to_dense().to(device), - final_resid_acts=final_resid_acts, - feature_resid_dirs=feature_resid_dir_batches[i], - feature_indices_list=feature_batches[i], - W_U=W_U, - vocab_dict=vocab_dict, - fvp=fvp, - ) - dashboards_data[sae_name].update(new_feature_data) - - return dashboards_data - - -@torch.inference_mode() -def parse_prompt_data( - tokens: Int[Tensor, "batch pos"], - str_tokens: list[str], - features_data: MultiFeatureData, - feature_acts: Float[Tensor, "seq some_feats"], - final_resid_acts: Float[Tensor, "seq d_resid"], - feature_resid_dirs: Float[Tensor, "some_feats dim"], - feature_indices_list: list[int], - W_U: Float[Tensor, "dim d_vocab"], - num_top_features: int = 10, -) -> MultiPromptData: - """Gets data needed to create the sequences in the prompt-centric HTML visualisation. - - This visualization displays dashboards for the most relevant features on a prompt. - Adapted from sae_vis.data_fetching_fns.get_prompt_data(). - - Args: - tokens: The input prompt to the model as tokens - str_tokens: The input prompt to the model as a list of strings (one string per token) - features_data: A MultiFeatureData containing information required to plot the features. - feature_acts: The activations values of the features - final_resid_acts: The activations of the final layer of the model - feature_resid_dirs: The directions that each feature writes to the logit output - feature_indices_list: The indices of the features we're interested in - W_U: The unembed weights for the logit lens - num_top_features: The number of top features to display in this view, for any given metric. - Returns: - A MultiPromptData object containing data for visualizing the most relevant features - given the prompt. - - Similar to parse_feature_data, except it just gets the data relevant for a particular - sequence (i.e. a custom one that the user inputs on their own). - - The ordering metric for relevant features is set by the str_score parameter in the - MultiPromptData.get_html() method: it can be "act_size", "act_quantile", or "loss_effect" - """ - torch.cuda.empty_cache() - device = W_U.device - n_feats = len(feature_indices_list) - batch, seq_len = tokens.shape - feats_contribution_to_loss = torch.empty(size=(n_feats, seq_len - 1), device=device) - - # Some logit computations which we only need to do once - correct_token_unembeddings = W_U[:, tokens[0, 1:]] # [d_model seq] - orig_logits = ( - final_resid_acts / final_resid_acts.std(dim=-1, keepdim=True) - ) @ W_U # [seq d_vocab] - - sequence_data_dict: dict[int, SequenceData] = {} - - for i, feat in enumerate(feature_indices_list): - # Calculate all data for the sequences - # (this is the only truly 'new' bit of calculation we need to do) - - # Get this feature's output vector, using an outer product over feature acts for all tokens - final_resid_acts_feature_effect = einsum( - feature_acts[..., i].to_dense().to(device), - feature_resid_dirs[i], - "seq, d_model -> seq d_model", - ) - - # Ablate the output vector from the residual stream, and get logits post-ablation - new_final_resid_acts = final_resid_acts - final_resid_acts_feature_effect - new_logits = (new_final_resid_acts / new_final_resid_acts.std(dim=-1, keepdim=True)) @ W_U - - # Get the top5 & bottom5 changes in logits - contribution_to_logprobs = orig_logits.log_softmax(dim=-1) - new_logits.log_softmax(dim=-1) - top5_contribution_to_logits = TopK(contribution_to_logprobs[:-1], k=5) - bottom5_contribution_to_logits = TopK(contribution_to_logprobs[:-1], k=5, largest=False) - - # Get the change in loss (which is negative of change of logprobs for correct token) - contribution_to_loss = eindex(-contribution_to_logprobs[:-1], tokens[0, 1:], "seq [seq]") - feats_contribution_to_loss[i, :] = contribution_to_loss - - # Store the sequence data - sequence_data_dict[feat] = SequenceData( - token_ids=tokens.squeeze(0).tolist(), - feat_acts=feature_acts[..., i].tolist(), - contribution_to_loss=[0.0] + contribution_to_loss.tolist(), - top5_token_ids=top5_contribution_to_logits.indices.tolist(), - top5_logit_contributions=top5_contribution_to_logits.values.tolist(), - bottom5_token_ids=bottom5_contribution_to_logits.indices.tolist(), - bottom5_logit_contributions=bottom5_contribution_to_logits.values.tolist(), - ) - - # Get the logits for the correct tokens - logits_for_correct_tokens = einsum( - feature_resid_dirs[i], correct_token_unembeddings, "d_model, d_model seq -> seq" - ) - - # Add the annotations data (feature activations and logit effect) to the histograms - freq_line_posn = feature_acts[..., i].tolist() - freq_line_text = [ - f"\\'{str_tok}\\'
{act:.3f}" - for str_tok, act in zip(str_tokens[1:], freq_line_posn, strict=False) - ] - middle_plots_data = features_data[feat].middle_plots_data - assert middle_plots_data is not None - middle_plots_data.freq_histogram_data.line_posn = freq_line_posn - middle_plots_data.freq_histogram_data.line_text = freq_line_text # type: ignore (due to typing bug in sae_vis) - logits_line_posn = logits_for_correct_tokens.tolist() - logits_line_text = [ - f"\\'{str_tok}\\'
{logits:.3f}" - for str_tok, logits in zip(str_tokens[1:], logits_line_posn, strict=False) - ] - middle_plots_data.logits_histogram_data.line_posn = logits_line_posn - middle_plots_data.logits_histogram_data.line_text = logits_line_text # type: ignore (due to typing bug in sae_vis) - - # Lastly, use the criteria (act size, act quantile, loss effect) to find top-scoring features - - # Construct a scores dict, which maps from things like ("act_quantile", seq_pos) - # to a list of the top-scoring features - scores_dict: dict[tuple[str, str], tuple[TopK, list[str]]] = {} - - for seq_pos in range(len(str_tokens)): - # Filter the feature activations, since we only need the ones that are non-zero - feat_acts_nonzero_filter = to_numpy(feature_acts[seq_pos] > 0) - feat_acts_nonzero_locations = np.nonzero(feat_acts_nonzero_filter)[0].tolist() - _feature_acts = ( - feature_acts[seq_pos, feat_acts_nonzero_filter].to_dense().to(device) - ) # [feats_filtered,] - _feature_indices_list = np.array(feature_indices_list)[feat_acts_nonzero_filter] - - if feat_acts_nonzero_filter.sum() > 0: - k = min(num_top_features, _feature_acts.numel()) - - # Get the "act_size" scores (we return it as a TopK object) - act_size_topk = TopK(_feature_acts, k=k, largest=True) - # Replace the indices with feature indices (these are different when - # feature_indices_list argument is not [0, 1, 2, ...]) - act_size_topk.indices[:] = _feature_indices_list[act_size_topk.indices] - scores_dict[("act_size", seq_pos)] = (act_size_topk, ".3f") # type: ignore (due to typing bug in sae_vis) - - # Get the "act_quantile" scores, which is just the fraction of cached feat acts that it - # is larger than - act_quantile, act_precision = features_data.feature_act_quantiles.get_quantile( - _feature_acts, feat_acts_nonzero_locations - ) - act_quantile_topk = TopK(act_quantile, k=k, largest=True) - act_formatting_topk = [f".{act_precision[i]-2}%" for i in act_quantile_topk.indices] - # Replace the indices with feature indices (these are different when - # feature_indices_list argument is not [0, 1, 2, ...]) - act_quantile_topk.indices[:] = _feature_indices_list[act_quantile_topk.indices] - scores_dict[("act_quantile", seq_pos)] = (act_quantile_topk, act_formatting_topk) # type: ignore (due to typing bug in sae_vis) - - # We don't measure loss effect on the first token - if seq_pos == 0: - continue - - # Filter the loss effects, since we only need the ones which have non-zero feature acts on - # the tokens before them - prev_feat_acts_nonzero_filter = to_numpy(feature_acts[seq_pos - 1] > 0) - _contribution_to_loss = feats_contribution_to_loss[ - prev_feat_acts_nonzero_filter, seq_pos - 1 - ] # [feats_filtered,] - _feature_indices_list_prev = np.array(feature_indices_list)[prev_feat_acts_nonzero_filter] - - if prev_feat_acts_nonzero_filter.sum() > 0: - k = min(num_top_features, _contribution_to_loss.numel()) - - # Get the "loss_effect" scores, which are just the min of features' contributions to - # loss (min because we're looking for helpful features, not harmful ones) - contribution_to_loss_topk = TopK(_contribution_to_loss, k=k, largest=False) - # Replace the indices with feature indices (these are different when - # feature_indices_list argument is not [0, 1, 2, ...]) - contribution_to_loss_topk.indices[:] = _feature_indices_list_prev[ - contribution_to_loss_topk.indices - ] - scores_dict[("loss_effect", seq_pos)] = (contribution_to_loss_topk, ".3f") # type: ignore (due to typing bug in sae_vis) - - # Get all the features which are required (i.e. all the sequence position indices) - feature_indices_list_required = set() - for score_topk, _ in scores_dict.values(): - feature_indices_list_required.update(set(score_topk.indices.tolist())) - prompt_data_dict = {} - for feat in feature_indices_list_required: - middle_plots_data = features_data[feat].middle_plots_data - assert middle_plots_data is not None - prompt_data_dict[feat] = PromptData( - prompt_data=sequence_data_dict[feat], - sequence_data=features_data[feat].sequence_data[0], - middle_plots_data=middle_plots_data, - ) - - return MultiPromptData( - prompt_str_toks=str_tokens, - prompt_data_dict=prompt_data_dict, - scores_dict=scores_dict, - ) - - -@torch.inference_mode() -def get_prompt_data( - model: SAETransformer, - tokens: Int[Tensor, "batch pos"], - str_tokens: list[str], - dashboards_data: dict[str, MultiFeatureData], - sae_positions: list[str] | None = None, - num_top_features: PositiveInt = 10, -) -> dict[str, MultiPromptData]: - """Gets data needed to create the sequences in the prompt-centric HTML visualisation. - - This visualization displays dashboards for the most relevant features on a prompt. - Adapted from sae_vis.data_fetching_fns.get_prompt_data() - - Args: - model: - The model (with SAEs) we'll be using to get the feature activations. - tokens: - The input prompt to the model as tokens - str_tokens: - The input prompt to the model as a list of strings (one string per token) - dashboards_data: - For each SAE, a MultiFeatureData containing information required to plot its features. - sae_positions: - The names of the SAEs we want to find relevant features in. - eg. ['blocks.0.hook_resid_pre']. If none, then we'll do all of them. - num_top_features: int - The number of top features to display in this view, for any given metric. - - Returns: - A dict of [sae_position_name: MultiPromptData]. Each MultiPromptData contains data for - visualizing the most relevant features in that SAE given the prompt. - Similar to get_feature_data, except it just gets the data relevant for a particular - sequence (i.e. a custom one that the user inputs on their own). - - The ordering metric for relevant features is set by the str_score parameter in the - MultiPromptData.get_html() method: it can be "act_size", "act_quantile", or "loss_effect" - """ - assert tokens.shape[-1] == len( - str_tokens - ), "Error: the number of tokens does not equal the number of str_tokens" - if sae_positions is None: - raw_sae_positions: list[str] = model.raw_sae_positions - else: - raw_sae_positions: list[str] = filter_names( - list(model.tlens_model.hook_dict.keys()), sae_positions - ) - feature_indices: dict[str, list[int]] = {} - for sae_name in raw_sae_positions: - feature_indices[sae_name] = list(dashboards_data[sae_name].feature_data_dict.keys()) - - feature_acts, final_resid_acts = compute_feature_acts( - model=model, - tokens=tokens, - raw_sae_positions=raw_sae_positions, - feature_indices=feature_indices, - ) - final_resid_acts = final_resid_acts.squeeze(dim=0) - - prompt_data: dict[str, MultiPromptData] = {} - - for sae_name in raw_sae_positions: - sae = model.saes[sae_name.replace(".", "-")] - feature_act_dir: Float[Tensor, "dim some_feats"] = sae.encoder[0].weight.T[ - :, feature_indices[sae_name] - ] # [d_in feats] - feature_resid_dirs: Float[Tensor, "some_feats dim"] = sae.decoder.weight.T[ - feature_indices[sae_name] - ] # [feats d_in] - assert ( - feature_act_dir.T.shape - == feature_resid_dirs.shape - == (len(feature_indices[sae_name]), sae.input_size) - ) - - prompt_data[sae_name] = parse_prompt_data( - tokens=tokens, - str_tokens=str_tokens, - features_data=dashboards_data[sae_name], - feature_acts=feature_acts[sae_name].squeeze(dim=0), - final_resid_acts=final_resid_acts, - feature_resid_dirs=feature_resid_dirs, - feature_indices_list=feature_indices[sae_name], - W_U=model.tlens_model.W_U, - num_top_features=num_top_features, - ) - return prompt_data - - -@torch.inference_mode() -def generate_feature_dashboard_html_files( - dashboards_data: dict[str, MultiFeatureData], - feature_indices: FeatureIndicesType | dict[str, set[int]] | None, - save_dir: str | Path = "", -): - """Generates viewable HTML dashboards for every feature in every SAE in dashboards_data""" - if feature_indices is None: - feature_indices = {name: dashboards_data[name].keys() for name in dashboards_data} - save_dir = Path(save_dir) - save_dir.mkdir(parents=True, exist_ok=True) - for sae_name in feature_indices: - logger.info(f"Saving HTML feature dashboards for the SAE at {sae_name}:") - folder = save_dir / Path(f"dashboards_{sae_name}") - folder.mkdir(parents=True, exist_ok=True) - for feature_idx in tqdm(feature_indices[sae_name], desc="Dashboard HTML files"): - feature_idx = ( - int(feature_idx.item()) if isinstance(feature_idx, Tensor) else feature_idx - ) - if feature_idx in dashboards_data[sae_name].keys(): - html_str = dashboards_data[sae_name][feature_idx].get_html() - filepath = folder / Path(f"feature-{feature_idx}.html") - with open(filepath, "w") as f: - f.write(html_str) - logger.info(f"Saved HTML feature dashboards in {folder}") - - -@torch.inference_mode() -def generate_prompt_dashboard_html_files( - model: SAETransformer, - tokens: Int[Tensor, "batch pos"], - str_tokens: list[str], - dashboards_data: dict[str, MultiFeatureData], - seq_pos: int | list[int] | None = None, - vocab_dict: dict[int, str] | None = None, - str_score: StrScoreType = "loss_effect", - save_dir: str | Path = "", -) -> dict[str, set[int]]: - """Generates viewable HTML dashboards for the most relevant features (measured by str_score) for - every SAE in dashboards_data. - - Returns the set of feature indices which were active""" - assert tokens.shape[-1] == len( - str_tokens - ), "Error: the number of tokens does not equal the number of str_tokens" - str_tokens = [s.replace("Ġ", " ") for s in str_tokens] - if isinstance(seq_pos, int): - seq_pos = [seq_pos] - if seq_pos is None: # Generate a dashboard for every position if none is specified - seq_pos = list(range(2, len(str_tokens) - 2)) - if vocab_dict is None: - assert ( - model.tlens_model.tokenizer is not None - ), "If voacab_dict is not supplied, the model must have a tokenizer" - vocab_dict = create_vocab_dict(model.tlens_model.tokenizer) - prompt_data = get_prompt_data( - model=model, tokens=tokens, str_tokens=str_tokens, dashboards_data=dashboards_data - ) - prompt = "".join(str_tokens) - # Use the beginning of the prompt for the filename, but make sure that it's safe for a filename - str_tokens_safe_for_filenames = [ - "".join(c for c in token if c.isalpha() or c.isdigit() or c == " ") - .rstrip() - .replace(" ", "-") - for token in str_tokens - ] - filename_from_prompt = "".join(str_tokens_safe_for_filenames) - filename_from_prompt = filename_from_prompt[:50] - save_dir = Path(save_dir) - save_dir.mkdir(parents=True, exist_ok=True) - used_features: dict[str, set[int]] = {sae_name: set() for sae_name in dashboards_data} - for sae_name in dashboards_data: - seq_pos_with_scores: set[int] = { - int(x[1]) - for x in prompt_data["blocks.1.hook_resid_post"].scores_dict - if x[0] == str_score - } - for seq_pos_i in seq_pos_with_scores.intersection(seq_pos): - # Find the most relevant features (by {str_score}) for the token - # '{str_tokens[seq_pos_i]}' in the prompt '{prompt}: - html_str = prompt_data[sae_name].get_html(seq_pos_i, str_score, vocab_dict) - # Insert a title - title: str = ( - f"

  The most relevant features from {sae_name},
  " - f"measured by {str_score} on the '{str_tokens[seq_pos_i].replace('Ġ',' ')}' " - f"token (token number {seq_pos_i}) in the prompt '{prompt}':

" - ) - substr = "
" - html_str = html_str.replace( - substr, "
" + title + "
\n" + substr - ) - filepath = save_dir / Path( - f"prompt-{filename_from_prompt}_token-{seq_pos_i}-" - f"{str_tokens_safe_for_filenames[seq_pos_i]}_-{str_score.replace('_','-')}_" - f"sae-{sae_name}.html" - ) - with open(filepath, "w") as f: - f.write(html_str) - scores = prompt_data[sae_name].scores_dict[(str_score, seq_pos_i)][0] # type: ignore - used_features[sae_name] = used_features[sae_name].union(scores.indices.tolist()) - return used_features - - -@torch.inference_mode() -def generate_random_prompt_dashboards( - model: SAETransformer, - dashboards_data: dict[str, MultiFeatureData], - dashboards_config: DashboardsConfig, - use_model_tokenizer: bool = False, - save_dir: RootPath | None = None, -) -> dict[str, set[int]]: - """Generates prompt-centric HTML dashboards for prompts from the training distribution. - - A data_loader is created using the dashboards_config.prompt_centric.data if it exists, - otherwise using the dashboards_config.data config. - For each random prompt, dashboards are generated for three consecutive sequence positions.""" - np.random.seed(dashboards_config.seed) - if save_dir is None: - save_dir = dashboards_config.save_dir - assert save_dir is not None, ( - "generate_random_prompt_dashboards() saves HTML files, but no save_dir was specified in" - + " the dashboards_config or given as input" - ) - assert dashboards_config.prompt_centric is not None, ( - "generate_random_prompt_dashboards() makes prompt-centric dashboards: " - + "the dashboards_config.prompt_centric config must exist" - ) - dataset_config = ( - dashboards_config.prompt_centric.data - if dashboards_config.prompt_centric.data - else dashboards_config.data - ) - data_loader, _ = create_data_loader(dataset_config=dataset_config, batch_size=1, buffer_size=1) - assert model.tlens_model.tokenizer is not None, "The model must have a tokenizer" - if use_model_tokenizer: - tokenizer = model.tlens_model.tokenizer - assert isinstance(tokenizer, PreTrainedTokenizer | PreTrainedTokenizerFast) - else: - tokenizer = AutoTokenizer.from_pretrained(dashboards_config.data.tokenizer_name) - vocab_dict = create_vocab_dict(tokenizer) - if dashboards_config.sae_positions is None: - raw_sae_positions: list[str] = model.raw_sae_positions - else: - raw_sae_positions: list[str] = filter_names( - list(model.tlens_model.hook_dict.keys()), dashboards_config.sae_positions - ) - - used_features: dict[str, set[int]] = {sae_name: set() for sae_name in dashboards_data} - device = model.saes[raw_sae_positions[0].replace(".", "-")].device - n_prompts = (dashboards_config.prompt_centric.n_random_prompt_dashboards + 2) // 3 - for prompt_idx, batch in tqdm( - enumerate(data_loader), - total=n_prompts, - desc="Random prompt dashboards", - ): - batch_tokens: Int[Tensor, "1 pos"] = batch[dashboards_config.data.column_name].to( - device=device - ) - assert len(batch_tokens.shape) == 2 and batch_tokens.shape[0] == 1 - # Use the tokens from the first <|endoftext|> token to the next - bos_inds = torch.argwhere(batch_tokens == tokenizer.bos_token_id)[:, 1] - if len(bos_inds) > 1: - batch_tokens = batch_tokens[:, bos_inds[0] : bos_inds[1]] - str_tokens = tokenizer.convert_ids_to_tokens(batch_tokens.squeeze(dim=0).tolist()) - assert isinstance(str_tokens, list) - seq_len: int = batch_tokens.shape[1] - # Generate dashboards for three consecutive positions in the prompt, chosen randomly - if seq_len > 4: # Ensure the prompt is long enough for three positions + next token effect - seq_pos_c = np.random.randint(1, seq_len - 3) - seq_pos = [seq_pos_c - 1, seq_pos_c, seq_pos_c + 1] - used_features_now = generate_prompt_dashboard_html_files( - model=model, - tokens=batch_tokens, - str_tokens=str_tokens, - dashboards_data=dashboards_data, - seq_pos=seq_pos, - vocab_dict=vocab_dict, - str_score=dashboards_config.prompt_centric.str_score, - save_dir=save_dir, - ) - for sae_name in used_features: - used_features[sae_name] = used_features[sae_name].union(used_features_now[sae_name]) - - if prompt_idx > n_prompts: - break - return used_features - - -@torch.inference_mode() -def generate_dashboards( - model: SAETransformer, dashboards_config: DashboardsConfig, save_dir: RootPath | None = None -) -> None: - """Generate HTML feature dashboards for an SAETransformer and save them. - - First the data for the dashboards are crated using dashboards_data = get_dashboards_data(), - then prompt-centric HTML dashboards are created (if dashboards_config.prompt_centric exists), - then feature-centric HTML dashboards are created for any features in - dashboards_config.feature_indices (all features if this is None), or any features which - appeared in prompt-centric dashboards. - Dashboards are saved in dashboards_config.save_dir - """ - if save_dir is None: - save_dir = dashboards_config.save_dir - assert save_dir is not None, ( - "generate_dashboards() saves HTML files, but no save_dir was specified in the" - + " dashboards_config or given as input" - ) - # Deal with the possible input typles of sae_positions - if dashboards_config.sae_positions is None: - raw_sae_positions = model.raw_sae_positions - else: - raw_sae_positions = filter_names( - list(model.tlens_model.hook_dict.keys()), dashboards_config.sae_positions - ) - # Deal with the possible input typles of feature_indices - feature_indices = feature_indices_to_tensordict( - dashboards_config.feature_indices, raw_sae_positions, model - ) - - # Get the data used in the dashboards - dashboards_data: dict[str, MultiFeatureData] = get_dashboards_data( - model=model, - dataset_config=dashboards_config.data, - sae_positions=raw_sae_positions, - # We need data for every feature if we're generating prompt-centric dashboards: - feature_indices=None if dashboards_config.prompt_centric else feature_indices, - n_samples=dashboards_config.n_samples, - batch_size=dashboards_config.batch_size, - minibatch_size_features=dashboards_config.minibatch_size_features, - ) - - # Generate the prompt-centric dashboards and record which features were active on them - used_features: dict[str, set[int]] = {sae_name: set() for sae_name in dashboards_data} - if dashboards_config.prompt_centric: - prompt_dashboard_saving_folder = save_dir / Path("prompt-dashboards") - prompt_dashboard_saving_folder.mkdir(parents=True, exist_ok=True) - # Generate random prompt-centric dashboards - if dashboards_config.prompt_centric.n_random_prompt_dashboards > 0: - used_features_now = generate_random_prompt_dashboards( - model=model, - dashboards_data=dashboards_data, - dashboards_config=dashboards_config, - save_dir=prompt_dashboard_saving_folder, - ) - for sae_name in used_features: - used_features[sae_name] = used_features[sae_name].union(used_features_now[sae_name]) - - # Generate dashboards for specific prompts - if dashboards_config.prompt_centric.prompts is not None: - tokenizer = AutoTokenizer.from_pretrained(dashboards_config.data.tokenizer_name) - vocab_dict = create_vocab_dict(tokenizer) - for prompt in dashboards_config.prompt_centric.prompts: - tokens = tokenizer(prompt)["input_ids"] - list_tokens = tokens.tolist() if isinstance(tokens, Tensor) else tokens - assert isinstance(list_tokens, list) - str_tokens = tokenizer.convert_ids_to_tokens(list_tokens) - assert isinstance(str_tokens, list) - used_features_now = generate_prompt_dashboard_html_files( - model=model, - tokens=torch.Tensor(tokens).to(dtype=torch.int).unsqueeze(dim=0), - str_tokens=str_tokens, - dashboards_data=dashboards_data, - str_score=dashboards_config.prompt_centric.str_score, - vocab_dict=vocab_dict, - save_dir=prompt_dashboard_saving_folder, - ) - for sae_name in used_features: - used_features[sae_name] = used_features[sae_name].union( - used_features_now[sae_name] - ) - - for sae_name in raw_sae_positions: - used_features[sae_name] = used_features[sae_name].union( - set(feature_indices[sae_name].tolist()) - ) - - # Generate the viewable HTML feature dashboard files - dashboard_html_saving_folder = save_dir / Path("feature-dashboards") - dashboard_html_saving_folder.mkdir(parents=True, exist_ok=True) - generate_feature_dashboard_html_files( - dashboards_data=dashboards_data, - feature_indices=used_features if dashboards_config.prompt_centric else feature_indices, - save_dir=dashboard_html_saving_folder, - ) - - -# Load the saved SAEs and the corresponding model -def load_SAETransformer_from_saes_paths( - pretrained_sae_paths: list[RootPath] | list[str] | None, - config_path: RootPath | str | None = None, - sae_positions: list[str] | None = None, -) -> tuple[SAETransformer, Config, list[str]]: - if pretrained_sae_paths is not None: - pretrained_sae_paths = [Path(p) for p in pretrained_sae_paths] - for path in pretrained_sae_paths: - assert path.exists(), f"pretrained_sae_path: {path} does not exist" - assert path.is_file() and ( - path.suffix == ".pt" or path.suffix == ".pth" - ), f"pretrained_sae_path: {path} is not a .pt or .pth file" - - if config_path is None: - assert ( - pretrained_sae_paths is not None - ), "Either config_path or pretrained_sae_paths must be provided" - config_path = pretrained_sae_paths[0].parent / "config.yaml" - config_path = Path(config_path) - assert config_path.exists(), f"config_path: {config_path} does not exist" - assert ( - config_path.is_file() and config_path.suffix == ".yaml" - ), f"config_path: {config_path} does not exist" - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - config = load_config(config_path, config_model=Config) - if pretrained_sae_paths is None: - pretrained_sae_paths = config.saes.pretrained_sae_paths - assert pretrained_sae_paths is not None, "pretrained_sae_paths must be given or in config" - logger.info(config) - - tlens_model = load_tlens_model( - tlens_model_name=config.tlens_model_name, tlens_model_path=config.tlens_model_path - ) - assert tlens_model is not None - - if sae_positions is None: - sae_positions = config.saes.sae_positions - - raw_sae_positions = filter_names(list(tlens_model.hook_dict.keys()), sae_positions) - model = SAETransformer( - tlens_model=tlens_model, - raw_sae_positions=raw_sae_positions, - dict_size_to_input_ratio=config.saes.dict_size_to_input_ratio, - init_decoder_orthogonal=False, - ).to(device=device) - - all_param_names = [name for name, _ in model.saes.named_parameters()] - trainable_param_names = load_pretrained_saes( - saes=model.saes, - pretrained_sae_paths=pretrained_sae_paths, - all_param_names=all_param_names, - retrain_saes=config.saes.retrain_saes, - ) - return model, config, trainable_param_names - - -def main( - config_path_or_obj: Path | str | DashboardsConfig, - pretrained_sae_paths: Path | str | list[Path] | list[str] | None, -) -> None: - dashboards_config = load_config(config_path_or_obj, config_model=DashboardsConfig) - logger.info(dashboards_config) - - if pretrained_sae_paths is None: - assert ( - dashboards_config.pretrained_sae_paths is not None - ), "pretrained_sae_paths must be provided, either in the dashboards config or as an input" - pretrained_sae_paths = dashboards_config.pretrained_sae_paths - else: - pretrained_sae_paths = ( - pretrained_sae_paths - if isinstance(pretrained_sae_paths, list) - else [Path(pretrained_sae_paths)] - ) - - logger.info("Loading the model and SAEs") - model, _, _ = load_SAETransformer_from_saes_paths( - pretrained_sae_paths, dashboards_config.sae_config_path - ) - logger.info("done") - - save_dir = dashboards_config.save_dir or Path(pretrained_sae_paths[0]).parent - logger.info(f"The HTML dashboards will be saved in {save_dir}") - generate_dashboards(model, dashboards_config, save_dir=save_dir) - - -if __name__ == "__main__": - fire.Fire(main) diff --git a/pyproject.toml b/pyproject.toml index 6e79c78..25c94fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ dependencies = [ "wandb~=0.16.2", "fire~=0.5.0", "tqdm~=4.66.1", - "pytest~=8.0.0", + "pytest~=8.1.2", "ipykernel~=6.29.0", "transformer-lens~=1.14.0", "jaxtyping~=0.2.25", @@ -20,7 +20,9 @@ dependencies = [ "zstandard~=0.22.0", "matplotlib>=3.5.3", "eindex-callum@git+https://github.com/callummcdougall/eindex", - "sae_vis@git+https://github.com/callummcdougall/sae_vis.git@b28a0f7c7e936f4bea05528d952dfcd438533cce" + # "sae_vis@git+https://github.com/callummcdougall/sae_vis.git@b28a0f7c7e936f4bea05528d952dfcd438533cce", + "sae_lens@git+https://github.com/jbloomAus/SAELens.git@7c43c4caa84aea421ac81ae0e326d9c62bb17bec", + "neuron_explainer@git+https://github.com/hijohnnylin/automated-interpretability.git" ] [project.urls] @@ -29,7 +31,7 @@ repository = "https://github.com/ApolloResearch/e2e_sae" [project.optional-dependencies] dev = [ "ruff~=0.1.14", - "pyright~=1.1.357", + "pyright~=1.1.360", "pre-commit~=3.6.0", ] diff --git a/tests/test_dashboards.py b/tests/test_dashboards.py deleted file mode 100644 index f72beb4..0000000 --- a/tests/test_dashboards.py +++ /dev/null @@ -1,124 +0,0 @@ -from pathlib import Path - -import pytest -import torch -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast - -from e2e_sae.data import DatasetConfig -from e2e_sae.loader import load_tlens_model -from e2e_sae.models.transformers import SAETransformer -from e2e_sae.scripts.generate_dashboards import ( - DashboardsConfig, - compute_feature_acts, - create_vocab_dict, - generate_dashboards, -) -from e2e_sae.utils import set_seed -from tests.utils import get_tinystories_config - -Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast - - -@pytest.fixture(scope="module") -def tinystories_model() -> SAETransformer: - tlens_model = load_tlens_model( - tlens_model_name="roneneldan/TinyStories-1M", tlens_model_path=None - ) - sae_position = "blocks.2.hook_resid_post" - config = get_tinystories_config({"saes": {"sae_positions": sae_position}}) - model = SAETransformer( - tlens_model=tlens_model, - raw_sae_positions=[sae_position], - dict_size_to_input_ratio=config.saes.dict_size_to_input_ratio, - ) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - return model - - -def test_compute_feature_acts(tinystories_model: SAETransformer): - set_seed(0) - prompt = "Once upon" - tokenizer = tinystories_model.tlens_model.tokenizer - assert tokenizer is not None - tokens = tokenizer(prompt, return_tensors="pt")["input_ids"] - assert isinstance(tokens, torch.Tensor) - feature_indices = {name: list(range(7)) for name in tinystories_model.raw_sae_positions} - feature_acts, final_resid_acts = compute_feature_acts( - tinystories_model, tokens, feature_indices=feature_indices - ) - for sae_name, acts in feature_acts.items(): - assert acts.shape[0] == 1 # batch size - assert acts.shape[2] == 7 # feature_indices - - -def test_create_vocab_dict(tinystories_model: SAETransformer): - tokenizer = tinystories_model.tlens_model.tokenizer - assert tokenizer is not None - vocab_dict = create_vocab_dict(tokenizer) - assert isinstance(tokenizer, PreTrainedTokenizerFast) - assert len(vocab_dict) == len(tokenizer.vocab) - for token_id, token_str in vocab_dict.items(): - assert isinstance(token_id, int) - assert isinstance(token_str, str) - - -def check_valid_feature_dashboard_htmls(folder: Path): - assert folder.exists() - for html_file in folder.iterdir(): - assert html_file.name.endswith(".html") - assert html_file.exists() - with open(html_file) as f: - html_content = f.read() - assert isinstance(html_content, str) - assert len(html_content) > 100 - assert "Plotly.newPlot('histogram-acts'" in html_content - assert '
'" in html_content - assert "Feature #" in html_content - assert '
Date: Mon, 29 Apr 2024 18:35:48 +0000 Subject: [PATCH 03/12] bar plot --- e2e_sae/scripts/autointerp.py | 86 +++++++++++++++++++++-------------- 1 file changed, 51 insertions(+), 35 deletions(-) diff --git a/e2e_sae/scripts/autointerp.py b/e2e_sae/scripts/autointerp.py index 32785d9..60f7f89 100644 --- a/e2e_sae/scripts/autointerp.py +++ b/e2e_sae/scripts/autointerp.py @@ -36,6 +36,7 @@ LogprobFreeExplanationTokenSimulator, NeuronSimulator, ) +from scipy.stats import bootstrap from tenacity import retry, stop_after_attempt, wait_random_exponential from tqdm import tqdm @@ -443,16 +444,7 @@ def run_autointerp( print(f"Your feature is at: {feature_url}") -def compare_autointerp_results(out_dir: Path): - """ - Compare autointerp results across SAEs. - - Args: - output_dir: Directory containing the autointerp output files. - - Returns: - None - """ +def get_autointerp_results_df(out_dir: Path): autointerp_files = glob.glob(f"{out_dir}/*_feature-*_score-*") stats = { "layer": [], @@ -476,7 +468,21 @@ def compare_autointerp_results(out_dir: Path): stats["explanationModel"].append("gpt-4-1106-preview") stats["explanation"].append(json_data["explanation"]) df_stats = pd.DataFrame(stats) + assert not df_stats.empty + return df_stats + +def compare_autointerp_results(out_dir: Path): + """ + Compare autointerp results across SAEs. + + Args: + output_dir: Directory containing the autointerp output files. + + Returns: + None + """ + df_stats = get_autointerp_results_df(out_dir) sns.set_theme(style="whitegrid", palette="pastel") g = sns.catplot( data=df_stats, @@ -529,30 +535,7 @@ def compare_autointerp_results(out_dir: Path): def compare_across_saes(out_dir: Path): # Get data about the quality of the autointerp from the .jsonl files - autointerp_files = glob.glob(f"{out_dir}/*_feature-*_score-*") - stats = { - "layer": [], - "sae": [], - "feature": [], - "explanationScore": [], - "explanationModel": [], - "autointerpModel": [], - "explanation": [], - } - for autointerp_file in tqdm(autointerp_files, "constructing stats from json autointerp files"): - with open(autointerp_file) as f: - json_data = json.loads(json.load(f)) - if "feature" in json_data: - json_data = json_data["feature"] - stats["layer"].append(int(json_data["layer"].split("-")[0])) - stats["sae"].append("-".join(json_data["layer"].split("-")[1:])) - stats["feature"].append(json_data["index"]) - stats["explanationScore"].append(json_data["explanationScore"]) - stats["autointerpModel"].append(json_data["autointerpModel"]) - stats["explanationModel"].append("gpt-4-1106-preview") - stats["explanation"].append(json_data["explanation"]) - df_stats = pd.DataFrame(stats) - + df_stats = get_autointerp_results_df(out_dir) # Plot the relative performance of the SAEs sns.set_theme(style="whitegrid", palette="pastel") g = sns.catplot( @@ -607,8 +590,40 @@ def compare_across_saes(out_dir: Path): print(bar_file) +def bootstrapped_bar(out_dir: Path): + results_df = get_autointerp_results_df(out_dir) + fig, axs = plt.subplots(1, 2, figsize=(8, 4), sharey=True) + + sae_names = { + "res_scefr-ajt": "Downstream", + "res_sll-ajt": "CE Local", + "res_scl-ajt": "L0 Local", + } + # make matplotlib histogram with ci as error bars + for layer, ax in zip([6, 10], axs, strict=True): + layer_data = results_df.loc[results_df.layer == layer] + + means, yerrs = [], [[], []] + for sae_type in sae_names: + sae_data = layer_data.loc[layer_data.sae == sae_type] + scores = sae_data.explanationScore.to_numpy() + ci = bootstrap((scores,), statistic=np.mean).confidence_interval + means.append(scores.mean()) + yerrs[0].append(scores.mean() - ci.low) + yerrs[1].append(ci.high - scores.mean()) + + ax.bar(range(3), means, yerr=yerrs, capsize=5) + ax.set_title(f"Layer {layer}") + ax.set_xticks(range(3), sae_names.values()) + + axs[0].set_ylabel("Mean Explanation Score") + + plt.tight_layout() + plt.show() + + if __name__ == "__main__": - out_dir = Path("neuronpedia_outputs/autointerp") + out_dir = Path(__file__).parent / "out/autointerp/jordan_results_4_29" ## Running autointerp # Compare similar CE e2e+Downstream with similar CE local and similar L0 local sae_sets = [ @@ -628,3 +643,4 @@ def compare_across_saes(out_dir: Path): ## Analysis of autointerp results compare_autointerp_results(out_dir) compare_across_saes(out_dir) + bootstrapped_bar(out_dir) From 4ee4cc37fe34b0776c1edf45f57cf3e61a6d2446 Mon Sep 17 00:00:00 2001 From: Nix Goldowsky-Dill Date: Mon, 29 Apr 2024 21:20:11 +0000 Subject: [PATCH 04/12] only compute df once, pvals --- e2e_sae/scripts/autointerp.py | 59 +++++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 13 deletions(-) diff --git a/e2e_sae/scripts/autointerp.py b/e2e_sae/scripts/autointerp.py index 60f7f89..62a5f37 100644 --- a/e2e_sae/scripts/autointerp.py +++ b/e2e_sae/scripts/autointerp.py @@ -36,6 +36,7 @@ LogprobFreeExplanationTokenSimulator, NeuronSimulator, ) +from numpy.typing import ArrayLike from scipy.stats import bootstrap from tenacity import retry, stop_after_attempt, wait_random_exponential from tqdm import tqdm @@ -445,7 +446,7 @@ def run_autointerp( def get_autointerp_results_df(out_dir: Path): - autointerp_files = glob.glob(f"{out_dir}/*_feature-*_score-*") + autointerp_files = glob.glob(f"{out_dir}/**/*_feature-*_score-*", recursive=True) stats = { "layer": [], "sae": [], @@ -469,10 +470,11 @@ def get_autointerp_results_df(out_dir: Path): stats["explanation"].append(json_data["explanation"]) df_stats = pd.DataFrame(stats) assert not df_stats.empty + df_stats.dropna(inplace=True) return df_stats -def compare_autointerp_results(out_dir: Path): +def compare_autointerp_results(df_stats: pd.DataFrame): """ Compare autointerp results across SAEs. @@ -482,7 +484,6 @@ def compare_autointerp_results(out_dir: Path): Returns: None """ - df_stats = get_autointerp_results_df(out_dir) sns.set_theme(style="whitegrid", palette="pastel") g = sns.catplot( data=df_stats, @@ -533,9 +534,7 @@ def compare_autointerp_results(out_dir: Path): print(f"Saved to {bar_file}") -def compare_across_saes(out_dir: Path): - # Get data about the quality of the autointerp from the .jsonl files - df_stats = get_autointerp_results_df(out_dir) +def compare_across_saes(df_stats: pd.DataFrame): # Plot the relative performance of the SAEs sns.set_theme(style="whitegrid", palette="pastel") g = sns.catplot( @@ -590,8 +589,7 @@ def compare_across_saes(out_dir: Path): print(bar_file) -def bootstrapped_bar(out_dir: Path): - results_df = get_autointerp_results_df(out_dir) +def bootstrapped_bar(df_stats: pd.DataFrame): fig, axs = plt.subplots(1, 2, figsize=(8, 4), sharey=True) sae_names = { @@ -601,7 +599,7 @@ def bootstrapped_bar(out_dir: Path): } # make matplotlib histogram with ci as error bars for layer, ax in zip([6, 10], axs, strict=True): - layer_data = results_df.loc[results_df.layer == layer] + layer_data = df_stats.loc[df_stats.layer == layer] means, yerrs = [], [[], []] for sae_type in sae_names: @@ -620,10 +618,43 @@ def bootstrapped_bar(out_dir: Path): plt.tight_layout() plt.show() + plt.savefig(out_dir / "bootstrapped_bar.png") + print(f"Saved to {out_dir / 'bootstrapped_bar.png'}") + plt.close() + + +def bootstrap_p_value(sample_a: ArrayLike, sample_b: ArrayLike) -> float: + """ + Computes 2 sided p-value, for null hypothesis that means are the same + """ + sample_a, sample_b = np.asarray(sample_a), np.asarray(sample_b) + mean_diff = np.mean(sample_a) - np.mean(sample_b) + n_bootstraps = 100_000 + bootstrapped_diffs = [] + combined = np.concatenate([sample_a, sample_b]) + for _ in range(n_bootstraps): + boot_a = np.random.choice(combined, len(sample_a), replace=True) + boot_b = np.random.choice(combined, len(sample_b), replace=True) + boot_diff = np.mean(boot_a) - np.mean(boot_b) + bootstrapped_diffs.append(boot_diff) + bootstrapped_diffs = np.array(bootstrapped_diffs) + p_value = np.mean(np.abs(bootstrapped_diffs) > np.abs(mean_diff)) + return p_value + + +def compute_p_values(df: pd.DataFrame): + ref_sae = "res_scefr-ajt" + for layer in [6, 10]: + for name, sae in [("CE Local", "res_sll-ajt"), ("L0 Local", "res_scl-ajt")]: + pval = bootstrap_p_value( + df.loc[(df.layer == layer) & (df.sae == sae)].explanationScore.to_numpy(), + df.loc[(df.layer == layer) & (df.sae == ref_sae)].explanationScore.to_numpy(), + ) + print(f"L{layer}, Downstream vs {name}: p={pval}") if __name__ == "__main__": - out_dir = Path(__file__).parent / "out/autointerp/jordan_results_4_29" + out_dir = Path(__file__).parent / "out/autointerp" ## Running autointerp # Compare similar CE e2e+Downstream with similar CE local and similar L0 local sae_sets = [ @@ -640,7 +671,9 @@ def bootstrapped_bar(out_dir: Path): autointerp_scorer_model_name="gpt-3.5-turbo", ) + df = get_autointerp_results_df(out_dir) ## Analysis of autointerp results - compare_autointerp_results(out_dir) - compare_across_saes(out_dir) - bootstrapped_bar(out_dir) + compare_autointerp_results(df) + compare_across_saes(df) + bootstrapped_bar(df) + compute_p_values(df) From 6f89c1f6a40d57b2f43e272800df646f79e59cbd Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Tue, 30 Apr 2024 15:43:19 +0100 Subject: [PATCH 05/12] Update autointerp config --- e2e_sae/scripts/autointerp.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/e2e_sae/scripts/autointerp.py b/e2e_sae/scripts/autointerp.py index 62a5f37..8d36058 100644 --- a/e2e_sae/scripts/autointerp.py +++ b/e2e_sae/scripts/autointerp.py @@ -69,7 +69,10 @@ def get_neuronpedia_feature( """Fetch a feature from Neuronpedia API.""" url = f"{NEURONPEDIA_DOMAIN}/api/feature/{model}/{layer}-{dataset}/{feature}" result = requests.get(url).json() - result["index"] = int(result["index"]) + if "index" in result: + result["index"] = int(result["index"]) + else: + raise Exception(f"Feature {model}@{layer}-{dataset}:{feature} does not exist.") return result @@ -656,11 +659,12 @@ def compute_p_values(df: pd.DataFrame): if __name__ == "__main__": out_dir = Path(__file__).parent / "out/autointerp" ## Running autointerp - # Compare similar CE e2e+Downstream with similar CE local and similar L0 local + # Get runs for similar CE and similar L0 for e2e+Downstream and local + # Note that "10-res_slefr-ajt" does not exist, we use 10-res_scefr-ajt for similar l0 too sae_sets = [ - ["6-res_scefr-ajt", "6-res_sll-ajt", "6-res_scl-ajt"], + ["6-res_scefr-ajt", "6-res_slefr-ajt", "6-res_sll-ajt", "6-res_scl-ajt"], ["10-res_scefr-ajt", "10-res_sll-ajt", "10-res_scl-ajt"], - ["2-res_scefr-ajt", "2-res_sll-ajt", "2-res_scl-ajt"], + ["2-res_scefr-ajt", "2-res_slefr-ajt", "2-res_sll-ajt", "2-res_scl-ajt"], ] run_autointerp( sae_sets=sae_sets, From 265e140b47aa6edab501977ba4763700152737ed Mon Sep 17 00:00:00 2001 From: Nix Goldowsky-Dill Date: Tue, 30 Apr 2024 15:40:27 +0000 Subject: [PATCH 06/12] violin plot --- e2e_sae/scripts/autointerp.py | 125 ++++++++++++++++++++++++++++------ 1 file changed, 104 insertions(+), 21 deletions(-) diff --git a/e2e_sae/scripts/autointerp.py b/e2e_sae/scripts/autointerp.py index 8d36058..491bd5d 100644 --- a/e2e_sae/scripts/autointerp.py +++ b/e2e_sae/scripts/autointerp.py @@ -20,6 +20,7 @@ import pandas as pd import requests import seaborn as sns +import statsmodels.stats.api as sms from dotenv import load_dotenv from neuron_explainer.activations.activation_records import calculate_max_activation from neuron_explainer.activations.activations import ActivationRecord @@ -458,6 +459,7 @@ def get_autointerp_results_df(out_dir: Path): "explanationModel": [], "autointerpModel": [], "explanation": [], + "sae_type": [], } for autointerp_file in tqdm(autointerp_files, "constructing stats from json autointerp files"): with open(autointerp_file) as f: @@ -471,6 +473,9 @@ def get_autointerp_results_df(out_dir: Path): stats["autointerpModel"].append(json_data["autointerpModel"]) stats["explanationModel"].append("gpt-4-1106-preview") stats["explanation"].append(json_data["explanation"]) + stats["sae_type"].append( + {"e": "e2e", "l": "local", "r": "downstream"}[json_data["layer"].split("-")[1][-1]] + ) df_stats = pd.DataFrame(stats) assert not df_stats.empty df_stats.dropna(inplace=True) @@ -552,7 +557,7 @@ def compare_across_saes(df_stats: pd.DataFrame): "marker": "o", "markerfacecolor": "white", "markeredgecolor": "black", - "markersize": "10", + "markersize": "5", }, ) sns.swarmplot( @@ -626,6 +631,60 @@ def bootstrapped_bar(df_stats: pd.DataFrame): plt.close() +def pair_violin_plot(df_stats: pd.DataFrame, pairs: dict[int, dict[str, str]]): + fig, axs = plt.subplots(1, 3, sharey=True, figsize=(7, 3.5)) + + for ax, layer in zip(axs, pairs.keys(), strict=True): + sae_ids = [pairs[layer]["local"], pairs[layer]["downstream"]] + layer_data = df_stats[df_stats["layer"] == layer] + grouped_scores = layer_data.groupby("sae")["explanationScore"] + mean_explanationScores = [grouped_scores.get_group(sae_id).mean() for sae_id in sae_ids] + confidence_intervals = [ + sms.DescrStatsW(grouped_scores.get_group(sae_id)).tconfint_mean() for sae_id in sae_ids + ] + + data = layer_data.loc[layer_data.sae.isin(sae_ids)] + colors = { + "local": "#f0a70a", + "e2e": "#518c31", + "downstream": plt.get_cmap("tab20b").colors[2], # type: ignore[reportAttributeAccessIssue] + } + l = sns.violinplot( + data=data, + hue="sae_type", + x=[1 if sae_type == "downstream" else 0 for sae_type in data.sae_type], + y="explanationScore", + palette=colors, + native_scale=True, + orient="v", + ax=ax, + inner=None, + cut=0, + bw_adjust=0.7, + hue_order=["local", "downstream"], + alpha=0.8, + legend=False, + ) + ax.set_xticks([0, 1], ["Local", "Downstream"]) + ax.errorbar( + x=(0, 1), + y=mean_explanationScores, + yerr=[(c[1] - c[0]) / 2 for c in confidence_intervals], + fmt="o", + color="black", + capsize=5, + ) + ax.set_ylim((None, 1)) + ax.set_title(label=f"Layer {layer}") + + axs[0].set_ylabel("Auto-intepretability score") + + plt.tight_layout() + out_file = Path(__file__).parent / "out/autointerp" / "autointerp_same_l0.png" + plt.savefig(out_file, bbox_inches="tight") + print(f"Saved to {out_file}") + + def bootstrap_p_value(sample_a: ArrayLike, sample_b: ArrayLike) -> float: """ Computes 2 sided p-value, for null hypothesis that means are the same @@ -645,19 +704,34 @@ def bootstrap_p_value(sample_a: ArrayLike, sample_b: ArrayLike) -> float: return p_value -def compute_p_values(df: pd.DataFrame): - ref_sae = "res_scefr-ajt" - for layer in [6, 10]: - for name, sae in [("CE Local", "res_sll-ajt"), ("L0 Local", "res_scl-ajt")]: - pval = bootstrap_p_value( - df.loc[(df.layer == layer) & (df.sae == sae)].explanationScore.to_numpy(), - df.loc[(df.layer == layer) & (df.sae == ref_sae)].explanationScore.to_numpy(), - ) - print(f"L{layer}, Downstream vs {name}: p={pval}") +def bootstrap_mean_diff(sample_a: ArrayLike, sample_b: ArrayLike): + def mean_diff(resample_a, resample_b): # type: ignore + return np.mean(resample_a) - np.mean(resample_b) + + result = bootstrap((sample_a, sample_b), n_resamples=100_000, statistic=mean_diff) + diff = np.mean(sample_a) - np.mean(sample_b) # type: ignore + low, high = result.confidence_interval + print(f"Diff {diff:.2f}, [{low:.2f},{high:.2f}]") + + +def compute_p_values(df: pd.DataFrame, pairs: dict[int, dict[str, str]]): + score_groups = df.groupby(["layer", "sae"]).explanationScore + for layer, sae_dict in pairs.items(): + sae_local, sae_downstream = sae_dict["local"], sae_dict["downstream"] + pval = bootstrap_p_value( + score_groups.get_group((layer, sae_downstream)).to_numpy(), + score_groups.get_group((layer, sae_local)).to_numpy(), + ) + print(f"L{layer}, {sae_downstream} vs {sae_local}: p={pval}") + bootstrap_mean_diff( + score_groups.get_group((layer, sae_downstream)).to_numpy(), + score_groups.get_group((layer, sae_local)).to_numpy(), + ) if __name__ == "__main__": - out_dir = Path(__file__).parent / "out/autointerp" + out_dir = Path(__file__).parent / "out/autointerp/" + input_dir = out_dir # Path("/data/apollo/autointerp") ## Running autointerp # Get runs for similar CE and similar L0 for e2e+Downstream and local # Note that "10-res_slefr-ajt" does not exist, we use 10-res_scefr-ajt for similar l0 too @@ -666,18 +740,27 @@ def compute_p_values(df: pd.DataFrame): ["10-res_scefr-ajt", "10-res_sll-ajt", "10-res_scl-ajt"], ["2-res_scefr-ajt", "2-res_slefr-ajt", "2-res_sll-ajt", "2-res_scl-ajt"], ] - run_autointerp( - sae_sets=sae_sets, - n_random_features=50, - dict_size=768 * 60, - feature_model_id="gpt2-small", - autointerp_explainer_model_name="gpt-4-turbo-2024-04-09", - autointerp_scorer_model_name="gpt-3.5-turbo", - ) + # run_autointerp( + # sae_sets=sae_sets, + # n_random_features=50, + # dict_size=768 * 60, + # feature_model_id="gpt2-small", + # autointerp_explainer_model_name="gpt-4-turbo-2024-04-09", + # autointerp_scorer_model_name="gpt-3.5-turbo", + # ) + + # TO UPDATE + const_l0_pairs = { + 2: {"local": "res_sll-ajt", "downstream": "res_slefr-ajt"}, + 6: {"local": "res_sll-ajt", "downstream": "res_scefr-ajt"}, + 10: {"local": "res_sll-ajt", "downstream": "res_scefr-ajt"}, + } - df = get_autointerp_results_df(out_dir) + df = get_autointerp_results_df(input_dir) + print(df.autointerpModel.unique(), df.explanationModel.unique()) ## Analysis of autointerp results compare_autointerp_results(df) compare_across_saes(df) bootstrapped_bar(df) - compute_p_values(df) + pair_violin_plot(df, const_l0_pairs) + compute_p_values(df, const_l0_pairs) From a8b4b57b118cbac3dc9ef36851a585fbb85c1c10 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 1 May 2024 18:42:33 +0000 Subject: [PATCH 07/12] Clean autointerp --- e2e_sae/scripts/__init__.py | 0 e2e_sae/scripts/autointerp.py | 124 +++++++++++++++++++--------------- 2 files changed, 68 insertions(+), 56 deletions(-) create mode 100644 e2e_sae/scripts/__init__.py diff --git a/e2e_sae/scripts/__init__.py b/e2e_sae/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/e2e_sae/scripts/autointerp.py b/e2e_sae/scripts/autointerp.py index 491bd5d..f24df76 100644 --- a/e2e_sae/scripts/autointerp.py +++ b/e2e_sae/scripts/autointerp.py @@ -5,7 +5,6 @@ This has been updated to work with gpt4-turbo-2024-04-09 and fixes an OPENAI_API_KEY issue. """ - import asyncio import glob import json @@ -385,7 +384,7 @@ async def autointerp_neuronpedia_features( def run_autointerp( - sae_sets: list[list[str]], + saes: list[str], n_random_features: int, dict_size: int, feature_model_id: str, @@ -395,6 +394,7 @@ def run_autointerp( autointerp_scorer_model_name: Literal[ "gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview", "gpt-4-turbo-2024-04-09" ], + out_dir: str | Path = "neuronpedia_outputs/autointerp", ): """Explain and score random features across SAEs and upload to Neuronpedia. @@ -406,47 +406,49 @@ def run_autointerp( autointerp_explainer_model_name: Model name for autointerp explainer. autointerp_scorer_model_name: Model name for autointerp scorer (much more expensive). """ - for sae_set in tqdm(sae_sets, desc="sae sets"): - for i in tqdm(range(n_random_features), desc="random features"): - for sae in tqdm(sae_set, desc="sae"): - layer = int(sae.split("-")[0]) - dataset = "-".join(sae.split("-")[1:]) - feature_exists = False - while not feature_exists: - feature = random.randint(0, dict_size) - feature_data = get_neuronpedia_feature( - feature=feature, - layer=layer, - model=feature_model_id, - dataset=dataset, - ) - if "activations" in feature_data and len(feature_data["activations"]) >= 20: - feature_exists = True - print(f"sae: {sae}, feature: {feature}") - - features = [ - NeuronpediaFeature( - modelId=feature_model_id, - layer=layer, - dataset=dataset, - feature=feature, - ), - ] - - asyncio.run( - autointerp_neuronpedia_features( - features=features, - openai_api_key=os.getenv("OPENAI_API_KEY", ""), - neuronpedia_api_key=os.getenv("NEURONPEDIA_API_KEY", ""), - autointerp_explainer_model_name=autointerp_explainer_model_name, - autointerp_scorer_model_name=autointerp_scorer_model_name, - num_activations_to_use=20, - max_explanation_activation_records=5, - ) + for i in tqdm(range(n_random_features), desc="random features", total=n_random_features): + for sae in tqdm(saes, desc="sae", total=len(saes)): + layer = int(sae.split("-")[0]) + dataset = "-".join(sae.split("-")[1:]) + feature_exists = False + while not feature_exists: + feature = random.randint(0, dict_size) + feature_data = get_neuronpedia_feature( + feature=feature, + layer=layer, + model=feature_model_id, + dataset=dataset, + ) + if "activations" in feature_data and len(feature_data["activations"]) >= 20: + feature_exists = True + print(f"sae: {sae}, feature: {feature}") + + features = [ + NeuronpediaFeature( + modelId=feature_model_id, + layer=layer, + dataset=dataset, + feature=feature, + ), + ] + + asyncio.run( + autointerp_neuronpedia_features( + features=features, + openai_api_key=os.getenv("OPENAI_API_KEY", ""), + neuronpedia_api_key=os.getenv("NEURONPEDIA_API_KEY", ""), + autointerp_explainer_model_name=autointerp_explainer_model_name, + autointerp_scorer_model_name=autointerp_scorer_model_name, + num_activations_to_use=20, + max_explanation_activation_records=5, + output_dir=str(out_dir), ) + ) - feature_url = f"https://neuronpedia.org/{feature_model_id}/{layer}-{dataset}/{feature}" - print(f"Your feature is at: {feature_url}") + feature_url = ( + f"https://neuronpedia.org/{feature_model_id}/{layer}-{dataset}/{feature}" + ) + print(f"Your feature is at: {feature_url}") def get_autointerp_results_df(out_dir: Path): @@ -730,24 +732,34 @@ def compute_p_values(df: pd.DataFrame, pairs: dict[int, dict[str, str]]): if __name__ == "__main__": - out_dir = Path(__file__).parent / "out/autointerp/" - input_dir = out_dir # Path("/data/apollo/autointerp") + # out_dir = Path(__file__).parent / "out/autointerp/" + # input_dir = out_dir # Path("/data/apollo/autointerp") + out_dir = Path("/data/apollo/autointerp/") ## Running autointerp # Get runs for similar CE and similar L0 for e2e+Downstream and local # Note that "10-res_slefr-ajt" does not exist, we use 10-res_scefr-ajt for similar l0 too - sae_sets = [ - ["6-res_scefr-ajt", "6-res_slefr-ajt", "6-res_sll-ajt", "6-res_scl-ajt"], - ["10-res_scefr-ajt", "10-res_sll-ajt", "10-res_scl-ajt"], - ["2-res_scefr-ajt", "2-res_slefr-ajt", "2-res_sll-ajt", "2-res_scl-ajt"], + saes = [ + "2-res_scefr-ajt", + "2-res_slefr-ajt", + "2-res_sll-ajt", + "2-res_scl-ajt", + "6-res_scefr-ajt", + "6-res_slefr-ajt", + "6-res_sll-ajt", + "6-res_scl-ajt", + "10-res_scefr-ajt", + "10-res_sll-ajt", + "10-res_scl-ajt", ] - # run_autointerp( - # sae_sets=sae_sets, - # n_random_features=50, - # dict_size=768 * 60, - # feature_model_id="gpt2-small", - # autointerp_explainer_model_name="gpt-4-turbo-2024-04-09", - # autointerp_scorer_model_name="gpt-3.5-turbo", - # ) + run_autointerp( + saes=saes, + n_random_features=150, + dict_size=768 * 60, + feature_model_id="gpt2-small", + autointerp_explainer_model_name="gpt-4-turbo-2024-04-09", + autointerp_scorer_model_name="gpt-3.5-turbo", + out_dir=out_dir, + ) # TO UPDATE const_l0_pairs = { @@ -756,7 +768,7 @@ def compute_p_values(df: pd.DataFrame, pairs: dict[int, dict[str, str]]): 10: {"local": "res_sll-ajt", "downstream": "res_scefr-ajt"}, } - df = get_autointerp_results_df(input_dir) + df = get_autointerp_results_df(out_dir) print(df.autointerpModel.unique(), df.explanationModel.unique()) ## Analysis of autointerp results compare_autointerp_results(df) From 95b1a5286d839241244e97982bf8243f128dc9ad Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 1 May 2024 18:44:15 +0000 Subject: [PATCH 08/12] Update pyproject --- pyproject.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 53373b2..ec6b351 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,11 +18,12 @@ dependencies = [ "jaxtyping~=0.2.25", "python-dotenv~=1.0.1", "zstandard~=0.22.0", - "matplotlib>=3.5.3", + "matplotlib~=3.5.3", + "seaborn~=0.13.2", "eindex-callum@git+https://github.com/callummcdougall/eindex", # "sae_vis@git+https://github.com/callummcdougall/sae_vis.git@b28a0f7c7e936f4bea05528d952dfcd438533cce", "sae_lens@git+https://github.com/jbloomAus/SAELens.git@7c43c4caa84aea421ac81ae0e326d9c62bb17bec", - # "neuron_explainer@git+https://github.com/hijohnnylin/automated-interpretability.git" + "neuron_explainer@git+https://github.com/ApolloResearch/automated-interpretability.git" ] [project.urls] From d268e7517e9f3440f3326f509e98b652d9fbc771 Mon Sep 17 00:00:00 2001 From: Nix Goldowsky-Dill Date: Thu, 2 May 2024 10:13:57 +0000 Subject: [PATCH 09/12] pair plots --- e2e_sae/scripts/autointerp.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/e2e_sae/scripts/autointerp.py b/e2e_sae/scripts/autointerp.py index f24df76..e0e0552 100644 --- a/e2e_sae/scripts/autointerp.py +++ b/e2e_sae/scripts/autointerp.py @@ -633,7 +633,7 @@ def bootstrapped_bar(df_stats: pd.DataFrame): plt.close() -def pair_violin_plot(df_stats: pd.DataFrame, pairs: dict[int, dict[str, str]]): +def pair_violin_plot(df_stats: pd.DataFrame, pairs: dict[int, dict[str, str]], out_file: Path): fig, axs = plt.subplots(1, 3, sharey=True, figsize=(7, 3.5)) for ax, layer in zip(axs, pairs.keys(), strict=True): @@ -682,7 +682,6 @@ def pair_violin_plot(df_stats: pd.DataFrame, pairs: dict[int, dict[str, str]]): axs[0].set_ylabel("Auto-intepretability score") plt.tight_layout() - out_file = Path(__file__).parent / "out/autointerp" / "autointerp_same_l0.png" plt.savefig(out_file, bbox_inches="tight") print(f"Saved to {out_file}") @@ -761,18 +760,28 @@ def compute_p_values(df: pd.DataFrame, pairs: dict[int, dict[str, str]]): out_dir=out_dir, ) - # TO UPDATE + df = get_autointerp_results_df(out_dir) + + ## Analysis of autointerp results + const_l0_pairs = { 2: {"local": "res_sll-ajt", "downstream": "res_slefr-ajt"}, - 6: {"local": "res_sll-ajt", "downstream": "res_scefr-ajt"}, + 6: {"local": "res_sll-ajt", "downstream": "res_slefr-ajt"}, 10: {"local": "res_sll-ajt", "downstream": "res_scefr-ajt"}, } - df = get_autointerp_results_df(out_dir) - print(df.autointerpModel.unique(), df.explanationModel.unique()) - ## Analysis of autointerp results - compare_autointerp_results(df) - compare_across_saes(df) + const_ce_pairs = { + 2: {"local": "res_scl-ajt", "downstream": "res_scefr-ajt"}, + 6: {"local": "res_scl-ajt", "downstream": "res_scefr-ajt"}, + 10: {"local": "res_scl-ajt", "downstream": "res_scefr-ajt"}, + } + + # compare_autointerp_results(df) + # compare_across_saes(df) bootstrapped_bar(df) - pair_violin_plot(df, const_l0_pairs) + pair_violin_plot(df, const_l0_pairs, out_dir / "l0_violin.png") + pair_violin_plot(df, const_ce_pairs, out_dir / "ce_violin.png") + print("SAME L0") compute_p_values(df, const_l0_pairs) + print("\nSAME CE") + compute_p_values(df, const_ce_pairs) From fcc62cc26651b44db43d5ae6a0a9cfcb7aa7b29b Mon Sep 17 00:00:00 2001 From: Nix Goldowsky-Dill Date: Thu, 9 May 2024 16:26:17 +0000 Subject: [PATCH 10/12] cleanup autointerp script --- e2e_sae/scripts/autointerp.py | 175 +++------------------------------- 1 file changed, 12 insertions(+), 163 deletions(-) diff --git a/e2e_sae/scripts/autointerp.py b/e2e_sae/scripts/autointerp.py index e0e0552..91e9514 100644 --- a/e2e_sae/scripts/autointerp.py +++ b/e2e_sae/scripts/autointerp.py @@ -484,155 +484,6 @@ def get_autointerp_results_df(out_dir: Path): return df_stats -def compare_autointerp_results(df_stats: pd.DataFrame): - """ - Compare autointerp results across SAEs. - - Args: - output_dir: Directory containing the autointerp output files. - - Returns: - None - """ - sns.set_theme(style="whitegrid", palette="pastel") - g = sns.catplot( - data=df_stats, - x="sae", - y="explanationScore", - kind="box", - palette="pastel", - hue="layer", - showmeans=True, - meanprops={ - "marker": "o", - "markerfacecolor": "white", - "markeredgecolor": "black", - "markersize": "10", - }, - ) - sns.swarmplot( - data=df_stats, x="sae", y="explanationScore", color="k", size=4, ax=g.ax, legend=False - ) - plt.title("Quality of auto-interpretability explanations across SAEs") - plt.ylabel("Auto-interpretability score") - plt.xlabel("SAE") - results_file = out_dir / "Auto-interpretability_results.png" - plt.savefig(results_file, bbox_inches="tight") - # plt.savefig("Auto-interpretability_results.pdf", bbox_inches="tight") - # plt.show() - print(f"Saved to {results_file}") - - sns.set_theme(rc={"figure.figsize": (6, 6)}) - b = sns.barplot( - df_stats, - x="sae", - y="explanationScore", - palette="pastel", - hue="layer", - capsize=0.3, - legend=False, - ) - s = sns.swarmplot(data=df_stats, x="sae", y="explanationScore", color="k", alpha=0.25, size=6) - plt.title("Quality of auto-interpretability explanations across SAEs") - plt.ylabel("Auto-interpretability score") - plt.xlabel("SAE") - plt.yticks(np.arange(0, 1, 0.1)) - bar_file = out_dir / "Auto-interpretability_results_bar.png" - plt.savefig(bar_file, bbox_inches="tight") - # plt.savefig("Auto-interpretability_results_bar.pdf", bbox_inches="tight") - # plt.show() - print(f"Saved to {bar_file}") - - -def compare_across_saes(df_stats: pd.DataFrame): - # Plot the relative performance of the SAEs - sns.set_theme(style="whitegrid", palette="pastel") - g = sns.catplot( - data=df_stats, - x="sae", - y="explanationScore", - kind="box", - palette="pastel", - hue="layer", - showmeans=True, - meanprops={ - "marker": "o", - "markerfacecolor": "white", - "markeredgecolor": "black", - "markersize": "5", - }, - ) - sns.swarmplot( - data=df_stats, x="sae", y="explanationScore", color="k", size=4, ax=g.ax, legend=False - ) - # ax.set_title("Quality of auto-interpretability explanations across SAEs") - # ax.set_ylabel("Auto-interpretability score") - # ax.set_xlabel("SAE") - result_file = out_dir / "Auto-interpretability_results.png" - plt.savefig(result_file, bbox_inches="tight") - plt.close() - # plt.savefig("Auto-interpretability_results.pdf", bbox_inches="tight") - # plt.show() - print(result_file) - - # Plot the relative performance of the SAEs - sns.set_theme(rc={"figure.figsize": (6, 6)}) - b = sns.barplot( - df_stats, - x="sae", - y="explanationScore", - palette="pastel", - hue="layer", - capsize=0.3, - legend=False, - ) - s = sns.swarmplot(data=df_stats, x="sae", y="explanationScore", color="k", alpha=0.25, size=6) - plt.title("Quality of auto-interpretability explanations across SAEs") - plt.ylabel("Auto-interpretability score") - plt.xlabel("SAE") - plt.yticks(np.arange(0, 1, 0.1)) - bar_file = out_dir / "Auto-interpretability_results_bar.png" - plt.savefig(bar_file, bbox_inches="tight") - plt.close() - # plt.savefig("Auto-interpretability_results_bar.pdf", bbox_inches="tight") - # plt.show() - print(bar_file) - - -def bootstrapped_bar(df_stats: pd.DataFrame): - fig, axs = plt.subplots(1, 2, figsize=(8, 4), sharey=True) - - sae_names = { - "res_scefr-ajt": "Downstream", - "res_sll-ajt": "CE Local", - "res_scl-ajt": "L0 Local", - } - # make matplotlib histogram with ci as error bars - for layer, ax in zip([6, 10], axs, strict=True): - layer_data = df_stats.loc[df_stats.layer == layer] - - means, yerrs = [], [[], []] - for sae_type in sae_names: - sae_data = layer_data.loc[layer_data.sae == sae_type] - scores = sae_data.explanationScore.to_numpy() - ci = bootstrap((scores,), statistic=np.mean).confidence_interval - means.append(scores.mean()) - yerrs[0].append(scores.mean() - ci.low) - yerrs[1].append(ci.high - scores.mean()) - - ax.bar(range(3), means, yerr=yerrs, capsize=5) - ax.set_title(f"Layer {layer}") - ax.set_xticks(range(3), sae_names.values()) - - axs[0].set_ylabel("Mean Explanation Score") - - plt.tight_layout() - plt.show() - plt.savefig(out_dir / "bootstrapped_bar.png") - print(f"Saved to {out_dir / 'bootstrapped_bar.png'}") - plt.close() - - def pair_violin_plot(df_stats: pd.DataFrame, pairs: dict[int, dict[str, str]], out_file: Path): fig, axs = plt.subplots(1, 3, sharey=True, figsize=(7, 3.5)) @@ -705,14 +556,13 @@ def bootstrap_p_value(sample_a: ArrayLike, sample_b: ArrayLike) -> float: return p_value -def bootstrap_mean_diff(sample_a: ArrayLike, sample_b: ArrayLike): +def bootstrap_mean_diff(sample_a: ArrayLike, sample_b: ArrayLike) -> tuple[float, Any]: def mean_diff(resample_a, resample_b): # type: ignore return np.mean(resample_a) - np.mean(resample_b) result = bootstrap((sample_a, sample_b), n_resamples=100_000, statistic=mean_diff) - diff = np.mean(sample_a) - np.mean(sample_b) # type: ignore - low, high = result.confidence_interval - print(f"Diff {diff:.2f}, [{low:.2f},{high:.2f}]") + diff: float = np.mean(sample_a) - np.mean(sample_b) # type: ignore + return diff, result.confidence_interval def compute_p_values(df: pd.DataFrame, pairs: dict[int, dict[str, str]]): @@ -723,17 +573,17 @@ def compute_p_values(df: pd.DataFrame, pairs: dict[int, dict[str, str]]): score_groups.get_group((layer, sae_downstream)).to_numpy(), score_groups.get_group((layer, sae_local)).to_numpy(), ) - print(f"L{layer}, {sae_downstream} vs {sae_local}: p={pval}") - bootstrap_mean_diff( + # print(f"L{layer}, {sae_downstream} vs {sae_local}: p={pval}") + diff, ci = bootstrap_mean_diff( score_groups.get_group((layer, sae_downstream)).to_numpy(), score_groups.get_group((layer, sae_local)).to_numpy(), ) + print(f"{layer}&${diff:.2f}\\ [{ci.low:.2f},{ci.high:.2f}]$&{pval:.2g}") if __name__ == "__main__": - # out_dir = Path(__file__).parent / "out/autointerp/" - # input_dir = out_dir # Path("/data/apollo/autointerp") - out_dir = Path("/data/apollo/autointerp/") + plot_out_dir = Path(__file__).parent / "out/autointerp/" + score_out_dir = Path("/data/apollo/autointerp/") ## Running autointerp # Get runs for similar CE and similar L0 for e2e+Downstream and local # Note that "10-res_slefr-ajt" does not exist, we use 10-res_scefr-ajt for similar l0 too @@ -757,10 +607,10 @@ def compute_p_values(df: pd.DataFrame, pairs: dict[int, dict[str, str]]): feature_model_id="gpt2-small", autointerp_explainer_model_name="gpt-4-turbo-2024-04-09", autointerp_scorer_model_name="gpt-3.5-turbo", - out_dir=out_dir, + out_dir=score_out_dir, ) - df = get_autointerp_results_df(out_dir) + df = get_autointerp_results_df(score_out_dir) ## Analysis of autointerp results @@ -778,9 +628,8 @@ def compute_p_values(df: pd.DataFrame, pairs: dict[int, dict[str, str]]): # compare_autointerp_results(df) # compare_across_saes(df) - bootstrapped_bar(df) - pair_violin_plot(df, const_l0_pairs, out_dir / "l0_violin.png") - pair_violin_plot(df, const_ce_pairs, out_dir / "ce_violin.png") + pair_violin_plot(df, const_l0_pairs, plot_out_dir / "l0_violin.png") + pair_violin_plot(df, const_ce_pairs, plot_out_dir / "ce_violin.png") print("SAME L0") compute_p_values(df, const_l0_pairs) print("\nSAME CE") From 1dc38c7b10bbe65bd357cb2957be97b9b66d1433 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 10 May 2024 17:28:47 +0100 Subject: [PATCH 11/12] Update env vars and required packages --- .env.example | 4 +++- e2e_sae/scripts/autointerp.py | 12 +++++++++--- pyproject.toml | 2 -- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/.env.example b/.env.example index 30865d0..aa1dc70 100644 --- a/.env.example +++ b/.env.example @@ -1,2 +1,4 @@ WANDB_API_KEY=your_api_key -WANDB_ENTITY=your_entity_name \ No newline at end of file +WANDB_ENTITY=your_entity_name +OPENAI_API_KEY=your_api_key # For autointerp +NEURONPEDIA_API_KEY=your_api_key # For neuronpdeia \ No newline at end of file diff --git a/e2e_sae/scripts/autointerp.py b/e2e_sae/scripts/autointerp.py index 91e9514..46be270 100644 --- a/e2e_sae/scripts/autointerp.py +++ b/e2e_sae/scripts/autointerp.py @@ -1,9 +1,14 @@ """Run and analyze autointerp for SAEs. NOTE: Running this script currently requires installing ApolloResearch's fork of neuron-explainer. -https://github.com/ApolloResearch/automated-interpretability +https://github.com/ApolloResearch/automated-interpretability. This has been updated to work with +gpt4-turbo-2024-04-09 and fixes an OPENAI_API_KEY issue. + +This script requires the following environment variables: +- OPENAI_API_KEY: OpenAI API key. +- NEURONPEDIA_API_KEY: Neuronpedia API key. +These can be set in .env in the root of the repository (see .env.example). -This has been updated to work with gpt4-turbo-2024-04-09 and fixes an OPENAI_API_KEY issue. """ import asyncio import glob @@ -583,7 +588,8 @@ def compute_p_values(df: pd.DataFrame, pairs: dict[int, dict[str, str]]): if __name__ == "__main__": plot_out_dir = Path(__file__).parent / "out/autointerp/" - score_out_dir = Path("/data/apollo/autointerp/") + score_out_dir = Path(__file__).parent / "out/autointerp/" + score_out_dir.mkdir(parents=True, exist_ok=True) ## Running autointerp # Get runs for similar CE and similar L0 for e2e+Downstream and local # Note that "10-res_slefr-ajt" does not exist, we use 10-res_scefr-ajt for similar l0 too diff --git a/pyproject.toml b/pyproject.toml index ec6b351..a90f024 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,8 +21,6 @@ dependencies = [ "matplotlib~=3.5.3", "seaborn~=0.13.2", "eindex-callum@git+https://github.com/callummcdougall/eindex", - # "sae_vis@git+https://github.com/callummcdougall/sae_vis.git@b28a0f7c7e936f4bea05528d952dfcd438533cce", - "sae_lens@git+https://github.com/jbloomAus/SAELens.git@7c43c4caa84aea421ac81ae0e326d9c62bb17bec", "neuron_explainer@git+https://github.com/ApolloResearch/automated-interpretability.git" ] From 31f90514ce2b3355d24327de7d7ba6916fb6ecc7 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 10 May 2024 17:38:11 +0100 Subject: [PATCH 12/12] Update required packages --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index a90f024..f5df53d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,8 @@ dependencies = [ "zstandard~=0.22.0", "matplotlib~=3.5.3", "seaborn~=0.13.2", + "tenacity~=8.2.3", + "statsmodels~=0.14.2", "eindex-callum@git+https://github.com/callummcdougall/eindex", "neuron_explainer@git+https://github.com/ApolloResearch/automated-interpretability.git" ]