From f7e76b614359127b221f9b124bd526d9bdb41988 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Wed, 5 Nov 2025 12:05:04 -0500 Subject: [PATCH 01/82] Create hubert_block.py --- transformer_lens/components/hubert_block.py | 52 +++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 transformer_lens/components/hubert_block.py diff --git a/transformer_lens/components/hubert_block.py b/transformer_lens/components/hubert_block.py new file mode 100644 index 000000000..143ac2ccd --- /dev/null +++ b/transformer_lens/components/hubert_block.py @@ -0,0 +1,52 @@ +class HubertBlock(nn.Module): + """ + HuBERT-style Transformer Block (Pre-LayerNorm). + Structurally similar to BERTBlock, but with LayerNorm applied before each sublayer. + """ + + def __init__(self, cfg: HookedTransformerConfig): + super().__init__() + self.cfg = cfg + + self.attn = Attention(cfg) + self.ln1 = LayerNorm(cfg) + self.mlp = MLPFactory.create_mlp(self.cfg) + self.ln2 = LayerNorm(cfg) + + self.hook_q_input = HookPoint() + self.hook_k_input = HookPoint() + self.hook_v_input = HookPoint() + self.hook_attn_out = HookPoint() + self.hook_mlp_in = HookPoint() + self.hook_mlp_out = HookPoint() + self.hook_resid_pre = HookPoint() + self.hook_resid_mid = HookPoint() + self.hook_resid_post = HookPoint() + self.hook_normalized_resid_post = HookPoint() + + def forward( + self, + resid_pre: Float[torch.Tensor, "batch pos d_model"], + additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None, + ) -> Float[torch.Tensor, "batch pos d_model"]: + resid_pre = self.hook_resid_pre(resid_pre) + + # --- Attention sublayer --- + normed = self.ln1(resid_pre) + attn_out = self.hook_attn_out( + self.attn( + self.hook_q_input(repeat_along_head_dimension(normed, self.cfg.n_heads)), + self.hook_k_input(repeat_along_head_dimension(normed, self.cfg.n_heads)), + self.hook_v_input(repeat_along_head_dimension(normed, self.cfg.n_heads)), + additive_attention_mask=additive_attention_mask, + ) + ) + resid_mid = self.hook_resid_mid(resid_pre + attn_out) + + # --- Feedforward sublayer --- + normed_mid = self.hook_normalized_resid_post(self.ln2(resid_mid)) + mlp_in = self.hook_mlp_in(normed_mid.clone()) if self.cfg.use_hook_mlp_in else normed_mid + mlp_out = self.hook_mlp_out(self.mlp(mlp_in)) + resid_post = self.hook_resid_post(resid_mid + mlp_out) + + return resid_post From 926c1c4640d1f1ef15f113548d0feaedeec36f38 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Wed, 5 Nov 2025 12:12:49 -0500 Subject: [PATCH 02/82] Delete transformer_lens/components/hubert_block.py --- transformer_lens/components/hubert_block.py | 52 --------------------- 1 file changed, 52 deletions(-) delete mode 100644 transformer_lens/components/hubert_block.py diff --git a/transformer_lens/components/hubert_block.py b/transformer_lens/components/hubert_block.py deleted file mode 100644 index 143ac2ccd..000000000 --- a/transformer_lens/components/hubert_block.py +++ /dev/null @@ -1,52 +0,0 @@ -class HubertBlock(nn.Module): - """ - HuBERT-style Transformer Block (Pre-LayerNorm). - Structurally similar to BERTBlock, but with LayerNorm applied before each sublayer. - """ - - def __init__(self, cfg: HookedTransformerConfig): - super().__init__() - self.cfg = cfg - - self.attn = Attention(cfg) - self.ln1 = LayerNorm(cfg) - self.mlp = MLPFactory.create_mlp(self.cfg) - self.ln2 = LayerNorm(cfg) - - self.hook_q_input = HookPoint() - self.hook_k_input = HookPoint() - self.hook_v_input = HookPoint() - self.hook_attn_out = HookPoint() - self.hook_mlp_in = HookPoint() - self.hook_mlp_out = HookPoint() - self.hook_resid_pre = HookPoint() - self.hook_resid_mid = HookPoint() - self.hook_resid_post = HookPoint() - self.hook_normalized_resid_post = HookPoint() - - def forward( - self, - resid_pre: Float[torch.Tensor, "batch pos d_model"], - additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None, - ) -> Float[torch.Tensor, "batch pos d_model"]: - resid_pre = self.hook_resid_pre(resid_pre) - - # --- Attention sublayer --- - normed = self.ln1(resid_pre) - attn_out = self.hook_attn_out( - self.attn( - self.hook_q_input(repeat_along_head_dimension(normed, self.cfg.n_heads)), - self.hook_k_input(repeat_along_head_dimension(normed, self.cfg.n_heads)), - self.hook_v_input(repeat_along_head_dimension(normed, self.cfg.n_heads)), - additive_attention_mask=additive_attention_mask, - ) - ) - resid_mid = self.hook_resid_mid(resid_pre + attn_out) - - # --- Feedforward sublayer --- - normed_mid = self.hook_normalized_resid_post(self.ln2(resid_mid)) - mlp_in = self.hook_mlp_in(normed_mid.clone()) if self.cfg.use_hook_mlp_in else normed_mid - mlp_out = self.hook_mlp_out(self.mlp(mlp_in)) - resid_post = self.hook_resid_post(resid_mid + mlp_out) - - return resid_post From 95d50bc40c194aa620867fb79ce314a174ec4db1 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Wed, 5 Nov 2025 13:09:52 -0500 Subject: [PATCH 03/82] Create HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 447 +++++++++++++++++++++++++ 1 file changed, 447 insertions(+) create mode 100644 transformer_lens/HookedAudioEncoder.py diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py new file mode 100644 index 000000000..7ebdbe5a3 --- /dev/null +++ b/transformer_lens/HookedAudioEncoder.py @@ -0,0 +1,447 @@ +"""Hooked Encoder. + +Contains a BERT style model. This is separate from :class:`transformer_lens.HookedTransformer` +because it has a significantly different architecture to e.g. GPT style transformers. +""" + +from __future__ import annotations + +import logging +import os +from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, overload + +import torch +import torch.nn as nn +from einops import repeat +from jaxtyping import Float, Int +from transformers.models.auto.tokenization_auto import AutoTokenizer +from typing_extensions import Literal + +import transformer_lens.loading_from_pretrained as loading +from transformer_lens.ActivationCache import ActivationCache +from transformer_lens.components import ( + MLP, + Attention, + BertBlock, + BertEmbed, + BertMLMHead, + BertNSPHead, + BertPooler, + Unembed, +) +from transformer_lens.FactoredMatrix import FactoredMatrix +from transformer_lens.hook_points import HookedRootModule, HookPoint +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.utilities import devices + +T = TypeVar("T", bound="HookedEncoder") + + +class HookedEncoder(HookedRootModule): + """ + This class implements a BERT-style encoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. + + Limitations: + - The model does not include dropouts, which may lead to inconsistent results from training or fine-tuning. + + Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported: + - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model + """ + + def __init__( + self, + cfg: Union[HookedTransformerConfig, Dict], + tokenizer: Optional[Any] = None, + move_to_device: bool = True, + **kwargs: Any, + ): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig(**cfg) + elif isinstance(cfg, str): + raise ValueError( + "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoder.from_pretrained() instead." + ) + self.cfg = cfg + + assert self.cfg.n_devices == 1, "Multiple devices not supported for HookedEncoder" + if tokenizer is not None: + self.tokenizer = tokenizer + elif self.cfg.tokenizer_name is not None: + huggingface_token = os.environ.get("HF_TOKEN", "") + self.tokenizer = AutoTokenizer.from_pretrained( + self.cfg.tokenizer_name, + token=huggingface_token if len(huggingface_token) > 0 else None, + ) + else: + self.tokenizer = None + + if self.cfg.d_vocab == -1: + # If we have a tokenizer, vocab size can be inferred from it. + assert self.tokenizer is not None, "Must provide a tokenizer if d_vocab is not provided" + self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 + if self.cfg.d_vocab_out == -1: + self.cfg.d_vocab_out = self.cfg.d_vocab + + self.embed = BertEmbed(self.cfg) + self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)]) + self.mlm_head = BertMLMHead(self.cfg) + self.unembed = Unembed(self.cfg) + self.nsp_head = BertNSPHead(self.cfg) + self.pooler = BertPooler(self.cfg) + + self.hook_full_embed = HookPoint() + + if move_to_device: + if self.cfg.device is None: + raise ValueError("Cannot move to device when device is None") + self.to(self.cfg.device) + + self.setup() + + def encoder_output( + self, + frames: torch.Tensor, # (batch, frames, d_model) <-- precomputed conv features + one_zero_attention_mask: Optional[torch.Tensor] = None, # (batch, frames) + ): + # Ensure device + if frames.device.type != self.cfg.device: + frames = frames.to(self.cfg.device) + if one_zero_attention_mask is not None: + one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) + + # directly use frames as "embed output" (skip to_tokens/embed) + resid = self.hook_full_embed(frames) + + large_negative_number = -torch.inf + mask = ( + repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos") + if one_zero_attention_mask is not None + else None + ) + additive_attention_mask = ( + torch.where(mask == 1, large_negative_number, 0) if mask is not None else None + ) + + for block in self.blocks: + resid = block(resid, additive_attention_mask) + + return resid + + def forward( + self, + input, # either: Tensor[batch, samples] (raw wave) OR Tensor[batch, frames, feat_dim] (precomputed conv features) + return_type: Optional[str] = "logits", # "logits" or None or "hidden" + lengths: Optional[torch.Tensor] = None, # optional lengths in frames (for padding), shape [batch] + masked_positions: Optional[torch.BoolTensor] = None, # optional mask of positions to replace with masked_spec_embed [batch, frames] + preprocess_already: bool = False, # if True, input is precomputed frames + ): + """ + HuBERT-like forward. If preprocess_already=False, expects raw audio waveforms (batch, samples) + and runs feature_extractor -> feature_projection. If preprocess_already=True, expects + (batch, frames, feat_dim) already projected to model hidden dim (or if feat_dim != hidden, we project). + """ + + device = self.cfg.device + + # 1) Build feature frames + if preprocess_already: + # assume input is frames, possibly already in d_model + features = input.to(device) + # if feature dim != model hidden, optionally project (defensive) + if features.shape[-1] != self.cfg.d_model: + raise ValueError(f"features shape is incorrect. Model is expecting {self.cfg.d_model}, but get {features.shape[-1]}") + else: + # raw waveform path + # feature_extractor returns something like (batch, feat_len, feat_dim) or (batch, feat_len) + # Hugging Face: feature_extractor expects float waveform batched + wave = input.to(device) + features = self.feature_extractor(wave) # conv layers -> torch.float + features = self.feature_projection(features) # linear + layernorm -> (batch, frames, d_model) + + # 2) Optionally apply masked_spec_embed for masked_positions + # masked_positions: bool tensor [batch, frames] where True indicates masked frames + if masked_positions is not None: + # masked_spec_embed is shape (d_model,) + mask = masked_positions.to(device) + masked_vec = self.masked_spec_embed.view(1, 1, -1) # (1,1,d_model) + features = torch.where(mask.unsqueeze(-1), masked_vec, features) + + # 3) Build attention mask from lengths if provided; else assume all ones + if lengths is not None: + # lengths in frames; create one_zero_attention_mask with 1 for valid / 0 for padding + max_frames = features.shape[1] + rng = torch.arange(max_frames, device=device).unsqueeze(0) # (1, frames) + one_zero_attention_mask = (rng < lengths.unsqueeze(1)).long() # (batch, frames) + else: + one_zero_attention_mask = torch.ones(features.shape[:2], dtype=torch.long, device=device) + + # 4) Pass through (possibly identical) encoder routine + # For the HookedTransformer code you had: resid = self.hook_full_embed(self.embed(tokens, ...)) + # For HuBERT we treat 'features' as the residual input. + resid = self.encoder_output(features, one_zero_attention_mask) + + # 5) Prediction head: project hidden states to logits/predictions over discrete units + if return_type == "hidden": + return resid # (batch, frames, d_model) + + # project_hid -> predictions (frame-wise) + pred = self.project_hid(resid) # shape (batch, frames, target_dim) or (batch, frames, n_classes) + # If your project_hid produces vectors and you want logits over cluster ids, there may be an extra linear/unembed + # e.g., logits = self.unembed(pred) or pred itself already logits. + + if return_type == "logits" or return_type is None: + return pred + + return None + + + @overload + def run_with_cache( + self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any + ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: + ... + + @overload + def run_with_cache( + self, *model_args: Any, return_cache_object: Literal[False], **kwargs: Any + ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: + ... + + def run_with_cache( + self, + *model_args: Any, + return_cache_object: bool = True, + remove_batch_dim: bool = False, + **kwargs: Any, + ) -> Tuple[ + Float[torch.Tensor, "batch pos d_vocab"], + Union[ActivationCache, Dict[str, torch.Tensor]], + ]: + """ + Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. + """ + out, cache_dict = super().run_with_cache( + *model_args, remove_batch_dim=remove_batch_dim, **kwargs + ) + if return_cache_object: + cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) + return out, cache + else: + return out, cache_dict + + def to( # type: ignore + self, + device_or_dtype: Union[torch.device, str, torch.dtype], + print_details: bool = True, + ): + return devices.move_to_and_update_config(self, device_or_dtype, print_details) + + def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T: + if isinstance(device, int): + return self.to(f"cuda:{device}") + elif device is None: + return self.to("cuda") + else: + return self.to(device) + + def cpu(self: T) -> T: + return self.to("cpu") + + def mps(self: T) -> T: + return self.to(torch.device("mps")) + + @classmethod + def from_pretrained( + cls, + model_name: str, + checkpoint_index: Optional[int] = None, + checkpoint_value: Optional[int] = None, + hf_model: Optional[Any] = None, + device: Optional[str] = None, + tokenizer: Optional[Any] = None, + move_to_device: bool = True, + dtype: torch.dtype = torch.float32, + **from_pretrained_kwargs: Any, + ) -> HookedEncoder: + """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" + logging.warning( + "Support for BERT in TransformerLens is currently experimental, until such a time when it has feature " + "parity with HookedTransformer and has been tested on real research tasks. Until then, backward " + "compatibility is not guaranteed. Please see the docs for information on the limitations of the current " + "implementation." + "\n" + "If using BERT for interpretability research, keep in mind that BERT has some significant architectural " + "differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning " + "that the last LayerNorm in a block cannot be folded." + ) + + assert not ( + from_pretrained_kwargs.get("load_in_8bit", False) + or from_pretrained_kwargs.get("load_in_4bit", False) + ), "Quantization not supported" + + if "torch_dtype" in from_pretrained_kwargs: + dtype = from_pretrained_kwargs["torch_dtype"] + + official_model_name = loading.get_official_model_name(model_name) + + cfg = loading.get_pretrained_model_config( + official_model_name, + checkpoint_index=checkpoint_index, + checkpoint_value=checkpoint_value, + fold_ln=False, + device=device, + n_devices=1, + dtype=dtype, + **from_pretrained_kwargs, + ) + + state_dict = loading.get_pretrained_state_dict( + official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs + ) + + model = cls(cfg, tokenizer, move_to_device=False) + + model.load_state_dict(state_dict, strict=False) + + if move_to_device: + model.to(cfg.device) + + print(f"Loaded pretrained model {model_name} into HookedEncoder") + + return model + + @property + def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: + """ + Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits) + """ + return self.unembed.W_U + + @property + def b_U(self) -> Float[torch.Tensor, "d_vocab"]: + """ + Convenience to get the unembedding bias + """ + return self.unembed.b_U + + @property + def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: + """ + Convenience to get the embedding matrix + """ + return self.embed.embed.W_E + + @property + def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]: + """ + Convenience function to get the positional embedding. Only works on models with absolute positional embeddings! + """ + return self.embed.pos_embed.W_pos + + @property + def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: + """ + Concatenated W_E and W_pos. Used as a full (overcomplete) basis of the input space, useful for full QK and full OV circuits. + """ + return torch.cat([self.W_E, self.W_pos], dim=0) + + @property + def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the key weights across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.W_K for block in self.blocks], dim=0) + + @property + def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the query weights across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.W_Q for block in self.blocks], dim=0) + + @property + def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the value weights across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.W_V for block in self.blocks], dim=0) + + @property + def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: + """Stacks the attn output weights across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.W_O for block in self.blocks], dim=0) + + @property + def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: + """Stacks the MLP input weights across all layers""" + for block in self.blocks: + assert isinstance(block.mlp, MLP) + return torch.stack([block.mlp.W_in for block in self.blocks], dim=0) + + @property + def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: + """Stacks the MLP output weights across all layers""" + for block in self.blocks: + assert isinstance(block.mlp, MLP) + return torch.stack([block.mlp.W_out for block in self.blocks], dim=0) + + @property + def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the key biases across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.b_K for block in self.blocks], dim=0) + + @property + def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the query biases across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.b_Q for block in self.blocks], dim=0) + + @property + def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the value biases across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.b_V for block in self.blocks], dim=0) + + @property + def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: + """Stacks the attn output biases across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.b_O for block in self.blocks], dim=0) + + @property + def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: + """Stacks the MLP input biases across all layers""" + for block in self.blocks: + assert isinstance(block.mlp, MLP) + return torch.stack([block.mlp.b_in for block in self.blocks], dim=0) + + @property + def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: + """Stacks the MLP output biases across all layers""" + for block in self.blocks: + assert isinstance(block.mlp, MLP) + return torch.stack([block.mlp.b_out for block in self.blocks], dim=0) + + @property + def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] + """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head. + Useful for visualizing attention patterns.""" + return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) + + @property + def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] + """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head.""" + return FactoredMatrix(self.W_V, self.W_O) + + def all_head_labels(self) -> List[str]: + """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index.""" + return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] From 9a2929532e98b7ae1300da80a95500f97a06168b Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Wed, 5 Nov 2025 14:56:35 -0500 Subject: [PATCH 04/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 116 +++++++++++++++++++++++-- 1 file changed, 110 insertions(+), 6 deletions(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index 7ebdbe5a3..c8a976dba 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -9,13 +9,15 @@ import logging import os from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, overload +from typing_extensions import Literal import torch import torch.nn as nn +import numpy as np from einops import repeat from jaxtyping import Float, Int +from transformers import AutoProcessor, HubertModel from transformers.models.auto.tokenization_auto import AutoTokenizer -from typing_extensions import Literal import transformer_lens.loading_from_pretrained as loading from transformer_lens.ActivationCache import ActivationCache @@ -53,6 +55,7 @@ def __init__( cfg: Union[HookedTransformerConfig, Dict], tokenizer: Optional[Any] = None, move_to_device: bool = True, + model_name: str = "facebook/hubert-base-ls960", **kwargs: Any, ): super().__init__() @@ -85,10 +88,15 @@ def __init__( self.embed = BertEmbed(self.cfg) self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)]) - self.mlm_head = BertMLMHead(self.cfg) - self.unembed = Unembed(self.cfg) - self.nsp_head = BertNSPHead(self.cfg) - self.pooler = BertPooler(self.cfg) + processor = AutoProcessor.from_pretrained(model_name) # builds input_values + attention_mask + model = HubertModel.from_pretrained(model_name) + if move_to_device: + if self.cfg.device is None: + raise ValueError("Cannot move to device when device is None") + model.to(self.cfg.device) + model.eval() + self.processor = processor + self.model = model self.hook_full_embed = HookPoint() @@ -98,7 +106,103 @@ def __init__( self.to(self.cfg.device) self.setup() - + + def _ensure_tensor(wave): + """Convert numpy array or python list to 1D torch.float tensor.""" + if isinstance(wave, np.ndarray): + return torch.from_numpy(wave).float() + if isinstance(wave, list): + return torch.tensor(wave, dtype=torch.float) + if isinstance(wave, torch.Tensor): + return wave.float() + raise TypeError("wave must be torch.Tensor, np.ndarray or list of floats") + + def to_frames( + raw_inputs: Union[torch.Tensor, List[torch.Tensor], List[np.ndarray]], + sampling_rate: int = 16000, + move_to_device: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Convert raw audio batch -> (projected frames, frame_attention_mask) + + Args: + raw_inputs: one of: + - a 1D torch.Tensor or numpy array (single waveform) + - a list of 1D torch.Tensors / numpy arrays (batch) + self.processor: HF AutoProcessor (creates input_values + sample-level attention_mask) + self.model: pretrained HubertModel (provides feature_extractor and feature_projection) + sampling_rate: sample rate of the audio (default 16k) + move_to_device: move outputs to model.device + + Returns: + frames: torch.Tensor of shape (batch, frames, hidden_size) <- after feature_projection + frame_attention_mask: torch.LongTensor of shape (batch, frames) with 1 for real frames, 0 for padding + """ + # raw_inputs are arrays/tensors + if isinstance(raw_inputs, (torch.Tensor, np.ndarray)): + waves = [_ensure_tensor(raw_inputs)] + elif isinstance(raw_inputs, list): + waves = [_ensure_tensor(w) for w in raw_inputs] + else: + raise TypeError("Unsupported raw_inputs type") + + # Use HF processor to create input_values (padded) + sample-level attention_mask + # Processor will do padding so we can pass a variable-length batch + proc_out = self.processor(waves, sampling_rate=sampling_rate, return_tensors="pt", padding=True) + + input_values = proc_out["input_values"] # (batch, samples), float + sample_attention_mask = proc_out.get("attention_mask") # (batch, samples), 1 for valid, 0 for padding; may be None + + # move to device + device = self.cfg.device + if move_to_device: + input_values = input_values.to(device) + if sample_attention_mask is not None: + sample_attention_mask = sample_attention_mask.to(device) + + # 1) convolutional frontend -> (batch, conv_dim, conv_time) + with torch.no_grad(): + conv_feats = self.model.feature_extractor(input_values) # (B, C, T_conv) + + # 2) transpose to (batch, T_conv, C) + extract_features = conv_feats.transpose(1, 2) + + # 3) compute reduced frame-level attention mask (if sample mask provided) + frame_attention_mask = None + if sample_attention_mask is not None: + # model should provide helper _get_feature_vector_attention_mask + try: + frame_attention_mask = self.model._get_feature_vector_attention_mask(extract_features.shape[1], sample_attention_mask) + except AttributeError: + # fallback: compute output lengths and create mask similarly to HF implementation + # compute output lengths (downsampled lengths) from sample attention mask (sums per example) + input_lengths = sample_attention_mask.sum(dim=-1) # (batch,) + # compute output lengths through conv layers using model._get_feat_extract_output_lengths if exists + if hasattr(model, "_get_feat_extract_output_lengths"): + output_lengths = self.model._get_feat_extract_output_lengths(input_lengths).to(torch.long) + else: + # fallback to naive downsample ratio: output_frames = extract_features.shape[1] + output_lengths = torch.full((sample_attention_mask.shape[0],), extract_features.shape[1], device=device, dtype=torch.long) + + batch_size = sample_attention_mask.shape[0] + feat_len = extract_features.shape[1] + frame_attention_mask = torch.zeros((batch_size, feat_len), dtype=sample_attention_mask.dtype, device=device) + # mark the last valid index for each example and then cumsum trick to fill ones before it + idx = (torch.arange(batch_size, device=device), (output_lengths - 1).clamp(min=0)) + frame_attention_mask[idx] = 1 + frame_attention_mask = frame_attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool().long() + + # 4) feature projection -> (batch, frames, hidden_size) + with torch.no_grad(): + hidden_states = self.model.feature_projection(extract_features) # typically returns (B, T, hidden) + # In HF's hubert, feature_projection is a module that returns a tensor (not tuple). If it returns tuple, adjust. + + # convert bool mask to long (1/0) if needed + if frame_attention_mask is not None: + frame_attention_mask = frame_attention_mask.to(dtype=torch.long) + + return hidden_states, frame_attention_mask + def encoder_output( self, frames: torch.Tensor, # (batch, frames, d_model) <-- precomputed conv features From 48c6efe743496f2089f5db2e62cf2b89b4ecee8d Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Wed, 5 Nov 2025 15:54:32 -0500 Subject: [PATCH 05/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 200 ++++++++++--------------- 1 file changed, 76 insertions(+), 124 deletions(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index c8a976dba..3631c840a 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -53,7 +53,6 @@ class HookedEncoder(HookedRootModule): def __init__( self, cfg: Union[HookedTransformerConfig, Dict], - tokenizer: Optional[Any] = None, move_to_device: bool = True, model_name: str = "facebook/hubert-base-ls960", **kwargs: Any, @@ -68,25 +67,7 @@ def __init__( self.cfg = cfg assert self.cfg.n_devices == 1, "Multiple devices not supported for HookedEncoder" - if tokenizer is not None: - self.tokenizer = tokenizer - elif self.cfg.tokenizer_name is not None: - huggingface_token = os.environ.get("HF_TOKEN", "") - self.tokenizer = AutoTokenizer.from_pretrained( - self.cfg.tokenizer_name, - token=huggingface_token if len(huggingface_token) > 0 else None, - ) - else: - self.tokenizer = None - if self.cfg.d_vocab == -1: - # If we have a tokenizer, vocab size can be inferred from it. - assert self.tokenizer is not None, "Must provide a tokenizer if d_vocab is not provided" - self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 - if self.cfg.d_vocab_out == -1: - self.cfg.d_vocab_out = self.cfg.d_vocab - - self.embed = BertEmbed(self.cfg) self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)]) processor = AutoProcessor.from_pretrained(model_name) # builds input_values + attention_mask model = HubertModel.from_pretrained(model_name) @@ -98,8 +79,6 @@ def __init__( self.processor = processor self.model = model - self.hook_full_embed = HookPoint() - if move_to_device: if self.cfg.device is None: raise ValueError("Cannot move to device when device is None") @@ -214,8 +193,9 @@ def encoder_output( if one_zero_attention_mask is not None: one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) - # directly use frames as "embed output" (skip to_tokens/embed) - resid = self.hook_full_embed(frames) + position_embeddings = self.model.encoder.pos_conv_embed(frames) + resid = resid + position_embeddings + resid = self.model.encoder.layer_norm(resid) large_negative_number = -torch.inf mask = ( @@ -226,79 +206,86 @@ def encoder_output( additive_attention_mask = ( torch.where(mask == 1, large_negative_number, 0) if mask is not None else None ) - for block in self.blocks: resid = block(resid, additive_attention_mask) return resid def forward( - self, - input, # either: Tensor[batch, samples] (raw wave) OR Tensor[batch, frames, feat_dim] (precomputed conv features) - return_type: Optional[str] = "logits", # "logits" or None or "hidden" - lengths: Optional[torch.Tensor] = None, # optional lengths in frames (for padding), shape [batch] - masked_positions: Optional[torch.BoolTensor] = None, # optional mask of positions to replace with masked_spec_embed [batch, frames] - preprocess_already: bool = False, # if True, input is precomputed frames - ): - """ - HuBERT-like forward. If preprocess_already=False, expects raw audio waveforms (batch, samples) - and runs feature_extractor -> feature_projection. If preprocess_already=True, expects - (batch, frames, feat_dim) already projected to model hidden dim (or if feat_dim != hidden, we project). - """ - - device = self.cfg.device - - # 1) Build feature frames - if preprocess_already: - # assume input is frames, possibly already in d_model - features = input.to(device) - # if feature dim != model hidden, optionally project (defensive) - if features.shape[-1] != self.cfg.d_model: - raise ValueError(f"features shape is incorrect. Model is expecting {self.cfg.d_model}, but get {features.shape[-1]}") - else: - # raw waveform path - # feature_extractor returns something like (batch, feat_len, feat_dim) or (batch, feat_len) - # Hugging Face: feature_extractor expects float waveform batched - wave = input.to(device) - features = self.feature_extractor(wave) # conv layers -> torch.float - features = self.feature_projection(features) # linear + layernorm -> (batch, frames, d_model) - - # 2) Optionally apply masked_spec_embed for masked_positions - # masked_positions: bool tensor [batch, frames] where True indicates masked frames - if masked_positions is not None: - # masked_spec_embed is shape (d_model,) - mask = masked_positions.to(device) - masked_vec = self.masked_spec_embed.view(1, 1, -1) # (1,1,d_model) - features = torch.where(mask.unsqueeze(-1), masked_vec, features) - - # 3) Build attention mask from lengths if provided; else assume all ones - if lengths is not None: - # lengths in frames; create one_zero_attention_mask with 1 for valid / 0 for padding - max_frames = features.shape[1] - rng = torch.arange(max_frames, device=device).unsqueeze(0) # (1, frames) - one_zero_attention_mask = (rng < lengths.unsqueeze(1)).long() # (batch, frames) - else: - one_zero_attention_mask = torch.ones(features.shape[:2], dtype=torch.long, device=device) - - # 4) Pass through (possibly identical) encoder routine - # For the HookedTransformer code you had: resid = self.hook_full_embed(self.embed(tokens, ...)) - # For HuBERT we treat 'features' as the residual input. - resid = self.encoder_output(features, one_zero_attention_mask) - - # 5) Prediction head: project hidden states to logits/predictions over discrete units - if return_type == "hidden": - return resid # (batch, frames, d_model) - - # project_hid -> predictions (frame-wise) - pred = self.project_hid(resid) # shape (batch, frames, target_dim) or (batch, frames, n_classes) - # If your project_hid produces vectors and you want logits over cluster ids, there may be an extra linear/unembed - # e.g., logits = self.unembed(pred) or pred itself already logits. - - if return_type == "logits" or return_type is None: - return pred - - return None + self, + input: Union[ + torch.Tensor, # waveform (1D) OR precomputed frames (3D) + List[Union[torch.Tensor, np.ndarray]], # list of waveforms + Tuple[torch.Tensor, torch.Tensor], # (frames, frame_mask) + ], + return_type: Optional[Literal["hidden", "logits"]] = "logits", + sampling_rate: int = 16000, + move_to_device: bool = True, +) -> Optional[torch.Tensor]: + """ + HuBERT-like forward (Transformer-Lens style). + + Args: + input: one of: + - 1D torch.Tensor or numpy array (single waveform) OR list of 1D waveforms -> will call self.to_frames(...) + - 3D torch.Tensor shaped (batch, frames, d_model) -> treated as precomputed frames (skip to_frames) + - tuple (frames, frame_mask) -> use directly + return_type: "hidden" to return encoder hidden states (B, T, D), "logits" to return project_hid output if present. + sampling_rate: sampling rate for to_frames when converting raw audio. + move_to_device: move tensors to self.cfg.device (to match your other code). + + Returns: + Depending on return_type: + - "hidden": (batch, frames, d_model) final encoder hidden states + - "logits": output of project_hid(resid) if available (e.g., (batch, frames, n_targets)) + """ + # ---------- 1) Normalize input: get (frames, frame_mask) ---------- + frames = None + frame_mask = None # one_zero_attention_mask: 1 = valid, 0 = padding + + # If user passed (frames, mask) tuple + if isinstance(input, tuple) and len(input) == 2 and isinstance(input[0], torch.Tensor): + frames, frame_mask = input + + # If user passed a 3D tensor -> assume (B, T, D) frames (pre-projected) + elif isinstance(input, torch.Tensor) and input.ndim == 3: + frames = input + # frame_mask stays whatever was passed as separate argument (None here) + + # Else treat as raw waveform(s) -> call to_frames + else: + # allow single 1D tensor or numpy array or list of tensors/arrays + frames, frame_mask = self.to_frames(input, sampling_rate=sampling_rate, move_to_device=move_to_device) + # to_frames should already place tensors on device if move_to_device=True + + # ---------- 2) Ensure device & dtype consistency ---------- + device = self.cfg.device + if frames.device.type != device: + frames = frames.to(device) + if frame_mask is not None: + frame_mask = frame_mask.to(device) + + # ---------- 3) Run encoder (respects pos_conv_embed / layer_norm / dropout inside encoder_output) ---------- + resid = self.encoder_output(frames, frame_mask) # (B, T, d_model) + + # ---------- 4) Return according to return_type ---------- + if return_type == "hidden": + return resid + + if return_type == "logits": + if hasattr(self, "project_hid"): + logits = self.project_hid(resid) + return logits + # try model-level project head (HuggingFace uses project_hid/project_q) + if hasattr(self.model, "project_hid"): + logits = self.model.project_hid(resid) + return logits + + # no head available — return hidden states as fallback + return resid + # unknown return_type -> return hidden states + return resid @overload def run_with_cache( @@ -416,41 +403,6 @@ def from_pretrained( return model - @property - def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: - """ - Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits) - """ - return self.unembed.W_U - - @property - def b_U(self) -> Float[torch.Tensor, "d_vocab"]: - """ - Convenience to get the unembedding bias - """ - return self.unembed.b_U - - @property - def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: - """ - Convenience to get the embedding matrix - """ - return self.embed.embed.W_E - - @property - def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]: - """ - Convenience function to get the positional embedding. Only works on models with absolute positional embeddings! - """ - return self.embed.pos_embed.W_pos - - @property - def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: - """ - Concatenated W_E and W_pos. Used as a full (overcomplete) basis of the input space, useful for full QK and full OV circuits. - """ - return torch.cat([self.W_E, self.W_pos], dim=0) - @property def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: """Stacks the key weights across all layers""" From 1b7559c7bcd36a021898e134fc9fb6855c388dfe Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Wed, 5 Nov 2025 16:00:00 -0500 Subject: [PATCH 06/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 124 +++++++++++-------------- 1 file changed, 52 insertions(+), 72 deletions(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index 3631c840a..fc6ad7b6c 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -212,81 +212,61 @@ def encoder_output( return resid def forward( - self, - input: Union[ - torch.Tensor, # waveform (1D) OR precomputed frames (3D) - List[Union[torch.Tensor, np.ndarray]], # list of waveforms - Tuple[torch.Tensor, torch.Tensor], # (frames, frame_mask) - ], - return_type: Optional[Literal["hidden", "logits"]] = "logits", - sampling_rate: int = 16000, - move_to_device: bool = True, -) -> Optional[torch.Tensor]: - """ - HuBERT-like forward (Transformer-Lens style). - - Args: - input: one of: - - 1D torch.Tensor or numpy array (single waveform) OR list of 1D waveforms -> will call self.to_frames(...) - - 3D torch.Tensor shaped (batch, frames, d_model) -> treated as precomputed frames (skip to_frames) - - tuple (frames, frame_mask) -> use directly - return_type: "hidden" to return encoder hidden states (B, T, D), "logits" to return project_hid output if present. - sampling_rate: sampling rate for to_frames when converting raw audio. - move_to_device: move tensors to self.cfg.device (to match your other code). - - Returns: - Depending on return_type: - - "hidden": (batch, frames, d_model) final encoder hidden states - - "logits": output of project_hid(resid) if available (e.g., (batch, frames, n_targets)) - """ - # ---------- 1) Normalize input: get (frames, frame_mask) ---------- - frames = None - frame_mask = None # one_zero_attention_mask: 1 = valid, 0 = padding - - # If user passed (frames, mask) tuple - if isinstance(input, tuple) and len(input) == 2 and isinstance(input[0], torch.Tensor): - frames, frame_mask = input - - # If user passed a 3D tensor -> assume (B, T, D) frames (pre-projected) - elif isinstance(input, torch.Tensor) and input.ndim == 3: - frames = input - # frame_mask stays whatever was passed as separate argument (None here) - - # Else treat as raw waveform(s) -> call to_frames - else: - # allow single 1D tensor or numpy array or list of tensors/arrays - frames, frame_mask = self.to_frames(input, sampling_rate=sampling_rate, move_to_device=move_to_device) - # to_frames should already place tensors on device if move_to_device=True - - # ---------- 2) Ensure device & dtype consistency ---------- - device = self.cfg.device - if frames.device.type != device: - frames = frames.to(device) - if frame_mask is not None: - frame_mask = frame_mask.to(device) - - # ---------- 3) Run encoder (respects pos_conv_embed / layer_norm / dropout inside encoder_output) ---------- - resid = self.encoder_output(frames, frame_mask) # (B, T, d_model) - - # ---------- 4) Return according to return_type ---------- - if return_type == "hidden": - return resid + self, + input: Union[ + torch.Tensor, # waveform (1D) OR precomputed frames (3D) + List[Union[torch.Tensor, np.ndarray]], # list of waveforms + Tuple[torch.Tensor, torch.Tensor], # (frames, frame_mask) + ], + sampling_rate: int = 16000, + move_to_device: bool = True, + ) -> Optional[torch.Tensor]: + """ + HuBERT-like forward (Transformer-Lens style). + + Args: + input: one of: + - 1D torch.Tensor or numpy array (single waveform) OR list of 1D waveforms -> will call self.to_frames(...) + - 3D torch.Tensor shaped (batch, frames, d_model) -> treated as precomputed frames (skip to_frames) + - tuple (frames, frame_mask) -> use directly + sampling_rate: sampling rate for to_frames when converting raw audio. + move_to_device: move tensors to self.cfg.device (to match your other code). + + Returns: + Depending on return_type: + - "hidden": (batch, frames, d_model) final encoder hidden states + """ + # ---------- 1) Normalize input: get (frames, frame_mask) ---------- + frames = None + frame_mask = None # one_zero_attention_mask: 1 = valid, 0 = padding + + # If user passed (frames, mask) tuple + if isinstance(input, tuple) and len(input) == 2 and isinstance(input[0], torch.Tensor): + frames, frame_mask = input + + # If user passed a 3D tensor -> assume (B, T, D) frames (pre-projected) + elif isinstance(input, torch.Tensor) and input.ndim == 3: + frames = input + # frame_mask stays whatever was passed as separate argument (None here) + + # Else treat as raw waveform(s) -> call to_frames + else: + # allow single 1D tensor or numpy array or list of tensors/arrays + frames, frame_mask = self.to_frames(input, sampling_rate=sampling_rate, move_to_device=move_to_device) + # to_frames should already place tensors on device if move_to_device=True + + # ---------- 2) Ensure device & dtype consistency ---------- + device = self.cfg.device + if frames.device.type != device: + frames = frames.to(device) + if frame_mask is not None: + frame_mask = frame_mask.to(device) + + # ---------- 3) Run encoder (respects pos_conv_embed / layer_norm / dropout inside encoder_output) ---------- + resid = self.encoder_output(frames, frame_mask) # (B, T, d_model) - if return_type == "logits": - if hasattr(self, "project_hid"): - logits = self.project_hid(resid) - return logits - # try model-level project head (HuggingFace uses project_hid/project_q) - if hasattr(self.model, "project_hid"): - logits = self.model.project_hid(resid) - return logits - - # no head available — return hidden states as fallback return resid - # unknown return_type -> return hidden states - return resid - @overload def run_with_cache( self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any From 94fa33ebc17881d6ec504d41c27f40e8a364d328 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Wed, 5 Nov 2025 16:09:13 -0500 Subject: [PATCH 07/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index fc6ad7b6c..224393737 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -330,19 +330,18 @@ def from_pretrained( checkpoint_value: Optional[int] = None, hf_model: Optional[Any] = None, device: Optional[str] = None, - tokenizer: Optional[Any] = None, move_to_device: bool = True, dtype: torch.dtype = torch.float32, **from_pretrained_kwargs: Any, ) -> HookedEncoder: """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" logging.warning( - "Support for BERT in TransformerLens is currently experimental, until such a time when it has feature " + "Support for HuBERT in TransformerLens is currently experimental, until such a time when it has feature " "parity with HookedTransformer and has been tested on real research tasks. Until then, backward " "compatibility is not guaranteed. Please see the docs for information on the limitations of the current " "implementation." "\n" - "If using BERT for interpretability research, keep in mind that BERT has some significant architectural " + "If using HuBERT for interpretability research, keep in mind that HuBERT has some significant architectural " "differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning " "that the last LayerNorm in a block cannot be folded." ) @@ -372,7 +371,7 @@ def from_pretrained( official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs ) - model = cls(cfg, tokenizer, move_to_device=False) + model = cls(cfg, move_to_device=False) model.load_state_dict(state_dict, strict=False) From 6e93a5bdd5b8551eb4b7c302191664579508c46d Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Wed, 5 Nov 2025 16:12:44 -0500 Subject: [PATCH 08/82] Update loading_from_pretrained.py --- transformer_lens/loading_from_pretrained.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 8bfb6315d..5051c5662 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -59,6 +59,7 @@ "facebook/opt-13b", "facebook/opt-30b", "facebook/opt-66b", + "facebook/hubert-base-ls960" "EleutherAI/gpt-neo-125M", "EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-2.7B", @@ -610,6 +611,10 @@ "google-bert/bert-base-uncased": ["bert-base-uncased"], "google-bert/bert-large-cased": ["bert-large-cased"], "google-bert/bert-large-uncased": ["bert-large-uncased"], + "facebook-hubert/hubert-base-ls960": ["facebook/hubert-base-ls960", "hubert-base-ls960"], + "facebook-hubert/hubert-large-ls960": ["facebook/hubert-large-ls960", "hubert-large-ls960"], + "facebook-hubert/hubert-xlarge-ls960": ["facebook/hubert-xlarge-ls960", "hubert-xlarge-ls960"], + "facebook-hubert/hubert-large-ls960-ft": ["facebook/hubert-large-ls960-ft", "hubert-large-ls960-ft"], "roneneldan/TinyStories-1M": ["tiny-stories-1M"], "roneneldan/TinyStories-3M": ["tiny-stories-3M"], "roneneldan/TinyStories-8M": ["tiny-stories-8M"], From 4edde8d7bea7603478f657c345fcc9fc4e2c11dd Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Wed, 5 Nov 2025 16:13:36 -0500 Subject: [PATCH 09/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index 224393737..82e6518d0 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -371,7 +371,7 @@ def from_pretrained( official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs ) - model = cls(cfg, move_to_device=False) + model = cls(cfg, move_to_device=False, model_name=official_model_name) model.load_state_dict(state_dict, strict=False) From a5ef321571623b880dcd2901de55bfd480c9c607 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Wed, 5 Nov 2025 16:31:37 -0500 Subject: [PATCH 10/82] Update loading_from_pretrained.py --- transformer_lens/loading_from_pretrained.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 5051c5662..4e47ec84c 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1181,6 +1181,23 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): } rotary_pct = hf_config.rotary_pct cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"]) + elif any(x in architecture for x in ( + "HubertModel", "HubertForCTC", "HubertForPreTraining", "HubertForSequenceClassification" + )) or "hubert" in official_model_name.lower(): + # Basic transformer configuration + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + # HuBERT operates on audio frames, not tokens — n_ctx is flexible + "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), + "eps": hf_config.layer_norm_eps, + "act_fn": "gelu", + "attention_dir": "bidirectional", + "d_vocab": -1, # no text vocabulary + } elif architecture == "BertForMaskedLM": # All supported Bert architectures have the same config, # so we can use the BertForMaskedLM config for all of them From cd930f318d55109aabbaa0a531ff47750352e9c0 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Wed, 5 Nov 2025 16:40:35 -0500 Subject: [PATCH 11/82] Update loading_from_pretrained.py --- transformer_lens/loading_from_pretrained.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 4e47ec84c..25d46983d 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -19,6 +19,7 @@ AutoModelForCausalLM, BertForPreTraining, T5ForConditionalGeneration, + HubertModel, ) import transformer_lens.utils as utils @@ -1943,6 +1944,13 @@ def get_pretrained_state_dict( huggingface_token = os.environ.get("HF_TOKEN", "") if official_model_name in NON_HF_HOSTED_MODEL_NAMES: raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model") + elif "hubert" in official_model_name: + hf_model = HubertModel.from_pretrained( + official_model_name, + torch_dtype=dtype, + token=huggingface_token if len(huggingface_token) > 0 else None, + **kwargs, + ) elif "bert" in official_model_name: hf_model = BertForPreTraining.from_pretrained( official_model_name, From 4621730d7c09665da84fdf0e2d43592e1a31e4f5 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Wed, 5 Nov 2025 16:43:52 -0500 Subject: [PATCH 12/82] Update loading_from_pretrained.py --- transformer_lens/loading_from_pretrained.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 25d46983d..a4bb9c7ae 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -45,6 +45,7 @@ convert_qwen3_weights, convert_qwen_weights, convert_t5_weights, + convert_hubert_weights, ) OFFICIAL_MODEL_NAMES = [ @@ -60,7 +61,10 @@ "facebook/opt-13b", "facebook/opt-30b", "facebook/opt-66b", - "facebook/hubert-base-ls960" + "facebook/hubert-base-ls960", + "facebook-hubert/hubert-large-ls960", + "facebook-hubert/hubert-xlarge-ls960", + "facebook-hubert/hubert-large-ls960-ft", "EleutherAI/gpt-neo-125M", "EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-2.7B", @@ -1182,9 +1186,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): } rotary_pct = hf_config.rotary_pct cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"]) - elif any(x in architecture for x in ( - "HubertModel", "HubertForCTC", "HubertForPreTraining", "HubertForSequenceClassification" - )) or "hubert" in official_model_name.lower(): + elif architecture == "HubertModel": # Basic transformer configuration cfg_dict = { "d_model": hf_config.hidden_size, @@ -1990,6 +1992,8 @@ def get_pretrained_state_dict( state_dict = convert_neox_weights(hf_model, cfg) elif cfg.original_architecture == "LlamaForCausalLM": state_dict = convert_llama_weights(hf_model, cfg) + elif cfg.original_architecture == "HubertModel": + state_dict = convert_hubert_weights(hf_model, cfg) elif cfg.original_architecture == "BertForMaskedLM": state_dict = convert_bert_weights(hf_model, cfg) elif cfg.original_architecture == "T5ForConditionalGeneration": From 548e693b1e8456c89c9e7240a68cb02e8b5ae021 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Wed, 5 Nov 2025 16:52:43 -0500 Subject: [PATCH 13/82] Create hubert.py --- .../pretrained/weight_conversions/hubert.py | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 transformer_lens/pretrained/weight_conversions/hubert.py diff --git a/transformer_lens/pretrained/weight_conversions/hubert.py b/transformer_lens/pretrained/weight_conversions/hubert.py new file mode 100644 index 000000000..d8de2dc0a --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/hubert.py @@ -0,0 +1,70 @@ +import einops +import torch +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def convert_hubert_weights(hf_model, cfg: HookedTransformerConfig): + """ + Convert a Hugging Face HuBERT model's transformer encoder weights + into a TransformerLens-compatible state_dict. + + This ignores HuBERT's convolutional feature extractor and feature projection, + since we assume they are handled externally (e.g., via hf_model.feature_extractor + and hf_model.feature_projection). + + Args: + hf_model: A pretrained HuggingFace HuBERT model (e.g., HubertModel.from_pretrained(...)) + cfg: TransformerLens HookedTransformerConfig + + Returns: + state_dict: a dict mapping TransformerLens parameter names to torch tensors + suitable for model.load_state_dict(state_dict, strict=False) + """ + state_dict = {} + + # Shortcut to encoder layers + encoder_layers = hf_model.encoder.layers + + for l, layer in enumerate(encoder_layers): + # --- Self-attention projections --- + q_proj = layer.self_attn.q_proj.weight + k_proj = layer.self_attn.k_proj.weight + v_proj = layer.self_attn.v_proj.weight + out_proj = layer.self_attn.out_proj.weight + + # Reshape Q, K, V into [n_heads, d_model, d_head] + d_model = cfg.d_model + n_heads = cfg.n_heads + d_head = d_model // n_heads + + state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange( + q_proj, "(n h) m -> n m h", n=n_heads + ) + state_dict[f"blocks.{l}.attn.W_K"] = einops.rearrange( + k_proj, "(n h) m -> n m h", n=n_heads + ) + state_dict[f"blocks.{l}.attn.W_V"] = einops.rearrange( + v_proj, "(n h) m -> n m h", n=n_heads + ) + state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange( + out_proj, "m (n h) -> n h m", n=n_heads + ) + + # --- LayerNorms --- + state_dict[f"blocks.{l}.ln1.w"] = layer.layer_norm.weight + state_dict[f"blocks.{l}.ln1.b"] = layer.layer_norm.bias + state_dict[f"blocks.{l}.ln2.w"] = layer.final_layer_norm.weight + state_dict[f"blocks.{l}.ln2.b"] = layer.final_layer_norm.bias + + # --- Feed-forward (MLP) --- + fc1 = layer.fc1.weight + fc2 = layer.fc2.weight + fc1_bias = layer.fc1.bias + fc2_bias = layer.fc2.bias + + state_dict[f"blocks.{l}.mlp.W_in"] = fc1.T # shape [d_model, d_mlp] + state_dict[f"blocks.{l}.mlp.b_in"] = fc1_bias + state_dict[f"blocks.{l}.mlp.W_out"] = fc2.T # shape [d_mlp, d_model] + state_dict[f"blocks.{l}.mlp.b_out"] = fc2_bias + + return state_dict From 5dc88a1a95cce927ee018dc872d29fb1f7c5e626 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Wed, 5 Nov 2025 17:13:02 -0500 Subject: [PATCH 14/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 48 +++++++++++++++++--------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index 82e6518d0..04be2eb6a 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -16,7 +16,7 @@ import numpy as np from einops import repeat from jaxtyping import Float, Int -from transformers import AutoProcessor, HubertModel +from transformers import AutoProcessor, HubertModel, HubertForCTC from transformers.models.auto.tokenization_auto import AutoTokenizer import transformer_lens.loading_from_pretrained as loading @@ -25,11 +25,6 @@ MLP, Attention, BertBlock, - BertEmbed, - BertMLMHead, - BertNSPHead, - BertPooler, - Unembed, ) from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.hook_points import HookedRootModule, HookPoint @@ -55,6 +50,7 @@ def __init__( cfg: Union[HookedTransformerConfig, Dict], move_to_device: bool = True, model_name: str = "facebook/hubert-base-ls960", + use_ctc: bool = True, **kwargs: Any, ): super().__init__() @@ -70,14 +66,22 @@ def __init__( self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)]) processor = AutoProcessor.from_pretrained(model_name) # builds input_values + attention_mask - model = HubertModel.from_pretrained(model_name) + if use_ctc: + hubert_model = HubertForCTC.from_pretrained(model_name) + else: + hubert_model = HubertModel.from_pretrained(model_name) if move_to_device: if self.cfg.device is None: raise ValueError("Cannot move to device when device is None") - model.to(self.cfg.device) - model.eval() + hubert_.to(self.cfg.device) + hubert_.eval() self.processor = processor - self.model = model + if use_ctc: + self.hubert_model = hubert_model.hubert + self.lm_head = hubert_model.lm_head + else: + self.hubert_model = hubert_model + self.lm_head = None if move_to_device: if self.cfg.device is None: @@ -141,7 +145,7 @@ def to_frames( # 1) convolutional frontend -> (batch, conv_dim, conv_time) with torch.no_grad(): - conv_feats = self.model.feature_extractor(input_values) # (B, C, T_conv) + conv_feats = self.hubert_model.feature_extractor(input_values) # (B, C, T_conv) # 2) transpose to (batch, T_conv, C) extract_features = conv_feats.transpose(1, 2) @@ -151,14 +155,14 @@ def to_frames( if sample_attention_mask is not None: # model should provide helper _get_feature_vector_attention_mask try: - frame_attention_mask = self.model._get_feature_vector_attention_mask(extract_features.shape[1], sample_attention_mask) + frame_attention_mask = self.hubert_model._get_feature_vector_attention_mask(extract_features.shape[1], sample_attention_mask) except AttributeError: # fallback: compute output lengths and create mask similarly to HF implementation # compute output lengths (downsampled lengths) from sample attention mask (sums per example) input_lengths = sample_attention_mask.sum(dim=-1) # (batch,) # compute output lengths through conv layers using model._get_feat_extract_output_lengths if exists if hasattr(model, "_get_feat_extract_output_lengths"): - output_lengths = self.model._get_feat_extract_output_lengths(input_lengths).to(torch.long) + output_lengths = self.hubert_model._get_feat_extract_output_lengths(input_lengths).to(torch.long) else: # fallback to naive downsample ratio: output_frames = extract_features.shape[1] output_lengths = torch.full((sample_attention_mask.shape[0],), extract_features.shape[1], device=device, dtype=torch.long) @@ -173,7 +177,7 @@ def to_frames( # 4) feature projection -> (batch, frames, hidden_size) with torch.no_grad(): - hidden_states = self.model.feature_projection(extract_features) # typically returns (B, T, hidden) + hidden_states = self.hubert_model.feature_projection(extract_features) # typically returns (B, T, hidden) # In HF's hubert, feature_projection is a module that returns a tensor (not tuple). If it returns tuple, adjust. # convert bool mask to long (1/0) if needed @@ -193,9 +197,9 @@ def encoder_output( if one_zero_attention_mask is not None: one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) - position_embeddings = self.model.encoder.pos_conv_embed(frames) + position_embeddings = self.hubert_model.encoder.pos_conv_embed(frames) resid = resid + position_embeddings - resid = self.model.encoder.layer_norm(resid) + resid = self.hubert_model.encoder.layer_norm(resid) large_negative_number = -torch.inf mask = ( @@ -219,6 +223,7 @@ def forward( Tuple[torch.Tensor, torch.Tensor], # (frames, frame_mask) ], sampling_rate: int = 16000, + use_proj: bool = False, move_to_device: bool = True, ) -> Optional[torch.Tensor]: """ @@ -230,6 +235,7 @@ def forward( - 3D torch.Tensor shaped (batch, frames, d_model) -> treated as precomputed frames (skip to_frames) - tuple (frames, frame_mask) -> use directly sampling_rate: sampling rate for to_frames when converting raw audio. + use_proj: Whether to use the final head of HubertCTC move_to_device: move tensors to self.cfg.device (to match your other code). Returns: @@ -265,6 +271,13 @@ def forward( # ---------- 3) Run encoder (respects pos_conv_embed / layer_norm / dropout inside encoder_output) ---------- resid = self.encoder_output(frames, frame_mask) # (B, T, d_model) + if use_proj: + if self.lm_head is None: + logging.warning("HubertForCTC not enabled") + return resid + hidden_states = resid[0] # (B, T, d_model) + resid = self.lm_head(hidden_states) # (B, T, vocab_size) + return resid @overload @@ -332,6 +345,7 @@ def from_pretrained( device: Optional[str] = None, move_to_device: bool = True, dtype: torch.dtype = torch.float32, + use_ctc: bool = True, **from_pretrained_kwargs: Any, ) -> HookedEncoder: """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" @@ -371,7 +385,7 @@ def from_pretrained( official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs ) - model = cls(cfg, move_to_device=False, model_name=official_model_name) + model = cls(cfg, move_to_device=False, model_name=official_model_name, use_ctc=use_ctc) model.load_state_dict(state_dict, strict=False) From 8282805663e568d6f2f5d15da109757e6c0831ff Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 13:21:25 -0500 Subject: [PATCH 15/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index 04be2eb6a..95e98c45c 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -73,8 +73,8 @@ def __init__( if move_to_device: if self.cfg.device is None: raise ValueError("Cannot move to device when device is None") - hubert_.to(self.cfg.device) - hubert_.eval() + hubert_model.to(self.cfg.device) + hubert_model.eval() self.processor = processor if use_ctc: self.hubert_model = hubert_model.hubert @@ -276,7 +276,8 @@ def forward( logging.warning("HubertForCTC not enabled") return resid hidden_states = resid[0] # (B, T, d_model) - resid = self.lm_head(hidden_states) # (B, T, vocab_size) + with torch.no_grad(): + resid = self.lm_head(hidden_states) # (B, T, vocab_size) return resid From 7f0c37397d8284ec761bf9b975e741fd5d7128ac Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 13:23:03 -0500 Subject: [PATCH 16/82] Create hubert_test.py --- hubert_test.py | 147 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 hubert_test.py diff --git a/hubert_test.py b/hubert_test.py new file mode 100644 index 000000000..df85c0087 --- /dev/null +++ b/hubert_test.py @@ -0,0 +1,147 @@ +# test_hubert_hooked.py +import torch +import numpy as np +import math + +# Replace this with the actual import for your implementation: +from transformer_lens import HookedAudioEncoder +# For illustration I assume the same API as HookedEncoder/HookedAudioEncoder: +# - HookedAudioEncoder.from_pretrained(...) OR HookedAudioEncoder(...) to instantiate +# - model(waveform, return_type=...) or model(waveform) returns a tensor +# +# If your class is named differently, change the import and instantiation below. + +# ---------- CONFIG ---------- +SAMPLE_RATE = 16000 +DURATION_S = 1.0 +BATCH_SIZE = 1 +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +# Name of HF checkpoint to use if you want to compare outputs (optional) +HF_CHECKPOINT = "facebook/hubert-base-ls960" # optional +# ---------------------------- + +def make_sine(frequency=440.0, sr=SAMPLE_RATE, duration=DURATION_S, amplitude=0.1): + t = np.linspace(0, duration, int(sr*duration), endpoint=False, dtype=np.float32) + wav = amplitude * np.sin(2 * math.pi * frequency * t) + return wav + +def run_basic_sanity_tests(model, waveform_np): + """Run quick checks: forward pass, shape, finite, deterministic, grad flow.""" + model.to(DEVICE) + + # Prepare tensor: shape (batch, time) + x = torch.from_numpy(waveform_np).unsqueeze(0).to(DEVICE) # (1, T) + + # 1) Eval forward: no grad + model.eval() + with torch.no_grad(): + out1 = model(x) # adapt if your API uses return_type="predictions" or similar + print("Forward (eval) output type:", type(out1)) + try: + out_tensor = out1 if isinstance(out1, torch.Tensor) else out1["predictions"] + except Exception: + out_tensor = out1 # fallback + + print("Output shape:", tuple(out_tensor.shape)) + print("Output stats: min=%.6g max=%.6g mean=%.6g" % (out_tensor.min().item(), out_tensor.max().item(), out_tensor.mean().item())) + assert torch.isfinite(out_tensor).all(), "Found NaNs or Infs in forward output!" + + # 2) Determinism in eval + with torch.no_grad(): + out2 = model(x) + # if model returns dict-like, extract tensor again + out2_tensor = out2 if isinstance(out2, torch.Tensor) else out2["predictions"] + if not torch.allclose(out_tensor, out2_tensor, atol=1e-6): + print("Warning: outputs differ between two eval runs (non-deterministic?), max diff:", (out_tensor - out2_tensor).abs().max().item()) + else: + print("Determinism test passed (eval mode).") + + # 3) Gradient flow test in train mode + model.train() + # zero grads + for p in model.parameters(): + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + out_train = model(x) + out_train_tensor = out_train if isinstance(out_train, torch.Tensor) else out_train["predictions"] + + # small scalar loss + loss = out_train_tensor.mean() + loss.backward() + # check some parameters got gradients + grads_found = any((p.grad is not None and torch.isfinite(p.grad).all()) for p in model.parameters() if p.requires_grad) + assert grads_found, "No finite gradients found on any parameter after backward()" + print("Gradient check passed: some parameters have finite gradients.") + +def optional_compare_to_hf(your_model, waveform_np, sr=SAMPLE_RATE): + """ + OPTIONAL: compare your_model outputs to Hugging Face's HubertModel outputs. + This requires transformers to be installed and internet access to download the checkpoint. + Important: to get a meaningful comparison you must match *exact preprocessing* (resampling, + normalization, padding/truncation) that the HF model expects and that your model used. + """ + try: + from transformers import HubertModel, Wav2Vec2FeatureExtractor + except Exception as e: + print("Transformers or feature extractor not available:", e) + return + + print("Loading Hugging Face HubertModel for optional comparison (may take a while)...") + hf_feat = Wav2Vec2FeatureExtractor(sampling_rate=sr, do_normalize=True) + hf_model = HubertModel.from_pretrained(HF_CHECKPOINT).to(DEVICE).eval() + + # Prepare input for HF model + input_values = hf_feat(waveform_np, sampling_rate=sr, return_tensors="pt").get("input_values") # (1, T) + input_values = input_values.to(DEVICE) + + with torch.no_grad(): + hf_outputs = hf_model(input_values).last_hidden_state # (1, L, D) + # Pool HF tokens to a single vector (simple mean pooling) + hf_embedding = hf_outputs.mean(dim=1) # (1, D) + + # Get your model's representation and pool similarly + your_model.eval() + with torch.no_grad(): + your_out = your_model(torch.from_numpy(waveform_np).unsqueeze(0).to(DEVICE)) + your_tensor = your_out if isinstance(your_out, torch.Tensor) else your_out["predictions"] # shape depends on your model + # If your output has time dimension, mean-pool across time + if your_tensor.ndim == 3: + your_emb = your_tensor.mean(dim=1) + else: + your_emb = your_tensor # assume (1, D) or similar + + # Resize / project if dims differ (simple check) + if hf_embedding.shape[1] != your_emb.shape[1]: + print(f"Dimension mismatch (HF {hf_embedding.shape[1]} vs your {your_emb.shape[1]}). " + "You can compare after projecting to a common dim (not shown).") + return + + # Cosine similarity + cos = torch.nn.functional.cosine_similarity(hf_embedding, your_emb, dim=1) + print("Cosine similarity between HF pooled embedding and your model:", cos.cpu().numpy()) + +if __name__ == "__main__": + # Create sample waveform + wav = make_sine(frequency=440.0, sr=SAMPLE_RATE, duration=DURATION_S) + + # ----------------------- + # Instantiate your model + # ----------------------- + # Example 1: from_pretrained API (if you implemented it) + try: + # If your class supports from_pretrained like HookedEncoder + from my_hubert_module import HookedAudioEncoder # <-- CHANGE to your module path + try: + model = HookedAudioEncoder.from_pretrained("your/checkpoint/name").to(DEVICE) + except AttributeError: + # fallback to direct constructor + model = HookedAudioEncoder().to(DEVICE) + except ImportError: + raise SystemExit("Please change the import 'from my_hubert_module import HookedAudioEncoder' to your actual module path.") + + # Run tests + run_basic_sanity_tests(model, wav) + + # Optionally compare to HF (network required) + # optional_compare_to_hf(model, wav, sr=SAMPLE_RATE) From e8bbf84634484445e9639ba3240863f64f25cc92 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 13:36:54 -0500 Subject: [PATCH 17/82] Update hubert_test.py --- hubert_test.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/hubert_test.py b/hubert_test.py index df85c0087..29cfdfd45 100644 --- a/hubert_test.py +++ b/hubert_test.py @@ -129,19 +129,9 @@ def optional_compare_to_hf(your_model, waveform_np, sr=SAMPLE_RATE): # Instantiate your model # ----------------------- # Example 1: from_pretrained API (if you implemented it) - try: - # If your class supports from_pretrained like HookedEncoder - from my_hubert_module import HookedAudioEncoder # <-- CHANGE to your module path - try: - model = HookedAudioEncoder.from_pretrained("your/checkpoint/name").to(DEVICE) - except AttributeError: - # fallback to direct constructor - model = HookedAudioEncoder().to(DEVICE) - except ImportError: - raise SystemExit("Please change the import 'from my_hubert_module import HookedAudioEncoder' to your actual module path.") - + model = HookedAudioEncoder.from_pretrained("your/checkpoint/name").to(DEVICE) # Run tests run_basic_sanity_tests(model, wav) - + # Optionally compare to HF (network required) - # optional_compare_to_hf(model, wav, sr=SAMPLE_RATE) + optional_compare_to_hf(model, wav, sr=SAMPLE_RATE) From 86ac1d94a49c121fa32faebae287044564476ba6 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 13:53:31 -0500 Subject: [PATCH 18/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index 95e98c45c..1f876e8cf 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -34,7 +34,7 @@ T = TypeVar("T", bound="HookedEncoder") -class HookedEncoder(HookedRootModule): +class HookedAudioEncoder(HookedRootModule): """ This class implements a BERT-style encoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. From 8f1b8896dc7d01c7edecc70feefd3d62b4f88ae6 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 13:55:28 -0500 Subject: [PATCH 19/82] Create hubert_ctc_test.py --- hubert_ctc_test.py | 231 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 hubert_ctc_test.py diff --git a/hubert_ctc_test.py b/hubert_ctc_test.py new file mode 100644 index 000000000..a37b833d9 --- /dev/null +++ b/hubert_ctc_test.py @@ -0,0 +1,231 @@ +# test_hubert_ctc_lmhead.py +""" +Test script to verify HookedAudioEncoder.forward(..., use_ctc=True) +loads/uses an lm_head and produces CTC logits. + +Usage: + python test_hubert_ctc_lmhead.py +Change the import to point at your HookedAudioEncoder implementation. +""" + +import torch +import numpy as np +import math +import sys +from transformer_lens import HookedAudioEncoder + +# ----- CONFIG ----- +SAMPLE_RATE = 16000 +DURATION_S = 1.0 +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +BATCH_SIZE = 1 +# If you want to attempt optional decoding with a HF tokenizer, +# set TOKENIZER_NAME to a valid tokenizer (e.g. "facebook/wav2vec2-base-960h") +# or set to None to skip tokenizer decoding. +TOKENIZER_NAME = "facebook/wav2vec2-base-960h" +# ------------------ + +def make_sine(frequency=440.0, sr=SAMPLE_RATE, duration=DURATION_S, amplitude=0.1): + t = np.linspace(0, duration, int(sr*duration), endpoint=False, dtype=np.float32) + return amplitude * np.sin(2 * math.pi * frequency * t) + +def has_lm_head(model): + return any(name.endswith("lm_head") or name == "lm_head" for name, _ in model.named_children()) or hasattr(model, "lm_head") + +def try_get_lm_head(model): + if hasattr(model, "lm_head"): + return model.lm_head + # try common nested names + for name, module in model.named_modules(): + if name.endswith("lm_head") or name == "lm_head": + return module + return None + +def print_param_info(module, prefix=""): + if module is None: + print(prefix + "None") + return + params = list(module.parameters()) + print(prefix + f"module type: {type(module)}, #params: {sum(p.numel() for p in params)}") + # print weight shape if available + if hasattr(module, "weight"): + try: + print(prefix + f" weight.shape = {tuple(module.weight.shape)}") + except Exception: + pass + +if __name__ == "__main__": + model = HookedAudioEncoder.from_pretrained("facebook/hubert-base-ls960") + + model.to(DEVICE) + + # sample waveform + wav = make_sine(frequency=440.0) + x = torch.from_numpy(wav).unsqueeze(0).to(DEVICE) # shape (1, T) + + print("=== lm_head presence BEFORE forward() ===") + print("has_lm_head():", has_lm_head(model)) + print("try_get_lm_head():") + print_param_info(try_get_lm_head(model), prefix=" ") + + # Forward pass with use_ctc=True (some model APIs accept it directly, some do not). + print("\nCalling forward(..., use_ctc=True) -- if that fails, will set attribute and call without arg") + logits = None + forward_exc = None + try: + # try direct call with argument + out = model(x, use_ctc=True) + except TypeError as e: + # forward signature may not accept use_ctc param; try setting attribute on model and call + forward_exc = e + print("Direct forward(..., use_ctc=True) failed with TypeError - will try setting model.use_ctc = True and calling forward(x).") + try: + if hasattr(model, "use_ctc"): + model.use_ctc = True + else: + # set attribute anyway + setattr(model, "use_ctc", True) + out = model(x) + except Exception as e2: + print("Forward still failed after setting model.use_ctc =", e2) + raise + + # Normalize out to logits tensor if possible + def extract_logits(out): + if out is None: + return None + if isinstance(out, torch.Tensor): + return out # assume logits + # dict-like outputs: look for common keys + if isinstance(out, dict): + for key in ("logits", "ctc_logits", "predictions", "hidden_states"): + if key in out: + t = out[key] + # if hidden_states is (batch, seq, dim) that's also fine to inspect + if isinstance(t, torch.Tensor): + return t + # if no known keys found, try to pick first tensor value + for v in out.values(): + if isinstance(v, torch.Tensor): + return v + # fallback: try to convert + return None + + logits = extract_logits(out) + print("\n=== Post-forward lm_head presence ===") + print("has_lm_head():", has_lm_head(model)) + lm = try_get_lm_head(model) + print("try_get_lm_head():") + print_param_info(lm, prefix=" ") + + if logits is None: + print("\nCould not automatically extract logits from the model output. The model returned:", type(out)) + # if out is tensor-like but not torch tensor, attempt conversion + if hasattr(out, "numpy"): + try: + logits = torch.from_numpy(out.numpy()).to(DEVICE) + except Exception: + pass + + if logits is not None: + print("\n=== Logits / CTC output info ===") + print("logits type:", type(logits)) + print("logits shape:", tuple(logits.shape)) + # typical CTC logits shape: (batch, time, vocab_size) or (batch, seq_len, vocab) + try: + print("stats: min=%.6g max=%.6g mean=%.6g" % (logits.min().item(), logits.max().item(), logits.mean().item())) + except Exception: + pass + assert torch.isfinite(logits).all(), "Found NaNs/Infs in logits!" + + # simple decode: argmax over last dim -> token ids + if logits.ndim >= 2: + token_dim = -1 + token_ids = logits.argmax(dim=token_dim) # shape: (batch, time) + token_ids_cpu = token_ids.detach().cpu().numpy() + print("Sample argmax token ids (first batch, up to first 40 frames):") + print(token_ids_cpu[0][:40].tolist()) + + # Optional: try to decode token ids to text if a tokenizer is available + if TOKENIZER_NAME is not None: + try: + from transformers import AutoTokenizer + tok = AutoTokenizer.from_pretrained(TOKENIZER_NAME) + # For many CTC tokenizers, you need to collapse repeats and remove blank token id (often id=0 or tok.pad_token_id) + # Here we do a naive collapse+remove assuming blank token is tokenizer.pad_token_id or tokenizer.pad_token_id==tok.pad_token_id + blank_id = getattr(tok, "pad_token_id", None) + seq = token_ids_cpu[0].tolist() + # collapse repeats and remove blanks + collapsed = [] + prev = None + for t in seq: + if t == prev: + prev = t + continue + prev = t + if blank_id is not None and t == blank_id: + continue + collapsed.append(t) + decoded = tok.decode(collapsed, skip_special_tokens=True) + print("Decoded (naive collapse) text:", decoded) + except Exception as e: + print("Optional decoding failed:", e) + + else: + print("No logits found — cannot run CTC-specific checks.") + + # Gradient test specifically for transformer encoder (since lm_head is frozen) + print("\nRunning gradient propagation test through transformer encoder...") + + model.train() + for p in model.parameters(): + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + + try: + out2 = model(x, use_ctc=True) + except TypeError: + if hasattr(model, "use_ctc"): + model.use_ctc = True + out2 = model(x) + + logits2 = extract_logits(out2) + if logits2 is None: + print("Could not extract logits for gradient test; aborting gradient check.") + else: + loss = logits2.mean() + loss.backward() + + # --- Check that lm_head is frozen --- + lm = try_get_lm_head(model) + if lm is not None: + lm_params = list(lm.parameters()) + grads = [p.grad for p in lm_params if p.grad is not None] + if len(grads) > 0: + print("Warning: lm_head has gradients, but it should be frozen (eval mode).") + else: + print("✅ lm_head correctly frozen (no gradients).") + + # --- Check that transformer block parameters have gradients --- + has_transformer_grad = False + for name, p in model.named_parameters(): + if "transformer" in name or "encoder" in name or "block" in name: + print(name) + if p.grad is not None and torch.isfinite(p.grad).all(): + has_transformer_grad = True + break + + if has_transformer_grad: + print("✅ Gradient test PASSED: transformer block parameters have finite gradients.") + else: + print("❌ Gradient test FAILED: no gradients found in transformer blocks.") + + + print("\n=== DONE ===") + print("Interpretation notes:") + print(" - If lm_head appears AFTER calling forward(use_ctc=True) and logits shape looks like (B, T, V),") + print(" then your forward-path is constructing/attaching an lm_head and producing CTC logits.") + print(" - If lm_head parameters have finite gradients after loss.backward(), the head is hooked into the graph.") + print(" - If you want a numeric golden-check, instantiate a HF Hubert/Wav2Vec2 CTC model and compare pooled logits/ids (optional).") + print(model.named_parameters()) From afc2a35d63281a8fe8ce4dd014865fd295119c30 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 13:59:20 -0500 Subject: [PATCH 20/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index 1f876e8cf..e1a10fced 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -1,6 +1,6 @@ -"""Hooked Encoder. +"""Hooked Audio Encoder. -Contains a BERT style model. This is separate from :class:`transformer_lens.HookedTransformer` +Contains a HuBERT style model. This is separate from :class:`transformer_lens.HookedTransformer` because it has a significantly different architecture to e.g. GPT style transformers. """ From f94fa407ee604665377da41af5e98723c51b8739 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 14:08:21 -0500 Subject: [PATCH 21/82] Create hubert_hook_test.py --- hubert_hook_test.py | 201 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 hubert_hook_test.py diff --git a/hubert_hook_test.py b/hubert_hook_test.py new file mode 100644 index 000000000..457f19f5f --- /dev/null +++ b/hubert_hook_test.py @@ -0,0 +1,201 @@ +# test_hubert_hooks.py +import torch +import numpy as np +import math +import circuitsvis as cv + +# transformer-lens utils used in your LLaMA example: +try: + from transformer_lens import utils +except Exception: + # if you put utils somewhere else, import accordingly + from transformer_lens import utils + +# ---- Replace these imports with your implementations ---- +# from my_hubert_module import HookedAudioEncoder +# from my_hubert_module import YourWrapperClass # whichever class exposes run_with_cache/run_with_hooks and to_frames +# --------------------------------------------------------- + +# ---- Simple sine audio generator ---- +SAMPLE_RATE = 16000 +DURATION_S = 1.0 +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + +def make_sine(sr=SAMPLE_RATE, duration=DURATION_S, freq=440.0, amp=0.1): + t = np.linspace(0, duration, int(sr*duration), endpoint=False, dtype=np.float32) + return amp * np.sin(2 * math.pi * freq * t) + +# ---- Adapt these to your instantiated model/wrapper ---- +# instantiate your HookedAudioEncoder (example) +# model = HookedAudioEncoder.from_pretrained("...").to(DEVICE) +# or if you have a wrapper with to_frames(), instantiate wrapper +# audio_wrapper = YourWrapperClass(...) +# audio_wrapper.hubert_model etc. + +# For this template I'll assume: +# - `audio_model` is the object that implements .to_frames(raw_inputs...) -> (frames, frame_mask) +# - `audio_model` exposes: run_with_cache(frames, attention_mask=frame_mask, remove_batch_dim=True) +# - `audio_model` exposes: run_with_hooks(frames, fwd_hooks=[(act_name, hook_fn)], attention_mask=frame_mask, return_type=...) +# If your API names differ, substitute accordingly. + +def main(): + # --- Build a 1s test waveform --- + wav = make_sine() + # If to_frames expects numpy or torch, both are accepted by your implementation + raw_batch = [wav] # batch of one + + # --- Convert to frames using your helper (you provided to_frames) --- + # IMPORTANT: use the same sampling_rate you used during training/FT (16k typical) + try: + frames, frame_mask = audio_model.to_frames(raw_batch, sampling_rate=SAMPLE_RATE, move_to_device=True) + except NameError: + raise RuntimeError("Replace `audio_model` with your model/wrapper instance that implements to_frames().") + + # frames shape expected: (batch, frames, hidden) ; frame_mask: (batch, frames) (1/0) + print("frames.shape:", tuple(frames.shape)) + if frame_mask is not None: + print("frame_mask.shape:", tuple(frame_mask.shape)) + + # --- Run with cache to inspect attention pattern --- + # remove_batch_dim=True makes cached activations shaped like (pos, ...) for easier visualization (like LLaMA example) + cache = audio_model.run_with_cache(frames, attention_mask=frame_mask, remove_batch_dim=True) + + # Picking a layer and head for visualization + layer_to_visualize = 0 + # act name for attention pattern — this is the same helper you used earlier + pattern_name = utils.get_act_name("pattern", layer_to_visualize) # e.g. "pattern_0" depending on utils + # some implementations store pattern as (layer, "attn") tuple; utils.get_act_name helps avoid mistakes + + # Extract attention pattern. Adapt this extraction if your cache key structure differs: + try: + attention_pattern = cache[pattern_name] # expected shape: (pos, pos, n_heads) or (pos, n_heads, pos) depending on implementation + except Exception: + # fallback: try tuple-key style + try: + attention_pattern = cache["pattern", layer_to_visualize, "attn"] + except Exception as exc: + raise RuntimeError(f"Couldn't find attention pattern in cache. Keys: {list(cache.keys())}") from exc + + # Build human-friendly "tokens" for frames (e.g. frame indices as strings) + n_frames = attention_pattern.shape[0] + frame_tokens = [f"f{i}" for i in range(n_frames)] + + print("Layer", layer_to_visualize, "attention pattern shape:", tuple(attention_pattern.shape)) + print("Displaying attention patterns (layer", layer_to_visualize, ")") + display(cv.attention.attention_patterns(tokens=frame_tokens, attention=attention_pattern)) + + # --- Define a head ablation hook (zero out a given head's v output) --- + head_index_to_ablate = 0 + layer_to_ablate = 0 + + # Hook target: v (value output) or "pattern" depending on what you'd like to ablate. + # Using the 'v' activation is a common choice, same form as your LLaMA example. + v_act_name = utils.get_act_name("v", layer_to_ablate) + + def head_ablation_hook(value, hook): + """ + value expected shape: [batch pos head d_head] OR [pos head d_head] when remove_batch_dim=True + We'll allow both shapes. + """ + # convert to mutable clone (some frameworks give non-writable tensors) + v = value.clone() + if v.ndim == 4: + # (B, pos, heads, d) + v[:, :, head_index_to_ablate, :] = 0.0 + elif v.ndim == 3: + # (pos, heads, d) + v[:, head_index_to_ablate, :] = 0.0 + else: + raise RuntimeError(f"Unexpected v tensor ndim={v.ndim}") + return v + + # --- Compute a downstream quantity without ablation --- + # Choose a metric you care about. Good choices: + # - CTC logits (if using use_ctc=True) -> argmax tokens or loss + # - Pooled encoder representation (mean of final resid_post) -> cosine similarity + # We'll implement both: try to extract CTC logits from model output; if not found, use pooled resid_post. + + def run_and_get_repr(frames, frame_mask, hooks=None): + # hooks: list of (act_name, hook_fn) tuples for run_with_hooks + if hooks is None: + # run_with_cache to gather activations + cache = audio_model.run_with_cache(frames, attention_mask=frame_mask, remove_batch_dim=True) + out = audio_model.run_with_hooks(frames, return_type=None, fwd_hooks=[]) + # NOTE: if your API returns outputs directly from run_with_cache, adapt as needed. + else: + # run with hooks and also capture cache + # run_with_hooks typically returns output (or logits) and optionally a cache depending on your implementation + out = audio_model.run_with_hooks(frames, fwd_hooks=hooks, attention_mask=frame_mask, return_type="both") + # If return_type="both" isn't supported, you can run run_with_cache and run_with_hooks separately. + # Try to extract CTC logits from `out` first + logits = None + if isinstance(out, dict): + for k in ("logits", "ctc_logits", "logits_ctc", "predictions"): + if k in out and isinstance(out[k], torch.Tensor): + logits = out[k] + break + elif isinstance(out, torch.Tensor): + # ambiguous: could be embeddings or logits + logits = out + + # if logits exist -> pooled logits (mean over time) as representation + if logits is not None: + # ensure shape (batch, time, vocab) -> pool over time axis (1) + if logits.ndim == 3: + pooled = logits.mean(dim=1) # (batch, vocab) + elif logits.ndim == 2: + pooled = logits # maybe (batch, vocab) + else: + pooled = logits.view(logits.shape[0], -1).mean(dim=1, keepdim=True) + return pooled, logits, None # third slot reserved for cache + + # fallback: use final residual activation from cache (resid_post of last layer) + try: + last_layer = audio_model.cfg.n_layers - 1 + resid_name = utils.get_act_name("resid_post", last_layer) + # get cache from run_with_cache (we ran above) + cache = audio_model.run_with_cache(frames, attention_mask=frame_mask, remove_batch_dim=True) + resid = cache[resid_name] # e.g. (pos, d) or (batch,pos,d) + # mean-pool across pos dimension + if resid.ndim == 3: + pooled = resid.mean(dim=1) # (batch, d) + elif resid.ndim == 2: + pooled = resid.mean(dim=0, keepdim=True) + else: + raise RuntimeError("Unexpected resid_post shape") + return pooled, None, cache + except Exception as e: + raise RuntimeError("Couldn't extract logits or resid_post; adapt the extraction to your model's output format.") from e + + # Get baseline representation + baseline_repr, baseline_logits, baseline_cache = run_and_get_repr(frames, frame_mask, hooks=None) + print("Baseline representation shape:", tuple(baseline_repr.shape)) + + # --- Run with ablation hook and get representation --- + hooks = [(v_act_name, head_ablation_hook)] + ablated_repr, ablated_logits, ablated_cache = run_and_get_repr(frames, frame_mask, hooks=hooks) + print("Ablated representation shape:", tuple(ablated_repr.shape)) + + # --- Compare representations (cosine similarity) --- + cos = torch.nn.functional.cosine_similarity(baseline_repr, ablated_repr, dim=-1) + print("Cosine similarity baseline vs ablated:", cos.detach().cpu().numpy()) + + # If you have logits, you can also compare token sequences (argmax) or loss increase + if baseline_logits is not None and ablated_logits is not None: + b_ids = baseline_logits.argmax(dim=-1) # (batch, time) + a_ids = ablated_logits.argmax(dim=-1) + print("Sample argmax token ids (baseline):", b_ids[0][:40].cpu().numpy().tolist()) + print("Sample argmax token ids (ablated): ", a_ids[0][:40].cpu().numpy().tolist()) + + print("Done. Interpret the results:") + print(" - A large drop in cosine similarity (or large change in argmax tokens / increase in loss) means the ablated head mattered.") + print(" - If ablation causes little change, that head may be redundant or not used for this example.") + +if __name__ == "__main__": + # create/instantiate your model here: replace the placeholder below + # Example: + # audio_model = HookedAudioEncoder.from_pretrained("...").to(DEVICE) + # audio_model.cfg.device = DEVICE + # For wrapper that exposes to_frames: + # audio_model = YourWrapperClass(...) + main() From cff50b3930b74e67cf0b55cb24b49ecd68740775 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 14:15:07 -0500 Subject: [PATCH 22/82] Update hubert_hook_test.py --- hubert_hook_test.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/hubert_hook_test.py b/hubert_hook_test.py index 457f19f5f..70595ce4d 100644 --- a/hubert_hook_test.py +++ b/hubert_hook_test.py @@ -4,18 +4,15 @@ import math import circuitsvis as cv -# transformer-lens utils used in your LLaMA example: -try: - from transformer_lens import utils -except Exception: - # if you put utils somewhere else, import accordingly - from transformer_lens import utils - -# ---- Replace these imports with your implementations ---- -# from my_hubert_module import HookedAudioEncoder -# from my_hubert_module import YourWrapperClass # whichever class exposes run_with_cache/run_with_hooks and to_frames -# --------------------------------------------------------- - +from tqdm import tqdm +from jaxtyping import Float + +import transformer_lens +import transformer_lens.utils as utils +from transformer_lens.hook_points import ( + HookPoint, +) # Hooking utilities +from transformer_lens import HookedAudioEncoder # ---- Simple sine audio generator ---- SAMPLE_RATE = 16000 DURATION_S = 1.0 @@ -27,10 +24,7 @@ def make_sine(sr=SAMPLE_RATE, duration=DURATION_S, freq=440.0, amp=0.1): # ---- Adapt these to your instantiated model/wrapper ---- # instantiate your HookedAudioEncoder (example) -# model = HookedAudioEncoder.from_pretrained("...").to(DEVICE) -# or if you have a wrapper with to_frames(), instantiate wrapper -# audio_wrapper = YourWrapperClass(...) -# audio_wrapper.hubert_model etc. +audio_model = HookedAudioEncoder.from_pretrained("facebook/hubert-base-ls960", device="cuda") # For this template I'll assume: # - `audio_model` is the object that implements .to_frames(raw_inputs...) -> (frames, frame_mask) From 764810ab40b1b977dcdc75006b73576a10353e6e Mon Sep 17 00:00:00 2001 From: jiankunwei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 01:29:41 +0000 Subject: [PATCH 23/82] done --- hubert_ctc_test.py | 6 +- hubert_test.py | 2 +- transformer_lens/HookedAudioEncoder.py | 49 ++++-- transformer_lens/__init__.py | 1 + .../pretrained/weight_conversions/__init__.py | 1 + .../pretrained/weight_conversions/hubert.py | 161 ++++++++++++------ 6 files changed, 146 insertions(+), 74 deletions(-) diff --git a/hubert_ctc_test.py b/hubert_ctc_test.py index a37b833d9..f223e2ef3 100644 --- a/hubert_ctc_test.py +++ b/hubert_ctc_test.py @@ -22,7 +22,7 @@ # If you want to attempt optional decoding with a HF tokenizer, # set TOKENIZER_NAME to a valid tokenizer (e.g. "facebook/wav2vec2-base-960h") # or set to None to skip tokenizer decoding. -TOKENIZER_NAME = "facebook/wav2vec2-base-960h" +TOKENIZER_NAME = "facebook/hubert-base-ls960-ft" # ------------------ def make_sine(frequency=440.0, sr=SAMPLE_RATE, duration=DURATION_S, amplitude=0.1): @@ -55,7 +55,7 @@ def print_param_info(module, prefix=""): pass if __name__ == "__main__": - model = HookedAudioEncoder.from_pretrained("facebook/hubert-base-ls960") + model = HookedAudioEncoder.from_pretrained("facebook/hubert-base-ls960-ft") model.to(DEVICE) @@ -211,7 +211,7 @@ def extract_logits(out): has_transformer_grad = False for name, p in model.named_parameters(): if "transformer" in name or "encoder" in name or "block" in name: - print(name) + print(name) if p.grad is not None and torch.isfinite(p.grad).all(): has_transformer_grad = True break diff --git a/hubert_test.py b/hubert_test.py index 29cfdfd45..ef971dbaa 100644 --- a/hubert_test.py +++ b/hubert_test.py @@ -129,7 +129,7 @@ def optional_compare_to_hf(your_model, waveform_np, sr=SAMPLE_RATE): # Instantiate your model # ----------------------- # Example 1: from_pretrained API (if you implemented it) - model = HookedAudioEncoder.from_pretrained("your/checkpoint/name").to(DEVICE) + model = HookedAudioEncoder.from_pretrained("facebook/hubert-base-ls960").to(DEVICE) # Run tests run_basic_sanity_tests(model, wav) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index e1a10fced..383e23dee 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -16,7 +16,7 @@ import numpy as np from einops import repeat from jaxtyping import Float, Int -from transformers import AutoProcessor, HubertModel, HubertForCTC +from transformers import AutoProcessor, HubertModel, HubertForCTC, AutoFeatureExtractor from transformers.models.auto.tokenization_auto import AutoTokenizer import transformer_lens.loading_from_pretrained as loading @@ -65,7 +65,15 @@ def __init__( assert self.cfg.n_devices == 1, "Multiple devices not supported for HookedEncoder" self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)]) - processor = AutoProcessor.from_pretrained(model_name) # builds input_values + attention_mask + if model_name.endswith("-ft") and use_ctc: + # fine-tuned model (has CTC head) + use_ctc = True + processor = AutoProcessor.from_pretrained(model_name) # builds input_values + attention_mask + else: + # pretraining-only model (no CTC) + use_ctc = False + processor = AutoFeatureExtractor.from_pretrained(model_name) + if use_ctc: hubert_model = HubertForCTC.from_pretrained(model_name) else: @@ -90,7 +98,7 @@ def __init__( self.setup() - def _ensure_tensor(wave): + def _ensure_tensor(self, wave): """Convert numpy array or python list to 1D torch.float tensor.""" if isinstance(wave, np.ndarray): return torch.from_numpy(wave).float() @@ -101,6 +109,7 @@ def _ensure_tensor(wave): raise TypeError("wave must be torch.Tensor, np.ndarray or list of floats") def to_frames( + self, raw_inputs: Union[torch.Tensor, List[torch.Tensor], List[np.ndarray]], sampling_rate: int = 16000, move_to_device: bool = True, @@ -122,17 +131,18 @@ def to_frames( frame_attention_mask: torch.LongTensor of shape (batch, frames) with 1 for real frames, 0 for padding """ # raw_inputs are arrays/tensors + # print(type(raw_inputs)) if isinstance(raw_inputs, (torch.Tensor, np.ndarray)): - waves = [_ensure_tensor(raw_inputs)] + waves = [self._ensure_tensor(raw_inputs)] elif isinstance(raw_inputs, list): - waves = [_ensure_tensor(w) for w in raw_inputs] + waves = [self._ensure_tensor(w) for w in raw_inputs] else: raise TypeError("Unsupported raw_inputs type") # Use HF processor to create input_values (padded) + sample-level attention_mask # Processor will do padding so we can pass a variable-length batch + waves = [w.detach().cpu() for w in waves] proc_out = self.processor(waves, sampling_rate=sampling_rate, return_tensors="pt", padding=True) - input_values = proc_out["input_values"] # (batch, samples), float sample_attention_mask = proc_out.get("attention_mask") # (batch, samples), 1 for valid, 0 for padding; may be None @@ -144,6 +154,10 @@ def to_frames( sample_attention_mask = sample_attention_mask.to(device) # 1) convolutional frontend -> (batch, conv_dim, conv_time) + if input_values.ndim > 2: + input_values = input_values.squeeze() + if input_values.ndim == 1: + input_values = input_values.unsqueeze(0) # (1, T) with torch.no_grad(): conv_feats = self.hubert_model.feature_extractor(input_values) # (B, C, T_conv) @@ -198,7 +212,7 @@ def encoder_output( one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) position_embeddings = self.hubert_model.encoder.pos_conv_embed(frames) - resid = resid + position_embeddings + resid = frames + position_embeddings resid = self.hubert_model.encoder.layer_norm(resid) large_negative_number = -torch.inf @@ -217,7 +231,7 @@ def encoder_output( def forward( self, - input: Union[ + inputs: Union[ torch.Tensor, # waveform (1D) OR precomputed frames (3D) List[Union[torch.Tensor, np.ndarray]], # list of waveforms Tuple[torch.Tensor, torch.Tensor], # (frames, frame_mask) @@ -245,20 +259,20 @@ def forward( # ---------- 1) Normalize input: get (frames, frame_mask) ---------- frames = None frame_mask = None # one_zero_attention_mask: 1 = valid, 0 = padding - + # print(type(inputs)) # If user passed (frames, mask) tuple - if isinstance(input, tuple) and len(input) == 2 and isinstance(input[0], torch.Tensor): - frames, frame_mask = input + if isinstance(inputs, tuple) and len(inputs) == 2 and isinstance(inputs[0], torch.Tensor): + frames, frame_mask = inputs # If user passed a 3D tensor -> assume (B, T, D) frames (pre-projected) - elif isinstance(input, torch.Tensor) and input.ndim == 3: - frames = input + elif isinstance(inputs, torch.Tensor) and inputs.ndim == 3: + frames = inputs # frame_mask stays whatever was passed as separate argument (None here) # Else treat as raw waveform(s) -> call to_frames else: # allow single 1D tensor or numpy array or list of tensors/arrays - frames, frame_mask = self.to_frames(input, sampling_rate=sampling_rate, move_to_device=move_to_device) + frames, frame_mask = self.to_frames(inputs) # to_frames should already place tensors on device if move_to_device=True # ---------- 2) Ensure device & dtype consistency ---------- @@ -370,6 +384,13 @@ def from_pretrained( dtype = from_pretrained_kwargs["torch_dtype"] official_model_name = loading.get_official_model_name(model_name) + + if model_name.endswith("-ft") and use_ctc: + # fine-tuned model (has CTC head) + use_ctc = True + else: + # pretraining-only model (no CTC) + use_ctc = False cfg = loading.get_pretrained_model_config( official_model_name, diff --git a/transformer_lens/__init__.py b/transformer_lens/__init__.py index 1e2ff1e1a..7e6183c71 100644 --- a/transformer_lens/__init__.py +++ b/transformer_lens/__init__.py @@ -13,6 +13,7 @@ from .HookedTransformer import HookedTransformer from .SVDInterpreter import SVDInterpreter from .HookedEncoder import HookedEncoder +from .HookedAudioEncoder import HookedAudioEncoder from .HookedEncoderDecoder import HookedEncoderDecoder from .BertNextSentencePrediction import BertNextSentencePrediction from . import head_detector diff --git a/transformer_lens/pretrained/weight_conversions/__init__.py b/transformer_lens/pretrained/weight_conversions/__init__.py index c5ea9581b..daaffe472 100644 --- a/transformer_lens/pretrained/weight_conversions/__init__.py +++ b/transformer_lens/pretrained/weight_conversions/__init__.py @@ -19,3 +19,4 @@ from .nanogpt import convert_nanogpt_weights from .t5 import convert_t5_weights from .neel_solu_old import convert_neel_solu_old_weights +from .hubert import convert_hubert_weights diff --git a/transformer_lens/pretrained/weight_conversions/hubert.py b/transformer_lens/pretrained/weight_conversions/hubert.py index d8de2dc0a..bfb0c00e8 100644 --- a/transformer_lens/pretrained/weight_conversions/hubert.py +++ b/transformer_lens/pretrained/weight_conversions/hubert.py @@ -2,69 +2,118 @@ import torch from transformer_lens.HookedTransformerConfig import HookedTransformerConfig - def convert_hubert_weights(hf_model, cfg: HookedTransformerConfig): """ - Convert a Hugging Face HuBERT model's transformer encoder weights - into a TransformerLens-compatible state_dict. - - This ignores HuBERT's convolutional feature extractor and feature projection, - since we assume they are handled externally (e.g., via hf_model.feature_extractor - and hf_model.feature_projection). - - Args: - hf_model: A pretrained HuggingFace HuBERT model (e.g., HubertModel.from_pretrained(...)) - cfg: TransformerLens HookedTransformerConfig + Convert transformer encoder weights from a HuggingFace HuBERT model + into the state_dict expected by Transformer-Lens' HookedEncoder. - Returns: - state_dict: a dict mapping TransformerLens parameter names to torch tensors - suitable for model.load_state_dict(state_dict, strict=False) + Notes: + - This intentionally skips the convolutional frontend and feature_projection. + Those are used directly from the HF model (hf_model.feature_extractor, hf_model.feature_projection). + - Use model.load_state_dict(state_dict, strict=False) to load these. """ state_dict = {} - # Shortcut to encoder layers - encoder_layers = hf_model.encoder.layers + # Try to find the encoder layer list (different HF variants use .layers or .layer) + encoder = getattr(hf_model, "encoder", None) + if encoder is None: + raise ValueError("hf_model has no .encoder attribute") + + encoder_layers = getattr(encoder, "layers", None) or getattr(encoder, "layer", None) + if encoder_layers is None: + # maybe hf_model itself is the encoder (unlikely), or a wrapped attribute + raise ValueError("Couldn't find encoder.layers or encoder.layer on hf_model.encoder") + + # Use cfg dims for reshaping + d_model = cfg.d_model + n_heads = cfg.n_heads + # d_head = d_model // n_heads # implicit if needed for l, layer in enumerate(encoder_layers): - # --- Self-attention projections --- - q_proj = layer.self_attn.q_proj.weight - k_proj = layer.self_attn.k_proj.weight - v_proj = layer.self_attn.v_proj.weight - out_proj = layer.self_attn.out_proj.weight - - # Reshape Q, K, V into [n_heads, d_model, d_head] - d_model = cfg.d_model - n_heads = cfg.n_heads - d_head = d_model // n_heads - - state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange( - q_proj, "(n h) m -> n m h", n=n_heads - ) - state_dict[f"blocks.{l}.attn.W_K"] = einops.rearrange( - k_proj, "(n h) m -> n m h", n=n_heads - ) - state_dict[f"blocks.{l}.attn.W_V"] = einops.rearrange( - v_proj, "(n h) m -> n m h", n=n_heads - ) - state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange( - out_proj, "m (n h) -> n h m", n=n_heads - ) - - # --- LayerNorms --- - state_dict[f"blocks.{l}.ln1.w"] = layer.layer_norm.weight - state_dict[f"blocks.{l}.ln1.b"] = layer.layer_norm.bias - state_dict[f"blocks.{l}.ln2.w"] = layer.final_layer_norm.weight - state_dict[f"blocks.{l}.ln2.b"] = layer.final_layer_norm.bias - - # --- Feed-forward (MLP) --- - fc1 = layer.fc1.weight - fc2 = layer.fc2.weight - fc1_bias = layer.fc1.bias - fc2_bias = layer.fc2.bias - - state_dict[f"blocks.{l}.mlp.W_in"] = fc1.T # shape [d_model, d_mlp] - state_dict[f"blocks.{l}.mlp.b_in"] = fc1_bias - state_dict[f"blocks.{l}.mlp.W_out"] = fc2.T # shape [d_mlp, d_model] - state_dict[f"blocks.{l}.mlp.b_out"] = fc2_bias + # --- Attention module --- + # Some HF variants might call it `attention`, others `self_attn` etc. + att = getattr(layer, "attention", None) or getattr(layer, "self_attn", None) + if att is None: + raise AttributeError(f"Encoder layer {l} has no 'attention' or 'self_attn' attribute") + + # q/k/v/out proj names in HuBERT's HubertAttention: q_proj, k_proj, v_proj, out_proj + # fall back to common alternatives if present + q_w = getattr(att, "q_proj", None) + k_w = getattr(att, "k_proj", None) + v_w = getattr(att, "v_proj", None) + o_w = getattr(att, "out_proj", None) or getattr(att, "proj", None) + + if any(x is None for x in (q_w, k_w, v_w, o_w)): + # Try alternate nested attributes like att.q, att.k, att.v, att.o + q_w = q_w or getattr(att, "q", None) + k_w = k_w or getattr(att, "k", None) + v_w = v_w or getattr(att, "v", None) + o_w = o_w or getattr(att, "o", None) + + if any(x is None for x in (q_w, k_w, v_w, o_w)): + raise AttributeError(f"Could not find q/k/v/out projections in layer {l}. Found: {att}") + + # weights are Linear modules: weight shape (out, in) => same convention as Bert conversion + # reshape to Transformer-Lens expected shapes using einops + state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange(q_w.weight, "(i h) m -> i m h", i=n_heads) + if q_w.bias is not None: + state_dict[f"blocks.{l}.attn.b_Q"] = einops.rearrange(q_w.bias, "(i h) -> i h", i=n_heads) + + state_dict[f"blocks.{l}.attn.W_K"] = einops.rearrange(k_w.weight, "(i h) m -> i m h", i=n_heads) + if k_w.bias is not None: + state_dict[f"blocks.{l}.attn.b_K"] = einops.rearrange(k_w.bias, "(i h) -> i h", i=n_heads) + + state_dict[f"blocks.{l}.attn.W_V"] = einops.rearrange(v_w.weight, "(i h) m -> i m h", i=n_heads) + if v_w.bias is not None: + state_dict[f"blocks.{l}.attn.b_V"] = einops.rearrange(v_w.bias, "(i h) -> i h", i=n_heads) + + state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange(o_w.weight, "m (i h) -> i h m", i=n_heads) + if o_w.bias is not None: + state_dict[f"blocks.{l}.attn.b_O"] = o_w.bias + + # --- Layer norms inside the layer --- + # HuBERT layer has `layer.layer_norm` and `layer.final_layer_norm` + ln1 = getattr(layer, "layer_norm", None) + ln2 = getattr(layer, "final_layer_norm", None) + if ln1 is None or ln2 is None: + # try alternative names + ln1 = ln1 or getattr(layer, "attention_norm", None) + ln2 = ln2 or getattr(layer, "output_layer_norm", None) + + if ln1 is not None: + state_dict[f"blocks.{l}.ln1.w"] = ln1.weight + state_dict[f"blocks.{l}.ln1.b"] = ln1.bias + if ln2 is not None: + state_dict[f"blocks.{l}.ln2.w"] = ln2.weight + state_dict[f"blocks.{l}.ln2.b"] = ln2.bias + + # --- Feed-forward / MLP --- + # HuBERT uses `feed_forward` which contains intermediate_dense and output_dense + ff = getattr(layer, "feed_forward", None) or getattr(layer, "feedforward", None) or getattr(layer, "ff", None) + if ff is None: + raise AttributeError(f"Layer {l} has no feed_forward/ff attribute") + + # Many implementations name them intermediate_dense and output_dense + fc1 = getattr(ff, "intermediate_dense", None) or getattr(ff, "fc1", None) or getattr(ff, "linear1", None) + fc2 = getattr(ff, "output_dense", None) or getattr(ff, "fc2", None) or getattr(ff, "linear2", None) + + if fc1 is None or fc2 is None: + raise AttributeError(f"Could not find FFN dense layers in layer {l}: {ff}") + + # fc1.weight shape: (d_mlp, d_model) -> Transformer-Lens expects (d_model, d_mlp) + state_dict[f"blocks.{l}.mlp.W_in"] = einops.rearrange(fc1.weight, "mlp model -> model mlp") + if fc1.bias is not None: + state_dict[f"blocks.{l}.mlp.b_in"] = fc1.bias + + # fc2.weight shape: (d_model, d_mlp) -> Transformer-Lens expects (d_mlp, d_model) + state_dict[f"blocks.{l}.mlp.W_out"] = einops.rearrange(fc2.weight, "model mlp -> mlp model") + if fc2.bias is not None: + state_dict[f"blocks.{l}.mlp.b_out"] = fc2.bias + + # --- Optional: encoder-level layer_norm (HubertModel.encoder.layer_norm) --- + if hasattr(hf_model.encoder, "layer_norm"): + ln_final = hf_model.encoder.layer_norm + state_dict["ln_final.w"] = ln_final.weight + state_dict["ln_final.b"] = ln_final.bias return state_dict From 7e844a3065ac36cc2af9fc8ae6fb3a1e2b7fa94e Mon Sep 17 00:00:00 2001 From: jiankunwei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 01:51:30 +0000 Subject: [PATCH 24/82] done --- hubert_ctc_test.py | 4 ++-- transformer_lens/HookedAudioEncoder.py | 14 +++++++++----- transformer_lens/loading_from_pretrained.py | 21 +++++++++++++++++++-- 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/hubert_ctc_test.py b/hubert_ctc_test.py index f223e2ef3..c5dc1577f 100644 --- a/hubert_ctc_test.py +++ b/hubert_ctc_test.py @@ -22,7 +22,7 @@ # If you want to attempt optional decoding with a HF tokenizer, # set TOKENIZER_NAME to a valid tokenizer (e.g. "facebook/wav2vec2-base-960h") # or set to None to skip tokenizer decoding. -TOKENIZER_NAME = "facebook/hubert-base-ls960-ft" +TOKENIZER_NAME = "facebook/hubert-large-ls960-ft" # ------------------ def make_sine(frequency=440.0, sr=SAMPLE_RATE, duration=DURATION_S, amplitude=0.1): @@ -55,7 +55,7 @@ def print_param_info(module, prefix=""): pass if __name__ == "__main__": - model = HookedAudioEncoder.from_pretrained("facebook/hubert-base-ls960-ft") + model = HookedAudioEncoder.from_pretrained("facebook/hubert-large-ls960-ft") model.to(DEVICE) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index 383e23dee..4f21a4204 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -87,6 +87,8 @@ def __init__( if use_ctc: self.hubert_model = hubert_model.hubert self.lm_head = hubert_model.lm_head + for p in self.lm_head.parameters(): + p.requires_grad = False else: self.hubert_model = hubert_model self.lm_head = None @@ -237,7 +239,7 @@ def forward( Tuple[torch.Tensor, torch.Tensor], # (frames, frame_mask) ], sampling_rate: int = 16000, - use_proj: bool = False, + use_ctc: bool = False, move_to_device: bool = True, ) -> Optional[torch.Tensor]: """ @@ -285,13 +287,15 @@ def forward( # ---------- 3) Run encoder (respects pos_conv_embed / layer_norm / dropout inside encoder_output) ---------- resid = self.encoder_output(frames, frame_mask) # (B, T, d_model) - if use_proj: + if use_ctc: if self.lm_head is None: logging.warning("HubertForCTC not enabled") return resid - hidden_states = resid[0] # (B, T, d_model) - with torch.no_grad(): - resid = self.lm_head(hidden_states) # (B, T, vocab_size) + if isinstance(resid, tuple): + hidden_states = resid[0] # take last hidden state + else: + hidden_states = resid # already tensor + resid = self.lm_head(hidden_states) # (B, T, vocab_size) return resid diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index a4bb9c7ae..37add5fdd 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -64,7 +64,7 @@ "facebook/hubert-base-ls960", "facebook-hubert/hubert-large-ls960", "facebook-hubert/hubert-xlarge-ls960", - "facebook-hubert/hubert-large-ls960-ft", + "facebook/hubert-large-ls960-ft", "EleutherAI/gpt-neo-125M", "EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-2.7B", @@ -619,7 +619,7 @@ "facebook-hubert/hubert-base-ls960": ["facebook/hubert-base-ls960", "hubert-base-ls960"], "facebook-hubert/hubert-large-ls960": ["facebook/hubert-large-ls960", "hubert-large-ls960"], "facebook-hubert/hubert-xlarge-ls960": ["facebook/hubert-xlarge-ls960", "hubert-xlarge-ls960"], - "facebook-hubert/hubert-large-ls960-ft": ["facebook/hubert-large-ls960-ft", "hubert-large-ls960-ft"], + "facebook/hubert-large-ls960-ft": ["facebook/hubert-large-ls960-ft", "hubert-large-ls960-ft"], "roneneldan/TinyStories-1M": ["tiny-stories-1M"], "roneneldan/TinyStories-3M": ["tiny-stories-3M"], "roneneldan/TinyStories-8M": ["tiny-stories-8M"], @@ -1201,6 +1201,21 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "attention_dir": "bidirectional", "d_vocab": -1, # no text vocabulary } + elif architecture == "HubertForCTC": + # Basic transformer configuration + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), + "eps": hf_config.layer_norm_eps, + "act_fn": "gelu", + "attention_dir": "bidirectional", + # For CTC models: + "d_vocab": hf_config.vocab_size, # text vocab from tokenizer + } elif architecture == "BertForMaskedLM": # All supported Bert architectures have the same config, # so we can use the BertForMaskedLM config for all of them @@ -1994,6 +2009,8 @@ def get_pretrained_state_dict( state_dict = convert_llama_weights(hf_model, cfg) elif cfg.original_architecture == "HubertModel": state_dict = convert_hubert_weights(hf_model, cfg) + elif cfg.original_architecture == "HubertForCTC": + state_dict = convert_hubert_weights(hf_model, cfg) elif cfg.original_architecture == "BertForMaskedLM": state_dict = convert_bert_weights(hf_model, cfg) elif cfg.original_architecture == "T5ForConditionalGeneration": From 9a6bc7a9660229588491b1ce004def4c5849b640 Mon Sep 17 00:00:00 2001 From: jiankunwei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 02:11:12 +0000 Subject: [PATCH 25/82] done --- hubert_hook_test.py | 12 ++++++------ transformer_lens/HookedAudioEncoder.py | 3 ++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/hubert_hook_test.py b/hubert_hook_test.py index 70595ce4d..5cf64f0f1 100644 --- a/hubert_hook_test.py +++ b/hubert_hook_test.py @@ -52,7 +52,7 @@ def main(): # --- Run with cache to inspect attention pattern --- # remove_batch_dim=True makes cached activations shaped like (pos, ...) for easier visualization (like LLaMA example) - cache = audio_model.run_with_cache(frames, attention_mask=frame_mask, remove_batch_dim=True) + logits, cache = audio_model.run_with_cache(frames, one_zero_attention_mask=frame_mask, remove_batch_dim=True) # Picking a layer and head for visualization layer_to_visualize = 0 @@ -76,7 +76,7 @@ def main(): print("Layer", layer_to_visualize, "attention pattern shape:", tuple(attention_pattern.shape)) print("Displaying attention patterns (layer", layer_to_visualize, ")") - display(cv.attention.attention_patterns(tokens=frame_tokens, attention=attention_pattern)) + # display(cv.attention.attention_patterns(tokens=frame_tokens, attention=attention_pattern)) # --- Define a head ablation hook (zero out a given head's v output) --- head_index_to_ablate = 0 @@ -113,13 +113,13 @@ def run_and_get_repr(frames, frame_mask, hooks=None): # hooks: list of (act_name, hook_fn) tuples for run_with_hooks if hooks is None: # run_with_cache to gather activations - cache = audio_model.run_with_cache(frames, attention_mask=frame_mask, remove_batch_dim=True) - out = audio_model.run_with_hooks(frames, return_type=None, fwd_hooks=[]) + cache = audio_model.run_with_cache(frames, one_zero_attention_mask=frame_mask, remove_batch_dim=True) + out = audio_model.run_with_hooks(frames, fwd_hooks=[]) # NOTE: if your API returns outputs directly from run_with_cache, adapt as needed. else: # run with hooks and also capture cache # run_with_hooks typically returns output (or logits) and optionally a cache depending on your implementation - out = audio_model.run_with_hooks(frames, fwd_hooks=hooks, attention_mask=frame_mask, return_type="both") + out = audio_model.run_with_hooks(frames, fwd_hooks=hooks, one_zero_attention_mask=frame_mask) # If return_type="both" isn't supported, you can run run_with_cache and run_with_hooks separately. # Try to extract CTC logits from `out` first logits = None @@ -148,7 +148,7 @@ def run_and_get_repr(frames, frame_mask, hooks=None): last_layer = audio_model.cfg.n_layers - 1 resid_name = utils.get_act_name("resid_post", last_layer) # get cache from run_with_cache (we ran above) - cache = audio_model.run_with_cache(frames, attention_mask=frame_mask, remove_batch_dim=True) + cache = audio_model.run_with_cache(frames, one_zero_attention_mask=frame_mask, remove_batch_dim=True) resid = cache[resid_name] # e.g. (pos, d) or (batch,pos,d) # mean-pool across pos dimension if resid.ndim == 3: diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index 4f21a4204..ee92d3493 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -238,6 +238,7 @@ def forward( List[Union[torch.Tensor, np.ndarray]], # list of waveforms Tuple[torch.Tensor, torch.Tensor], # (frames, frame_mask) ], + one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, sampling_rate: int = 16000, use_ctc: bool = False, move_to_device: bool = True, @@ -276,7 +277,7 @@ def forward( # allow single 1D tensor or numpy array or list of tensors/arrays frames, frame_mask = self.to_frames(inputs) # to_frames should already place tensors on device if move_to_device=True - + frame_mask = frame_mask if one_zero_attention_mask is None else one_zero_attention_mask # ---------- 2) Ensure device & dtype consistency ---------- device = self.cfg.device if frames.device.type != device: From 1ddbf7fab2845d06309d42c631455ed7b3eb3fb0 Mon Sep 17 00:00:00 2001 From: jiankunwei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 02:12:20 +0000 Subject: [PATCH 26/82] done --- requirements.txt | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 requirements.txt diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..07d227d57 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +rich +torch +transformers +datasets +jaxtyping +datasets +einops +better_abc +typeguard +wandb +circuitsvis \ No newline at end of file From c646ee54c2401f8977eddfcdd548c81a96a5f348 Mon Sep 17 00:00:00 2001 From: jiankunwei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 04:21:54 +0000 Subject: [PATCH 27/82] done --- transformer_lens/loading_from_pretrained.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 37add5fdd..e13800327 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -62,9 +62,10 @@ "facebook/opt-30b", "facebook/opt-66b", "facebook/hubert-base-ls960", - "facebook-hubert/hubert-large-ls960", - "facebook-hubert/hubert-xlarge-ls960", + "facebook/hubert-large-ll60k", "facebook/hubert-large-ls960-ft", + "facebook/hubert-xlarge-ll60k", + "facebook/hubert-xlarge-ls960-ft", "EleutherAI/gpt-neo-125M", "EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-2.7B", @@ -616,10 +617,11 @@ "google-bert/bert-base-uncased": ["bert-base-uncased"], "google-bert/bert-large-cased": ["bert-large-cased"], "google-bert/bert-large-uncased": ["bert-large-uncased"], - "facebook-hubert/hubert-base-ls960": ["facebook/hubert-base-ls960", "hubert-base-ls960"], - "facebook-hubert/hubert-large-ls960": ["facebook/hubert-large-ls960", "hubert-large-ls960"], - "facebook-hubert/hubert-xlarge-ls960": ["facebook/hubert-xlarge-ls960", "hubert-xlarge-ls960"], + "facebook/hubert-base-ls960": ["facebook/hubert-base-ls960", "hubert-base-ls960"], + "facebook/hubert-large-ll60k": ["facebook/hubert-large-ll60k", "hubert-large-ll60k"], "facebook/hubert-large-ls960-ft": ["facebook/hubert-large-ls960-ft", "hubert-large-ls960-ft"], + "facebook/hubert-xlarge-ll60k": ["facebook/hubert-xlarge-ll60k", "hubert-xlarge-ll60k"], + "facebook/hubert-xlarge-ls960-ft": ["facebook/hubert-xlarge-ls960-ft", "hubert-xlarge-ls960-ft"], "roneneldan/TinyStories-1M": ["tiny-stories-1M"], "roneneldan/TinyStories-3M": ["tiny-stories-3M"], "roneneldan/TinyStories-8M": ["tiny-stories-8M"], From 7d5fe2aa7c197e77f8cc5ebe254ed044f47938c1 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:31:25 -0500 Subject: [PATCH 28/82] Rename hubert_ctc_test.py to demos/HuBERT_test/hubert_ctc_test.py --- hubert_ctc_test.py => demos/HuBERT_test/hubert_ctc_test.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename hubert_ctc_test.py => demos/HuBERT_test/hubert_ctc_test.py (100%) diff --git a/hubert_ctc_test.py b/demos/HuBERT_test/hubert_ctc_test.py similarity index 100% rename from hubert_ctc_test.py rename to demos/HuBERT_test/hubert_ctc_test.py From 21a0256510b8f8aca0cd9eb36f25b4db362fc53e Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:31:50 -0500 Subject: [PATCH 29/82] Rename hubert_hook_test.py to demos/HuBERT_test /hubert_hook_test.py --- hubert_hook_test.py => demos/HuBERT_test /hubert_hook_test.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename hubert_hook_test.py => demos/HuBERT_test /hubert_hook_test.py (100%) diff --git a/hubert_hook_test.py b/demos/HuBERT_test /hubert_hook_test.py similarity index 100% rename from hubert_hook_test.py rename to demos/HuBERT_test /hubert_hook_test.py From c9f7c68042c76d367b3e0ef30d686c7e59f14a09 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:32:23 -0500 Subject: [PATCH 30/82] Rename hubert_hook_test.py to hubert_hook_test.py --- demos/{HuBERT_test => HuBERT_test}/hubert_hook_test.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename demos/{HuBERT_test => HuBERT_test}/hubert_hook_test.py (100%) diff --git a/demos/HuBERT_test /hubert_hook_test.py b/demos/HuBERT_test/hubert_hook_test.py similarity index 100% rename from demos/HuBERT_test /hubert_hook_test.py rename to demos/HuBERT_test/hubert_hook_test.py From 2f578ce64d8d6788da3ab38ad18929ae5ba1fd7f Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:32:44 -0500 Subject: [PATCH 31/82] Rename hubert_test.py to demos/HuBERT_test/hubert_test.py --- hubert_test.py => demos/HuBERT_test/hubert_test.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename hubert_test.py => demos/HuBERT_test/hubert_test.py (100%) diff --git a/hubert_test.py b/demos/HuBERT_test/hubert_test.py similarity index 100% rename from hubert_test.py rename to demos/HuBERT_test/hubert_test.py From f76c2ee82060bdbb00aca35d884f1e641055f4f3 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:41:34 -0500 Subject: [PATCH 32/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index ee92d3493..ead56b1e0 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -7,7 +7,6 @@ from __future__ import annotations import logging -import os from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, overload from typing_extensions import Literal @@ -17,7 +16,6 @@ from einops import repeat from jaxtyping import Float, Int from transformers import AutoProcessor, HubertModel, HubertForCTC, AutoFeatureExtractor -from transformers.models.auto.tokenization_auto import AutoTokenizer import transformer_lens.loading_from_pretrained as loading from transformer_lens.ActivationCache import ActivationCache From 69345b1b2f141591bae39267133beda2ea7c87ec Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:41:57 -0500 Subject: [PATCH 33/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index ead56b1e0..d37d4957e 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -25,7 +25,7 @@ BertBlock, ) from transformer_lens.FactoredMatrix import FactoredMatrix -from transformer_lens.hook_points import HookedRootModule, HookPoint +from transformer_lens.hook_points import HookedRootModule from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.utilities import devices From 7be3d4e028223369729f0c9180976d0a6b9d46ce Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:42:43 -0500 Subject: [PATCH 34/82] Update hubert.py --- transformer_lens/pretrained/weight_conversions/hubert.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_lens/pretrained/weight_conversions/hubert.py b/transformer_lens/pretrained/weight_conversions/hubert.py index bfb0c00e8..fd9e17523 100644 --- a/transformer_lens/pretrained/weight_conversions/hubert.py +++ b/transformer_lens/pretrained/weight_conversions/hubert.py @@ -1,5 +1,4 @@ import einops -import torch from transformer_lens.HookedTransformerConfig import HookedTransformerConfig def convert_hubert_weights(hf_model, cfg: HookedTransformerConfig): From 7e177c42c44a7a4eccf5bacbbea5677c05313ce5 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:43:08 -0500 Subject: [PATCH 35/82] Update hubert_ctc_test.py --- demos/HuBERT_test/hubert_ctc_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/demos/HuBERT_test/hubert_ctc_test.py b/demos/HuBERT_test/hubert_ctc_test.py index c5dc1577f..93fea7735 100644 --- a/demos/HuBERT_test/hubert_ctc_test.py +++ b/demos/HuBERT_test/hubert_ctc_test.py @@ -11,7 +11,6 @@ import torch import numpy as np import math -import sys from transformer_lens import HookedAudioEncoder # ----- CONFIG ----- From 6737ccd810a333126b00a7d02f2329513e7f53b3 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:43:58 -0500 Subject: [PATCH 36/82] Update hubert_hook_test.py --- demos/HuBERT_test/hubert_hook_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demos/HuBERT_test/hubert_hook_test.py b/demos/HuBERT_test/hubert_hook_test.py index 5cf64f0f1..084ecf021 100644 --- a/demos/HuBERT_test/hubert_hook_test.py +++ b/demos/HuBERT_test/hubert_hook_test.py @@ -2,7 +2,7 @@ import torch import numpy as np import math -import circuitsvis as cv +# import circuitsvis as cv from tqdm import tqdm from jaxtyping import Float From e062f38ba8eaa7b09ee53d6fae48544690a9cd21 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:44:42 -0500 Subject: [PATCH 37/82] Update hubert_hook_test.py --- demos/HuBERT_test/hubert_hook_test.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/demos/HuBERT_test/hubert_hook_test.py b/demos/HuBERT_test/hubert_hook_test.py index 084ecf021..1e40edfcf 100644 --- a/demos/HuBERT_test/hubert_hook_test.py +++ b/demos/HuBERT_test/hubert_hook_test.py @@ -4,14 +4,7 @@ import math # import circuitsvis as cv -from tqdm import tqdm -from jaxtyping import Float - -import transformer_lens import transformer_lens.utils as utils -from transformer_lens.hook_points import ( - HookPoint, -) # Hooking utilities from transformer_lens import HookedAudioEncoder # ---- Simple sine audio generator ---- SAMPLE_RATE = 16000 From 340260fe4ae53ba13406459a7753f055bc316970 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:53:24 -0500 Subject: [PATCH 38/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index d37d4957e..352e029a7 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -6,27 +6,26 @@ from __future__ import annotations +# Standard library import logging from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, overload -from typing_extensions import Literal +# Third-party +import numpy as np import torch import torch.nn as nn -import numpy as np from einops import repeat from jaxtyping import Float, Int -from transformers import AutoProcessor, HubertModel, HubertForCTC, AutoFeatureExtractor +from transformers import AutoFeatureExtractor, AutoProcessor, HubertForCTC, HubertModel +from typing_extensions import Literal +# Local imports import transformer_lens.loading_from_pretrained as loading from transformer_lens.ActivationCache import ActivationCache -from transformer_lens.components import ( - MLP, - Attention, - BertBlock, -) from transformer_lens.FactoredMatrix import FactoredMatrix -from transformer_lens.hook_points import HookedRootModule from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.components import Attention, BertBlock, MLP +from transformer_lens.hook_points import HookedRootModule from transformer_lens.utilities import devices T = TypeVar("T", bound="HookedEncoder") From 3c44076adf08b32cd6ec282785040ca93d0858ff Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:55:07 -0500 Subject: [PATCH 39/82] Update loading_from_pretrained.py --- transformer_lens/loading_from_pretrained.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index e13800327..e1a2ed36d 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -4,7 +4,6 @@ This module contains functions for loading pretrained models from the Hugging Face Hub. """ - import dataclasses import logging import os @@ -18,8 +17,8 @@ AutoConfig, AutoModelForCausalLM, BertForPreTraining, - T5ForConditionalGeneration, HubertModel, + T5ForConditionalGeneration, ) import transformer_lens.utils as utils From 64aeb4c61b942de5ed3458e1d5cc81beede8d169 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:55:22 -0500 Subject: [PATCH 40/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index 352e029a7..0cd4c10f6 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -6,11 +6,9 @@ from __future__ import annotations -# Standard library import logging from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, overload -# Third-party import numpy as np import torch import torch.nn as nn @@ -19,7 +17,6 @@ from transformers import AutoFeatureExtractor, AutoProcessor, HubertForCTC, HubertModel from typing_extensions import Literal -# Local imports import transformer_lens.loading_from_pretrained as loading from transformer_lens.ActivationCache import ActivationCache from transformer_lens.FactoredMatrix import FactoredMatrix From 71a4f51a0a39d17737e36c82dafd1b16244476c7 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:56:08 -0500 Subject: [PATCH 41/82] Update hubert.py --- transformer_lens/pretrained/weight_conversions/hubert.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_lens/pretrained/weight_conversions/hubert.py b/transformer_lens/pretrained/weight_conversions/hubert.py index fd9e17523..7d48066da 100644 --- a/transformer_lens/pretrained/weight_conversions/hubert.py +++ b/transformer_lens/pretrained/weight_conversions/hubert.py @@ -1,4 +1,5 @@ import einops + from transformer_lens.HookedTransformerConfig import HookedTransformerConfig def convert_hubert_weights(hf_model, cfg: HookedTransformerConfig): From f0207ca7c7dbae3f03bde74647bdb51f5a46b57c Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:56:46 -0500 Subject: [PATCH 42/82] Update hubert_ctc_test.py --- demos/HuBERT_test/hubert_ctc_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/demos/HuBERT_test/hubert_ctc_test.py b/demos/HuBERT_test/hubert_ctc_test.py index 93fea7735..d5ba8f46c 100644 --- a/demos/HuBERT_test/hubert_ctc_test.py +++ b/demos/HuBERT_test/hubert_ctc_test.py @@ -8,9 +8,11 @@ Change the import to point at your HookedAudioEncoder implementation. """ -import torch -import numpy as np import math + +import numpy as np +import torch + from transformer_lens import HookedAudioEncoder # ----- CONFIG ----- From 98f6eac1de90b0e64d5e9dd758fe5901198e8258 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:57:20 -0500 Subject: [PATCH 43/82] Update hubert_hook_test.py --- demos/HuBERT_test/hubert_hook_test.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/demos/HuBERT_test/hubert_hook_test.py b/demos/HuBERT_test/hubert_hook_test.py index 1e40edfcf..67d06b29d 100644 --- a/demos/HuBERT_test/hubert_hook_test.py +++ b/demos/HuBERT_test/hubert_hook_test.py @@ -1,11 +1,13 @@ -# test_hubert_hooks.py -import torch -import numpy as np import math + +import numpy as np +import torch # import circuitsvis as cv -import transformer_lens.utils as utils from transformer_lens import HookedAudioEncoder +import transformer_lens.utils as utils + + # ---- Simple sine audio generator ---- SAMPLE_RATE = 16000 DURATION_S = 1.0 From da84180c9a9ae238496e0fb8edf2790b03325604 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:57:36 -0500 Subject: [PATCH 44/82] Update hubert_hook_test.py --- demos/HuBERT_test/hubert_hook_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/demos/HuBERT_test/hubert_hook_test.py b/demos/HuBERT_test/hubert_hook_test.py index 67d06b29d..5e86a85b8 100644 --- a/demos/HuBERT_test/hubert_hook_test.py +++ b/demos/HuBERT_test/hubert_hook_test.py @@ -7,7 +7,6 @@ from transformer_lens import HookedAudioEncoder import transformer_lens.utils as utils - # ---- Simple sine audio generator ---- SAMPLE_RATE = 16000 DURATION_S = 1.0 From ede04f84c18596f0ddf747eeb4a85e6c8ed7da4e Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:58:14 -0500 Subject: [PATCH 45/82] Update hubert_test.py --- demos/HuBERT_test/hubert_test.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/demos/HuBERT_test/hubert_test.py b/demos/HuBERT_test/hubert_test.py index ef971dbaa..55bd8f5cc 100644 --- a/demos/HuBERT_test/hubert_test.py +++ b/demos/HuBERT_test/hubert_test.py @@ -1,15 +1,10 @@ # test_hubert_hooked.py -import torch -import numpy as np import math -# Replace this with the actual import for your implementation: +import numpy as np +import torch + from transformer_lens import HookedAudioEncoder -# For illustration I assume the same API as HookedEncoder/HookedAudioEncoder: -# - HookedAudioEncoder.from_pretrained(...) OR HookedAudioEncoder(...) to instantiate -# - model(waveform, return_type=...) or model(waveform) returns a tensor -# -# If your class is named differently, change the import and instantiation below. # ---------- CONFIG ---------- SAMPLE_RATE = 16000 From 305509a37acf74dc7ddad5fd7b49ff4f84509233 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 00:16:21 -0500 Subject: [PATCH 46/82] Update loading_from_pretrained.py --- transformer_lens/loading_from_pretrained.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index e1a2ed36d..40dc2408d 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -4,6 +4,7 @@ This module contains functions for loading pretrained models from the Hugging Face Hub. """ + import dataclasses import logging import os @@ -15,9 +16,9 @@ from huggingface_hub import HfApi from transformers import ( AutoConfig, + HubertModel, AutoModelForCausalLM, BertForPreTraining, - HubertModel, T5ForConditionalGeneration, ) From 6461e2e706f15db1b2ad12781b876fe7b6d4f62b Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 00:17:18 -0500 Subject: [PATCH 47/82] Update hubert.py --- transformer_lens/pretrained/weight_conversions/hubert.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_lens/pretrained/weight_conversions/hubert.py b/transformer_lens/pretrained/weight_conversions/hubert.py index 7d48066da..a13141725 100644 --- a/transformer_lens/pretrained/weight_conversions/hubert.py +++ b/transformer_lens/pretrained/weight_conversions/hubert.py @@ -2,6 +2,7 @@ from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + def convert_hubert_weights(hf_model, cfg: HookedTransformerConfig): """ Convert transformer encoder weights from a HuggingFace HuBERT model From 5344612d43e8d8289c34e0cef24e209af39e6f90 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 00:19:53 -0500 Subject: [PATCH 48/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index 0cd4c10f6..1a38a8cc0 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -9,8 +9,8 @@ import logging from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, overload -import numpy as np import torch +import numpy as np import torch.nn as nn from einops import repeat from jaxtyping import Float, Int @@ -20,9 +20,13 @@ import transformer_lens.loading_from_pretrained as loading from transformer_lens.ActivationCache import ActivationCache from transformer_lens.FactoredMatrix import FactoredMatrix -from transformer_lens.HookedTransformerConfig import HookedTransformerConfig -from transformer_lens.components import Attention, BertBlock, MLP +from transformer_lens.components import { + Attention, + BertBlock, + MLP +} from transformer_lens.hook_points import HookedRootModule +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.utilities import devices T = TypeVar("T", bound="HookedEncoder") From 32db5d2552249e2fae8f1d1d61701d1a8f82705f Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 00:20:19 -0500 Subject: [PATCH 49/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index 1a38a8cc0..bd03e68a5 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -21,9 +21,9 @@ from transformer_lens.ActivationCache import ActivationCache from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.components import { - Attention, - BertBlock, - MLP + MLP, + Attention, + BertBlock, } from transformer_lens.hook_points import HookedRootModule from transformer_lens.HookedTransformerConfig import HookedTransformerConfig From 560ffb92a91cba8767748051ae551b54e98cc22a Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 00:21:29 -0500 Subject: [PATCH 50/82] Update hubert_hook_test.py --- demos/HuBERT_test/hubert_hook_test.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/demos/HuBERT_test/hubert_hook_test.py b/demos/HuBERT_test/hubert_hook_test.py index 5e86a85b8..2f4aec867 100644 --- a/demos/HuBERT_test/hubert_hook_test.py +++ b/demos/HuBERT_test/hubert_hook_test.py @@ -1,11 +1,10 @@ import math -import numpy as np import torch -# import circuitsvis as cv +import numpy as np -from transformer_lens import HookedAudioEncoder import transformer_lens.utils as utils +from transformer_lens import HookedAudioEncoder # ---- Simple sine audio generator ---- SAMPLE_RATE = 16000 @@ -16,16 +15,8 @@ def make_sine(sr=SAMPLE_RATE, duration=DURATION_S, freq=440.0, amp=0.1): t = np.linspace(0, duration, int(sr*duration), endpoint=False, dtype=np.float32) return amp * np.sin(2 * math.pi * freq * t) -# ---- Adapt these to your instantiated model/wrapper ---- -# instantiate your HookedAudioEncoder (example) audio_model = HookedAudioEncoder.from_pretrained("facebook/hubert-base-ls960", device="cuda") -# For this template I'll assume: -# - `audio_model` is the object that implements .to_frames(raw_inputs...) -> (frames, frame_mask) -# - `audio_model` exposes: run_with_cache(frames, attention_mask=frame_mask, remove_batch_dim=True) -# - `audio_model` exposes: run_with_hooks(frames, fwd_hooks=[(act_name, hook_fn)], attention_mask=frame_mask, return_type=...) -# If your API names differ, substitute accordingly. - def main(): # --- Build a 1s test waveform --- wav = make_sine() From dda10e548d96b5422be0644124f34ae434fe79d0 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 00:24:18 -0500 Subject: [PATCH 51/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index bd03e68a5..7a44581d5 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -20,11 +20,11 @@ import transformer_lens.loading_from_pretrained as loading from transformer_lens.ActivationCache import ActivationCache from transformer_lens.FactoredMatrix import FactoredMatrix -from transformer_lens.components import { +from transformer_lens.components import ( MLP, Attention, BertBlock, -} +) from transformer_lens.hook_points import HookedRootModule from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.utilities import devices From 219defbb6b260970b20bcaed718e60c23b0e99f4 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 00:28:27 -0500 Subject: [PATCH 52/82] Update hubert_hook_test.py --- demos/HuBERT_test/hubert_hook_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demos/HuBERT_test/hubert_hook_test.py b/demos/HuBERT_test/hubert_hook_test.py index 2f4aec867..a6acaed14 100644 --- a/demos/HuBERT_test/hubert_hook_test.py +++ b/demos/HuBERT_test/hubert_hook_test.py @@ -3,8 +3,8 @@ import torch import numpy as np -import transformer_lens.utils as utils from transformer_lens import HookedAudioEncoder +import transformer_lens.utils as utils # ---- Simple sine audio generator ---- SAMPLE_RATE = 16000 From 2df2d27f50c31b5eee7efd6b7785ec56c7f7a787 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 00:31:16 -0500 Subject: [PATCH 53/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index 7a44581d5..1c31c48a4 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -9,26 +9,28 @@ import logging from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, overload -import torch import numpy as np +import torch import torch.nn as nn from einops import repeat from jaxtyping import Float, Int -from transformers import AutoFeatureExtractor, AutoProcessor, HubertForCTC, HubertModel +from transformers import ( + AutoFeatureExtractor, + AutoProcessor, + HubertForCTC, + HubertModel, +) from typing_extensions import Literal -import transformer_lens.loading_from_pretrained as loading from transformer_lens.ActivationCache import ActivationCache from transformer_lens.FactoredMatrix import FactoredMatrix -from transformer_lens.components import ( - MLP, - Attention, - BertBlock, -) -from transformer_lens.hook_points import HookedRootModule from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.components import Attention, BertBlock, MLP +from transformer_lens.hook_points import HookedRootModule +from transformer_lens import loading_from_pretrained as loading from transformer_lens.utilities import devices + T = TypeVar("T", bound="HookedEncoder") From 46c3344b51a2ab6e9b9f11c599618d1e8086fe05 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 00:32:55 -0500 Subject: [PATCH 54/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index 1c31c48a4..ce84d548b 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -22,12 +22,16 @@ ) from typing_extensions import Literal +from transformer_lens import loading_from_pretrained as loading from transformer_lens.ActivationCache import ActivationCache +from transformer_lens.components import ( + MLP, + Attention, + BertBlock, +) from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.HookedTransformerConfig import HookedTransformerConfig -from transformer_lens.components import Attention, BertBlock, MLP from transformer_lens.hook_points import HookedRootModule -from transformer_lens import loading_from_pretrained as loading from transformer_lens.utilities import devices From 6272b9f488c06fbe93615b02147f29fe9d63d0be Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 00:34:51 -0500 Subject: [PATCH 55/82] Update loading_from_pretrained.py --- transformer_lens/loading_from_pretrained.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 40dc2408d..7c10bb363 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -16,9 +16,9 @@ from huggingface_hub import HfApi from transformers import ( AutoConfig, - HubertModel, AutoModelForCausalLM, BertForPreTraining, + HubertModel, T5ForConditionalGeneration, ) From af6163deb2ea3f3cbe73fdb6f21b1b3364d9e023 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 00:35:30 -0500 Subject: [PATCH 56/82] Update loading_from_pretrained.py --- transformer_lens/loading_from_pretrained.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 7c10bb363..d72316bac 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -31,6 +31,7 @@ convert_gemma_weights, convert_gpt2_weights, convert_gptj_weights, + convert_hubert_weights, convert_llama_weights, convert_mingpt_weights, convert_mistral_weights, @@ -45,7 +46,6 @@ convert_qwen3_weights, convert_qwen_weights, convert_t5_weights, - convert_hubert_weights, ) OFFICIAL_MODEL_NAMES = [ From 0b5a86073d87959186dc91e919b979ccebf60b5a Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 00:37:39 -0500 Subject: [PATCH 57/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index ce84d548b..20d703568 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -26,7 +26,7 @@ from transformer_lens.ActivationCache import ActivationCache from transformer_lens.components import ( MLP, - Attention, + Attention, BertBlock, ) from transformer_lens.FactoredMatrix import FactoredMatrix From 817c97f768719689de6ed43867d8ac995e18dfb0 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 00:44:15 -0500 Subject: [PATCH 58/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index 20d703568..68c2eb2f1 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -30,8 +30,8 @@ BertBlock, ) from transformer_lens.FactoredMatrix import FactoredMatrix -from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.hook_points import HookedRootModule +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.utilities import devices From 48920e19b86b04cff12c6803a8df8d7dc3926078 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 00:44:49 -0500 Subject: [PATCH 59/82] Update hubert_hook_test.py --- demos/HuBERT_test/hubert_hook_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demos/HuBERT_test/hubert_hook_test.py b/demos/HuBERT_test/hubert_hook_test.py index a6acaed14..818829d86 100644 --- a/demos/HuBERT_test/hubert_hook_test.py +++ b/demos/HuBERT_test/hubert_hook_test.py @@ -1,7 +1,7 @@ import math -import torch import numpy as np +import torch from transformer_lens import HookedAudioEncoder import transformer_lens.utils as utils From 6dcffb231425c9b363828cf80c29ff88c759e704 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 00:47:35 -0500 Subject: [PATCH 60/82] Update hubert_hook_test.py --- demos/HuBERT_test/hubert_hook_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demos/HuBERT_test/hubert_hook_test.py b/demos/HuBERT_test/hubert_hook_test.py index 818829d86..79225bdc0 100644 --- a/demos/HuBERT_test/hubert_hook_test.py +++ b/demos/HuBERT_test/hubert_hook_test.py @@ -3,8 +3,8 @@ import numpy as np import torch -from transformer_lens import HookedAudioEncoder import transformer_lens.utils as utils +from transformer_lens import HookedAudioEncoder # ---- Simple sine audio generator ---- SAMPLE_RATE = 16000 From 5f7af854f9b7421a835492eb2de63fe65b6fe3fd Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 7 Nov 2025 00:50:00 -0500 Subject: [PATCH 61/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index 68c2eb2f1..631da94b1 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -9,11 +9,11 @@ import logging from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, overload +from einops import repeat +from jaxtyping import Float, Int import numpy as np import torch import torch.nn as nn -from einops import repeat -from jaxtyping import Float, Int from transformers import ( AutoFeatureExtractor, AutoProcessor, @@ -25,9 +25,9 @@ from transformer_lens import loading_from_pretrained as loading from transformer_lens.ActivationCache import ActivationCache from transformer_lens.components import ( - MLP, Attention, BertBlock, + MLP, ) from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.hook_points import HookedRootModule From fefcea2e70e897f97566c1c77a9f6e1abc0e5fe3 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Sun, 9 Nov 2025 14:27:42 -0500 Subject: [PATCH 62/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index 631da94b1..d70b191d2 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -54,7 +54,7 @@ def __init__( cfg: Union[HookedTransformerConfig, Dict], move_to_device: bool = True, model_name: str = "facebook/hubert-base-ls960", - use_ctc: bool = True, + use_ctc: bool = False, **kwargs: Any, ): super().__init__() @@ -369,7 +369,7 @@ def from_pretrained( device: Optional[str] = None, move_to_device: bool = True, dtype: torch.dtype = torch.float32, - use_ctc: bool = True, + use_ctc: bool = False, **from_pretrained_kwargs: Any, ) -> HookedEncoder: """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" From b1414e0262f6d0fb60c9adae48008e88055b5663 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Sun, 9 Nov 2025 16:25:35 -0500 Subject: [PATCH 63/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index d70b191d2..e0cbd8b95 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -28,6 +28,7 @@ Attention, BertBlock, MLP, + LayerNorm, ) from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.hook_points import HookedRootModule @@ -97,6 +98,8 @@ def __init__( self.hubert_model = hubert_model self.lm_head = None + self.ln_final = LayerNorm(self.cfg) + if move_to_device: if self.cfg.device is None: raise ValueError("Cannot move to device when device is None") @@ -291,6 +294,7 @@ def forward( # ---------- 3) Run encoder (respects pos_conv_embed / layer_norm / dropout inside encoder_output) ---------- resid = self.encoder_output(frames, frame_mask) # (B, T, d_model) + resid = self.ln_final(resid) if use_ctc: if self.lm_head is None: From 94bd3d7d8a9bc8bb5245fdc3962df61220e6db88 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Sun, 9 Nov 2025 16:32:13 -0500 Subject: [PATCH 64/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index e0cbd8b95..d70b191d2 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -28,7 +28,6 @@ Attention, BertBlock, MLP, - LayerNorm, ) from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.hook_points import HookedRootModule @@ -98,8 +97,6 @@ def __init__( self.hubert_model = hubert_model self.lm_head = None - self.ln_final = LayerNorm(self.cfg) - if move_to_device: if self.cfg.device is None: raise ValueError("Cannot move to device when device is None") @@ -294,7 +291,6 @@ def forward( # ---------- 3) Run encoder (respects pos_conv_embed / layer_norm / dropout inside encoder_output) ---------- resid = self.encoder_output(frames, frame_mask) # (B, T, d_model) - resid = self.ln_final(resid) if use_ctc: if self.lm_head is None: From fbae9c1bb2b049a2364035280e72bdd5102f7564 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Tue, 11 Nov 2025 18:13:39 -0500 Subject: [PATCH 65/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index d70b191d2..aa6431930 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -73,15 +73,28 @@ def __init__( # fine-tuned model (has CTC head) use_ctc = True processor = AutoProcessor.from_pretrained(model_name) # builds input_values + attention_mask + logging.warning( + f"Using AutoProcessor. The model name is {model_name}" + ) + else: # pretraining-only model (no CTC) use_ctc = False processor = AutoFeatureExtractor.from_pretrained(model_name) + logging.warning( + f"Using AutoFeatureExtractor. The model name is {model_name}" + ) - if use_ctc: + if model_name.endswith("-ft") and use_ctc: hubert_model = HubertForCTC.from_pretrained(model_name) + logging.warning( + f"Using HubertForCTC" + ) else: hubert_model = HubertModel.from_pretrained(model_name) + logging.warning( + f"Using HubertModel" + ) if move_to_device: if self.cfg.device is None: raise ValueError("Cannot move to device when device is None") From f23d0d9a298ecc88fb26d32dbdba2ff7453a380a Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Tue, 11 Nov 2025 18:46:48 -0500 Subject: [PATCH 66/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 53 ++++++++++++++------------ 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index aa6431930..0e70b15c1 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -69,36 +69,26 @@ def __init__( assert self.cfg.n_devices == 1, "Multiple devices not supported for HookedEncoder" self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)]) + if model_name.endswith("-ft") and use_ctc: # fine-tuned model (has CTC head) use_ctc = True processor = AutoProcessor.from_pretrained(model_name) # builds input_values + attention_mask - logging.warning( - f"Using AutoProcessor. The model name is {model_name}" - ) - else: # pretraining-only model (no CTC) use_ctc = False processor = AutoFeatureExtractor.from_pretrained(model_name) - logging.warning( - f"Using AutoFeatureExtractor. The model name is {model_name}" - ) if model_name.endswith("-ft") and use_ctc: hubert_model = HubertForCTC.from_pretrained(model_name) - logging.warning( - f"Using HubertForCTC" - ) else: hubert_model = HubertModel.from_pretrained(model_name) - logging.warning( - f"Using HubertModel" - ) + if move_to_device: if self.cfg.device is None: raise ValueError("Cannot move to device when device is None") hubert_model.to(self.cfg.device) + hubert_model.eval() self.processor = processor if use_ctc: @@ -117,15 +107,29 @@ def __init__( self.setup() - def _ensure_tensor(self, wave): - """Convert numpy array or python list to 1D torch.float tensor.""" - if isinstance(wave, np.ndarray): - return torch.from_numpy(wave).float() - if isinstance(wave, list): - return torch.tensor(wave, dtype=torch.float) + def _ensure_numpy(self, wave): + """ + Convert torch.Tensor / np.ndarray / list -> 1D np.float32 array on CPU. + """ if isinstance(wave, torch.Tensor): - return wave.float() - raise TypeError("wave must be torch.Tensor, np.ndarray or list of floats") + arr = wave.detach().cpu().numpy() + elif isinstance(wave, np.ndarray): + arr = wave + elif isinstance(wave, list): + arr = np.asarray(wave) + else: + raise TypeError("wave must be torch.Tensor, np.ndarray or list of floats") + + # force 1-D (if stereo or shape (N,1) etc) + if arr.ndim > 1: + # if shape (n_samples, n_channels) average channels -> mono + if arr.shape[1] <= arr.shape[0]: + arr = arr.mean(axis=1) + else: + arr = arr.reshape(-1) + + return arr.astype(np.float32, copy=False) + def to_frames( self, @@ -149,12 +153,11 @@ def to_frames( frames: torch.Tensor of shape (batch, frames, hidden_size) <- after feature_projection frame_attention_mask: torch.LongTensor of shape (batch, frames) with 1 for real frames, 0 for padding """ - # raw_inputs are arrays/tensors - # print(type(raw_inputs)) + # AutoFeatureExtractor works better onnumpy array where it pads automatically. If passing in tensors, it does not pad properly, giving inhomogeneous arts error if isinstance(raw_inputs, (torch.Tensor, np.ndarray)): - waves = [self._ensure_tensor(raw_inputs)] + waves = [self._ensure_numpy(raw_inputs)] elif isinstance(raw_inputs, list): - waves = [self._ensure_tensor(w) for w in raw_inputs] + waves = [self._ensure_numpy(w) for w in raw_inputs] else: raise TypeError("Unsupported raw_inputs type") From d20ee073e55666de9abf6d4674c7222c2a15f689 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Tue, 11 Nov 2025 19:02:31 -0500 Subject: [PATCH 67/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index 0e70b15c1..82cc1c2d3 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -163,7 +163,6 @@ def to_frames( # Use HF processor to create input_values (padded) + sample-level attention_mask # Processor will do padding so we can pass a variable-length batch - waves = [w.detach().cpu() for w in waves] proc_out = self.processor(waves, sampling_rate=sampling_rate, return_tensors="pt", padding=True) input_values = proc_out["input_values"] # (batch, samples), float sample_attention_mask = proc_out.get("attention_mask") # (batch, samples), 1 for valid, 0 for padding; may be None From 00c12cb9944971c1217d79c52bcbaca64f0cb1d6 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Wed, 12 Nov 2025 16:54:42 -0500 Subject: [PATCH 68/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index 82cc1c2d3..ff4748556 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -163,7 +163,7 @@ def to_frames( # Use HF processor to create input_values (padded) + sample-level attention_mask # Processor will do padding so we can pass a variable-length batch - proc_out = self.processor(waves, sampling_rate=sampling_rate, return_tensors="pt", padding=True) + proc_out = self.processor(waves, sampling_rate=sampling_rate, return_tensors="pt", padding=True, return_attention_mask=True) input_values = proc_out["input_values"] # (batch, samples), float sample_attention_mask = proc_out.get("attention_mask") # (batch, samples), 1 for valid, 0 for padding; may be None From 14ab5bb4fe30eeab9ada16436fa128786ee73f9f Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Sun, 16 Nov 2025 20:10:14 -0500 Subject: [PATCH 69/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index ff4748556..d9db532e9 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -19,6 +19,7 @@ AutoProcessor, HubertForCTC, HubertModel, + Wav2Vec2Model ) from typing_extensions import Literal @@ -81,6 +82,8 @@ def __init__( if model_name.endswith("-ft") and use_ctc: hubert_model = HubertForCTC.from_pretrained(model_name) + elif "wav2vec2" in model_name: + hubert_model = Wav2Vec2Model.from_pretrained(model_name) else: hubert_model = HubertModel.from_pretrained(model_name) From 41402ba7788d455e02f56108108c2559ef49fd4c Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Sun, 16 Nov 2025 20:15:07 -0500 Subject: [PATCH 70/82] Update loading_from_pretrained.py --- transformer_lens/loading_from_pretrained.py | 39 ++++++++++++++------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index d72316bac..42991359b 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -20,6 +20,7 @@ BertForPreTraining, HubertModel, T5ForConditionalGeneration, + Wav2Vec2Model, ) import transformer_lens.utils as utils @@ -46,6 +47,7 @@ convert_qwen3_weights, convert_qwen_weights, convert_t5_weights, + convert_w2v2_weights ) OFFICIAL_MODEL_NAMES = [ @@ -62,10 +64,7 @@ "facebook/opt-30b", "facebook/opt-66b", "facebook/hubert-base-ls960", - "facebook/hubert-large-ll60k", - "facebook/hubert-large-ls960-ft", - "facebook/hubert-xlarge-ll60k", - "facebook/hubert-xlarge-ls960-ft", + "facebook/wav2vec2-base-960h", "EleutherAI/gpt-neo-125M", "EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-2.7B", @@ -618,10 +617,7 @@ "google-bert/bert-large-cased": ["bert-large-cased"], "google-bert/bert-large-uncased": ["bert-large-uncased"], "facebook/hubert-base-ls960": ["facebook/hubert-base-ls960", "hubert-base-ls960"], - "facebook/hubert-large-ll60k": ["facebook/hubert-large-ll60k", "hubert-large-ll60k"], - "facebook/hubert-large-ls960-ft": ["facebook/hubert-large-ls960-ft", "hubert-large-ls960-ft"], - "facebook/hubert-xlarge-ll60k": ["facebook/hubert-xlarge-ll60k", "hubert-xlarge-ll60k"], - "facebook/hubert-xlarge-ls960-ft": ["facebook/hubert-xlarge-ls960-ft", "hubert-xlarge-ls960-ft"], + "facebook/wav2vec2-base-960h": ["facebook/wav2vec2-base-960h", "wav2vec2-base-960h"], "roneneldan/TinyStories-1M": ["tiny-stories-1M"], "roneneldan/TinyStories-3M": ["tiny-stories-3M"], "roneneldan/TinyStories-8M": ["tiny-stories-8M"], @@ -1203,7 +1199,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "attention_dir": "bidirectional", "d_vocab": -1, # no text vocabulary } - elif architecture == "HubertForCTC": + elif architecture == "Wav2Vec2Model": # Basic transformer configuration cfg_dict = { "d_model": hf_config.hidden_size, @@ -1211,13 +1207,28 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "n_heads": hf_config.num_attention_heads, "d_mlp": hf_config.intermediate_size, "n_layers": hf_config.num_hidden_layers, + # HuBERT operates on audio frames, not tokens — n_ctx is flexible "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), "eps": hf_config.layer_norm_eps, "act_fn": "gelu", "attention_dir": "bidirectional", - # For CTC models: - "d_vocab": hf_config.vocab_size, # text vocab from tokenizer + "d_vocab": -1, # no text vocabulary } + # elif architecture == "HubertForCTC": + # # Basic transformer configuration + # cfg_dict = { + # "d_model": hf_config.hidden_size, + # "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + # "n_heads": hf_config.num_attention_heads, + # "d_mlp": hf_config.intermediate_size, + # "n_layers": hf_config.num_hidden_layers, + # "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), + # "eps": hf_config.layer_norm_eps, + # "act_fn": "gelu", + # "attention_dir": "bidirectional", + # # For CTC models: + # "d_vocab": hf_config.vocab_size, # text vocab from tokenizer + # } elif architecture == "BertForMaskedLM": # All supported Bert architectures have the same config, # so we can use the BertForMaskedLM config for all of them @@ -2011,8 +2022,10 @@ def get_pretrained_state_dict( state_dict = convert_llama_weights(hf_model, cfg) elif cfg.original_architecture == "HubertModel": state_dict = convert_hubert_weights(hf_model, cfg) - elif cfg.original_architecture == "HubertForCTC": - state_dict = convert_hubert_weights(hf_model, cfg) + elif cfg.original_architecture == "Wav2Vec2Model": + state_dict = convert_w2v2_weights(hf_model, cfg) + # elif cfg.original_architecture == "HubertForCTC": + # state_dict = convert_hubert_weights(hf_model, cfg) elif cfg.original_architecture == "BertForMaskedLM": state_dict = convert_bert_weights(hf_model, cfg) elif cfg.original_architecture == "T5ForConditionalGeneration": From b5cb2e18a61309732f58f3219e98b945228d4a7c Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Sun, 16 Nov 2025 20:21:58 -0500 Subject: [PATCH 71/82] Update loading_from_pretrained.py --- transformer_lens/loading_from_pretrained.py | 37 ++++++++++----------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 42991359b..350d83e2f 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -47,7 +47,6 @@ convert_qwen3_weights, convert_qwen_weights, convert_t5_weights, - convert_w2v2_weights ) OFFICIAL_MODEL_NAMES = [ @@ -1214,21 +1213,21 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "attention_dir": "bidirectional", "d_vocab": -1, # no text vocabulary } - # elif architecture == "HubertForCTC": - # # Basic transformer configuration - # cfg_dict = { - # "d_model": hf_config.hidden_size, - # "d_head": hf_config.hidden_size // hf_config.num_attention_heads, - # "n_heads": hf_config.num_attention_heads, - # "d_mlp": hf_config.intermediate_size, - # "n_layers": hf_config.num_hidden_layers, - # "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), - # "eps": hf_config.layer_norm_eps, - # "act_fn": "gelu", - # "attention_dir": "bidirectional", - # # For CTC models: - # "d_vocab": hf_config.vocab_size, # text vocab from tokenizer - # } + elif architecture == "HubertForCTC": + # Basic transformer configuration + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), + "eps": hf_config.layer_norm_eps, + "act_fn": "gelu", + "attention_dir": "bidirectional", + # For CTC models: + "d_vocab": hf_config.vocab_size, # text vocab from tokenizer + } elif architecture == "BertForMaskedLM": # All supported Bert architectures have the same config, # so we can use the BertForMaskedLM config for all of them @@ -2023,9 +2022,9 @@ def get_pretrained_state_dict( elif cfg.original_architecture == "HubertModel": state_dict = convert_hubert_weights(hf_model, cfg) elif cfg.original_architecture == "Wav2Vec2Model": - state_dict = convert_w2v2_weights(hf_model, cfg) - # elif cfg.original_architecture == "HubertForCTC": - # state_dict = convert_hubert_weights(hf_model, cfg) + state_dict = convert_hubert_weights(hf_model, cfg) + elif cfg.original_architecture == "HubertForCTC": + state_dict = convert_hubert_weights(hf_model, cfg) elif cfg.original_architecture == "BertForMaskedLM": state_dict = convert_bert_weights(hf_model, cfg) elif cfg.original_architecture == "T5ForConditionalGeneration": From f8200bc2e4d33dd670964cd5ba7ffaf95b875023 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Sun, 16 Nov 2025 20:24:05 -0500 Subject: [PATCH 72/82] Update loading_from_pretrained.py --- transformer_lens/loading_from_pretrained.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 350d83e2f..740435c58 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1194,7 +1194,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): # HuBERT operates on audio frames, not tokens — n_ctx is flexible "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), "eps": hf_config.layer_norm_eps, - "act_fn": "gelu", + "act_fn": getattr(hf_cfg, "hidden_act", "gelu"), "attention_dir": "bidirectional", "d_vocab": -1, # no text vocabulary } @@ -1209,7 +1209,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): # HuBERT operates on audio frames, not tokens — n_ctx is flexible "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), "eps": hf_config.layer_norm_eps, - "act_fn": "gelu", + "act_fn": getattr(hf_cfg, "hidden_act", "gelu"), "attention_dir": "bidirectional", "d_vocab": -1, # no text vocabulary } @@ -1223,7 +1223,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "n_layers": hf_config.num_hidden_layers, "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), "eps": hf_config.layer_norm_eps, - "act_fn": "gelu", + "act_fn": getattr(hf_cfg, "hidden_act", "gelu"), "attention_dir": "bidirectional", # For CTC models: "d_vocab": hf_config.vocab_size, # text vocab from tokenizer From 6926e2bb97c2c9451ed329dce4a363bbcd627476 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Sun, 16 Nov 2025 20:25:54 -0500 Subject: [PATCH 73/82] Update loading_from_pretrained.py --- transformer_lens/loading_from_pretrained.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 740435c58..6fa90a824 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1980,6 +1980,13 @@ def get_pretrained_state_dict( token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, ) + elif "wav2vec2" in official_model_name: + hf_model = Wav2Vec2Model.from_pretrained( + official_model_name, + torch_dtype=dtype, + token=huggingface_token if len(huggingface_token) > 0 else None, + **kwargs, + ) elif "bert" in official_model_name: hf_model = BertForPreTraining.from_pretrained( official_model_name, From e8e958cbb48aae253563bd21da354b9f47ab13f1 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:04:31 -0500 Subject: [PATCH 74/82] Update loading_from_pretrained.py --- transformer_lens/loading_from_pretrained.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 6fa90a824..6a2671316 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1194,7 +1194,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): # HuBERT operates on audio frames, not tokens — n_ctx is flexible "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), "eps": hf_config.layer_norm_eps, - "act_fn": getattr(hf_cfg, "hidden_act", "gelu"), + "act_fn": getattr(hf_config, "hidden_act", "gelu"), "attention_dir": "bidirectional", "d_vocab": -1, # no text vocabulary } @@ -1209,7 +1209,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): # HuBERT operates on audio frames, not tokens — n_ctx is flexible "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), "eps": hf_config.layer_norm_eps, - "act_fn": getattr(hf_cfg, "hidden_act", "gelu"), + "act_fn": getattr(hf_config, "hidden_act", "gelu"), "attention_dir": "bidirectional", "d_vocab": -1, # no text vocabulary } @@ -1223,7 +1223,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "n_layers": hf_config.num_hidden_layers, "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), "eps": hf_config.layer_norm_eps, - "act_fn": getattr(hf_cfg, "hidden_act", "gelu"), + "act_fn": getattr(hf_config, "hidden_act", "gelu"), "attention_dir": "bidirectional", # For CTC models: "d_vocab": hf_config.vocab_size, # text vocab from tokenizer From fa8932157a8e2e903d80b6ff1b84162fceab08e9 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:12:53 -0500 Subject: [PATCH 75/82] Update requirements.txt --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 07d227d57..54065e35e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ rich -torch +torch==2.9.0 transformers datasets jaxtyping @@ -8,4 +8,4 @@ einops better_abc typeguard wandb -circuitsvis \ No newline at end of file +circuitsvis From 5a7c5c78a286315bd324b1eb3df2da33aee5c741 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:13:35 -0500 Subject: [PATCH 76/82] Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 54065e35e..f90bac6e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ torch==2.9.0 transformers datasets jaxtyping -datasets +datasets<3.0.0 einops better_abc typeguard From cd8e922e526cdd82fa8f7e3b18444f02c4c2c9aa Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Mon, 17 Nov 2025 20:55:04 -0500 Subject: [PATCH 77/82] Update loading_from_pretrained.py --- transformer_lens/loading_from_pretrained.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 6a2671316..7879091a9 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -63,7 +63,8 @@ "facebook/opt-30b", "facebook/opt-66b", "facebook/hubert-base-ls960", - "facebook/wav2vec2-base-960h", + "facebook/wav2vec2-base", + "facebook/wav2vec2-large", "EleutherAI/gpt-neo-125M", "EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-2.7B", @@ -616,7 +617,8 @@ "google-bert/bert-large-cased": ["bert-large-cased"], "google-bert/bert-large-uncased": ["bert-large-uncased"], "facebook/hubert-base-ls960": ["facebook/hubert-base-ls960", "hubert-base-ls960"], - "facebook/wav2vec2-base-960h": ["facebook/wav2vec2-base-960h", "wav2vec2-base-960h"], + "facebook/wav2vec2-base": ["facebook/wav2vec2-base", "wav2vec2-base", "w2v2-base"], + "facebook/wav2vec2-large": ["facebook/wav2vec2-large", "wav2vec2-large", "w2v2-large"], "roneneldan/TinyStories-1M": ["tiny-stories-1M"], "roneneldan/TinyStories-3M": ["tiny-stories-3M"], "roneneldan/TinyStories-8M": ["tiny-stories-8M"], From 77285ba7441cca93c1f1437aa12c5a8686b91d19 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 21 Nov 2025 22:09:42 -0500 Subject: [PATCH 78/82] Update loading_from_pretrained.py --- transformer_lens/loading_from_pretrained.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 7879091a9..ceaa1d8ea 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1200,7 +1200,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "attention_dir": "bidirectional", "d_vocab": -1, # no text vocabulary } - elif architecture == "Wav2Vec2Model": + elif "wav2vec2-base" in official_model_name or "wav2vec2-large" in official_model_name: # Basic transformer configuration cfg_dict = { "d_model": hf_config.hidden_size, From 9fa6464c371f7be9c1b0c265765e7c66897da53b Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Sun, 23 Nov 2025 19:54:00 -0500 Subject: [PATCH 79/82] Update loading_from_pretrained.py --- transformer_lens/loading_from_pretrained.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index ceaa1d8ea..c0c49d09e 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -2030,7 +2030,7 @@ def get_pretrained_state_dict( state_dict = convert_llama_weights(hf_model, cfg) elif cfg.original_architecture == "HubertModel": state_dict = convert_hubert_weights(hf_model, cfg) - elif cfg.original_architecture == "Wav2Vec2Model": + elif cfg.original_architecture == "Wav2Vec2Model" or cfg.original_architecture == "Wav2Vec2ForPreTraining": state_dict = convert_hubert_weights(hf_model, cfg) elif cfg.original_architecture == "HubertForCTC": state_dict = convert_hubert_weights(hf_model, cfg) From fc9327e0e111e42fb19d572b814437829b656e5b Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Sun, 23 Nov 2025 22:38:18 -0500 Subject: [PATCH 80/82] Update HookedAudioEncoder.py --- transformer_lens/HookedAudioEncoder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index d9db532e9..5d7f5a0a4 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -299,6 +299,8 @@ def forward( # allow single 1D tensor or numpy array or list of tensors/arrays frames, frame_mask = self.to_frames(inputs) # to_frames should already place tensors on device if move_to_device=True + if isinstance(frames, tuple): + frames = frames[0] frame_mask = frame_mask if one_zero_attention_mask is None else one_zero_attention_mask # ---------- 2) Ensure device & dtype consistency ---------- device = self.cfg.device From c6a43a768356682e8ae58bc0e36d295ad917c2c8 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Mon, 24 Nov 2025 15:37:57 -0500 Subject: [PATCH 81/82] Update bert_pooler.py --- transformer_lens/components/bert_pooler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_lens/components/bert_pooler.py b/transformer_lens/components/bert_pooler.py index cd205bf7f..4f0b4250e 100644 --- a/transformer_lens/components/bert_pooler.py +++ b/transformer_lens/components/bert_pooler.py @@ -24,7 +24,7 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): self.cfg = HookedTransformerConfig.unwrap(cfg) self.W = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_model, dtype=self.cfg.dtype)) self.b = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) - self.activation = nn.Tanh() + # self.activation = nn.Tanh() self.hook_pooler_out = HookPoint() def forward( @@ -32,5 +32,6 @@ def forward( ) -> Float[torch.Tensor, "batch d_model"]: first_token_tensor = resid[:, 0] pooled_output = torch.matmul(first_token_tensor, self.W) + self.b - pooled_output = self.hook_pooler_out(self.activation(pooled_output)) + # pooled_output = self.hook_pooler_out(self.activation(pooled_output)) + pooled_output = self.hook_pooler_out(pooled_output) return pooled_output From 942706819eefae758b8d4c004ae16f04b3bec3c2 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Mon, 24 Nov 2025 15:45:40 -0500 Subject: [PATCH 82/82] Update bert_pooler.py --- transformer_lens/components/bert_pooler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_lens/components/bert_pooler.py b/transformer_lens/components/bert_pooler.py index 4f0b4250e..4f23bba14 100644 --- a/transformer_lens/components/bert_pooler.py +++ b/transformer_lens/components/bert_pooler.py @@ -24,7 +24,7 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): self.cfg = HookedTransformerConfig.unwrap(cfg) self.W = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_model, dtype=self.cfg.dtype)) self.b = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) - # self.activation = nn.Tanh() + self.activation = nn.Tanh() self.hook_pooler_out = HookPoint() def forward( @@ -32,6 +32,6 @@ def forward( ) -> Float[torch.Tensor, "batch d_model"]: first_token_tensor = resid[:, 0] pooled_output = torch.matmul(first_token_tensor, self.W) + self.b - # pooled_output = self.hook_pooler_out(self.activation(pooled_output)) - pooled_output = self.hook_pooler_out(pooled_output) + pooled_output = self.hook_pooler_out(self.activation(pooled_output)) + # pooled_output = self.hook_pooler_out(pooled_output) return pooled_output