From 2be123d8908556ac5dad58953827f9574ca70a2d Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Sun, 16 Nov 2025 11:00:09 -0800 Subject: [PATCH 01/11] Address device/dtype mismatches that caused failures in various contexts. We also update .gitignore to exclude .env (commonly used local file exclution), e.g. to allow collaborators to add their on HF_TOKEN for test suite Core Fixes: ----------- transformer_lens/components/abstract_attention.py: - Replace pattern.to(self.cfg.dtype) with pattern.to(v.dtype) to handle cases where tensors are upcast to float32 for numerical stability while cfg.dtype remains float16/bfloat16 - Add explicit device/dtype synchronization for output projection: * Move weights (W_O) and bias (b_O) to match input device (z.device) * Ensure z matches weight dtype before final linear operation transformer_lens/model_bridge/bridge.py: - Replace direct original_model.to() call with move_to_and_update_config() utility to ensure: * All bridge components (not just original_model) are moved to target device * cfg.device and cfg.dtype stay synchronized with actual model state * Multi-GPU cache tensors remain on correct devices Test Fixes: ----------- tests/acceptance/test_hooked_encoder.py: - Fix test_cuda() to use correct fixture name 'tokens' instead of 'mlm_tokens' tests/acceptance/test_multi_gpu.py: - Update test_cache_device() to pass torch.device("cpu") instead of string "cpu" for proper device type validation tests/unit/components/test_attention.py: - Add test_attention_forward_half_precisions() to validate attention works correctly with bfloat16/float16 dtypes on CUDA devices tests/unit/factored_matrix/test_multiply_by_scalar.py: - Add test IDs to parametrize decorators to avoid pytest cache issues when random numbers appear in test names Tests Fixed by This Commit: --------------------------- - tests/acceptance/test_multi_gpu.py::test_cache_device - tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_memory_efficiency[gpt2] - tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_consistent_outputs[gpt2] - tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype0] - tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype1] - tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype0] - tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype1] - tests/unit/model_bridge/compatibility/test_utils.py::TestUtilsWithTransformerBridge::test_device_compatibility[gpt2] --- .gitignore | 2 +- .vscode/settings.json | 2 +- tests/acceptance/test_hooked_encoder.py | 4 ++-- tests/acceptance/test_multi_gpu.py | 2 +- tests/unit/components/test_attention.py | 19 +++++++++++++++++ .../test_multiply_by_scalar.py | 1 + .../components/abstract_attention.py | 21 ++++++++++++------- transformer_lens/model_bridge/bridge.py | 6 +++++- 8 files changed, 44 insertions(+), 13 deletions(-) diff --git a/.gitignore b/.gitignore index 32da76df8..be879b728 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,4 @@ docs/source/generated # docs/source/_static/model_table **.orig .venv - +.env diff --git a/.vscode/settings.json b/.vscode/settings.json index 63e6e310a..86d448657 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -33,7 +33,7 @@ "notebook.formatOnSave.enabled": true, "pylint.importStrategy": "fromEnvironment", "python.testing.pytestArgs": [ - "transformer_lens", + "tests" ], "python.testing.pytestEnabled": true, "rewrap.autoWrap.enabled": true, diff --git a/tests/acceptance/test_hooked_encoder.py b/tests/acceptance/test_hooked_encoder.py index e62d35574..3afa69561 100644 --- a/tests/acceptance/test_hooked_encoder.py +++ b/tests/acceptance/test_hooked_encoder.py @@ -222,6 +222,6 @@ def test_input_list_of_strings_mlm(our_bert, huggingface_bert, tokenizer): @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires a CUDA device") -def test_cuda(mlm_tokens): +def test_cuda(tokens): model = HookedEncoder.from_pretrained(MODEL_NAME) - model(mlm_tokens) + model(tokens) diff --git a/tests/acceptance/test_multi_gpu.py b/tests/acceptance/test_multi_gpu.py index 3af5eeeb2..ad407eb6e 100644 --- a/tests/acceptance/test_multi_gpu.py +++ b/tests/acceptance/test_multi_gpu.py @@ -111,7 +111,7 @@ def test_cache_device(): torch.device("cuda:1") ) - logits, cache = model.run_with_cache("Hello there", device="cpu") + logits, cache = model.run_with_cache("Hello there", device=torch.device("cpu")) assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(torch.device("cpu")) model.to("cuda") diff --git a/tests/unit/components/test_attention.py b/tests/unit/components/test_attention.py index b386660c6..c473cc491 100644 --- a/tests/unit/components/test_attention.py +++ b/tests/unit/components/test_attention.py @@ -80,6 +80,25 @@ def test_attention_load_in_4bit(): assert torch.all(attn.b_V == 0) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for half/bfloat16 tests") +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +def test_attention_forward_half_precisions(dtype): + # Construct a small attention block + cfg = HookedTransformerConfig( + d_model=64, d_head=16, n_heads=4, n_layers=1, n_ctx=8, dtype=dtype + ) + attn = Attention(cfg) + # Random inputs in the matching dtype + batch = 1 + seq = 4 + x = torch.rand((batch, seq, cfg.d_model), dtype=dtype).to("cuda") + # Run forward through attention (q,k,v = x) + out = attn(x, x, x) + # Should not raise and return a tensor on cuda with same dtype as cfg or compatible + assert isinstance(out, torch.Tensor) + assert out.device.type == "cuda" + + def test_attention_config_dict(): cfg = { "n_layers": 12, diff --git a/tests/unit/factored_matrix/test_multiply_by_scalar.py b/tests/unit/factored_matrix/test_multiply_by_scalar.py index 85d0bfbe7..d5fbf29ba 100644 --- a/tests/unit/factored_matrix/test_multiply_by_scalar.py +++ b/tests/unit/factored_matrix/test_multiply_by_scalar.py @@ -23,6 +23,7 @@ ), # Non-scalar Tensor. AssertionError expected. (torch.rand(2), AssertionError), # Non-scalar Tensor. AssertionError expected. ], + ids=["tensor", "float", "int", "tensor_2d", "tensor_1d"], ) @pytest.mark.parametrize("leading_dim", [False, True]) @pytest.mark.parametrize("multiply_from_left", [False, True]) diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 5f026f493..7894d9a84 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -280,8 +280,7 @@ def forward( raise TypeError(f"Expected 'pattern' to be a Tensor, got {type(pattern)}") pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] - pattern = pattern.to(self.cfg.dtype) - pattern = pattern.to(v.device) + pattern = pattern.to(device=v.device, dtype=v.dtype) z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] if not self.cfg.use_attn_result: if self.cfg.load_in_4bit: @@ -301,15 +300,21 @@ def forward( self.W_O, "head_index d_head d_model -> d_model (head_index d_head)" ) - if self.b_O.device != w.device: - w = w.to(self.b_O.device) - if self.b_O.device != z.device: - z = z.to(self.b_O.device) + # Move output projection weights and bias to the same device as z + # so that the final linear operation occurs on the device of the inputs + if w.device != z.device: + w = w.to(z.device) + b_O = self.b_O + if b_O.device != z.device: + b_O = b_O.to(z.device) + # Ensure z has the same dtype as weights used in the output projection + if z.dtype != w.dtype: + z = z.to(w.dtype) out = F.linear( z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads), w, - self.b_O, + b_O, ) else: # Explicitly calculate the attention result so it can be accessed by a hook @@ -329,6 +334,8 @@ def forward( self.W_O, "head_index d_head d_model -> 1 1 head_index d_head d_model", ) + if w.device != z.device: + w = w.to(z.device) z = einops.rearrange( z, "batch pos head_index d_head -> batch pos head_index d_head 1" ) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index e1cde6c0e..3920f8402 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -6082,7 +6082,11 @@ def to(self, *args, **kwargs) -> "TransformerBridge": Returns: Self for chaining """ - self.original_model = self.original_model.to(*args, **kwargs) + # Use the shared utility which also updates `cfg` on device/dtype changes + from transformer_lens.utilities.devices import move_to_and_update_config + + # Move underlying model (and update config) instead of directly calling nn.Module.to + move_to_and_update_config(self, *args, **kwargs) return self def cuda(self, device: Optional[Union[int, torch.device]] = None) -> "TransformerBridge": From 84a2540e6c4e1e6db198c92ebfafaa19a6e13bae Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Sat, 22 Nov 2025 13:12:45 -0800 Subject: [PATCH 02/11] Add HuggingFace ModelOutput support to TransformerLens generate() API Enhance both HookedTransformer and TransformerBridge generate() methods to support HuggingFace-style generation outputs for improved interoperability: - HookedTransformer: Add output_logits flag to return ModelOutput with sequences and logits - TransformerBridge: Forward HF dict flags (output_scores, output_logits, output_attentions, output_hidden_states) to underlying HF model - Maintain full backward compatibility with existing generate() usage patterns - Add 17 integration tests covering ModelOutput behavior and flag handling - Use Any type for ModelOutput returns due to beartype forward reference limitations (#546) This enables downstream libraries to leverage HF's standard generation output format for advanced mechanistic interpretability workflows. --- .../test_generation_modeloutput.py | 375 ++++++++++++++++++ transformer_lens/HookedTransformer.py | 77 +++- transformer_lens/model_bridge/bridge.py | 71 +++- 3 files changed, 510 insertions(+), 13 deletions(-) create mode 100644 tests/integration/test_generation_modeloutput.py diff --git a/tests/integration/test_generation_modeloutput.py b/tests/integration/test_generation_modeloutput.py new file mode 100644 index 000000000..eb89465aa --- /dev/null +++ b/tests/integration/test_generation_modeloutput.py @@ -0,0 +1,375 @@ +"""Integration tests for generation API with ModelOutput support. + +This module tests the new generation API features that support HuggingFace-style +ModelOutput return. +""" + +import warnings + +import pytest +import torch + +from transformer_lens import HookedTransformer +from transformer_lens.model_bridge import TransformerBridge + + +@pytest.fixture(scope="module") +def gpt2_ht(): + """Load GPT-2 HookedTransformer once per module.""" + return HookedTransformer.from_pretrained("gpt2", device="cpu") + + +@pytest.fixture(scope="module") +def gpt2_bridge(): + """Load GPT-2 TransformerBridge once per module.""" + bridge = TransformerBridge.boot_transformers("gpt2", device="cpu") + if bridge.tokenizer.pad_token is None: + bridge.tokenizer.pad_token = bridge.tokenizer.eos_token + return bridge + + +class TestHookedTransformerGenerationModelOutput: + """Tests for HookedTransformer generation with ModelOutput returns.""" + + def test_generate_with_output_logits_returns_modeloutput(self, gpt2_ht): + """Test that output_logits=True returns a ModelOutput with sequences and logits.""" + prompt = "The quick brown" + max_new_tokens = 5 + + result = gpt2_ht.generate( + prompt, + max_new_tokens=max_new_tokens, + do_sample=False, + verbose=False, + output_logits=True, + ) + + # Check that we got a ModelOutput-like object + assert hasattr(result, "sequences"), "Result should have sequences attribute" + assert hasattr(result, "logits"), "Result should have logits attribute" + + # Check sequences shape and type + assert isinstance(result.sequences, torch.Tensor), "sequences should be a tensor" + assert result.sequences.ndim == 2, "sequences should be 2D [batch, pos]" + + # Check logits structure and shape + assert isinstance(result.logits, tuple), "logits should be a tuple" + assert ( + len(result.logits) == max_new_tokens + ), f"logits tuple should have {max_new_tokens} elements" + + # Each logit tensor should be [batch, vocab] + for i, logit in enumerate(result.logits): + assert isinstance(logit, torch.Tensor), f"logits[{i}] should be a tensor" + assert logit.ndim == 2, f"logits[{i}] should be 2D [batch, vocab]" + assert ( + logit.shape[0] == result.sequences.shape[0] + ), f"logits[{i}] batch size should match sequences" + assert ( + logit.shape[1] == gpt2_ht.cfg.d_vocab + ), f"logits[{i}] vocab size should match model config" + + def test_generate_without_output_logits_returns_normal(self, gpt2_ht): + """Test that without output_logits flag, generation returns normal format.""" + prompt = "The quick brown" + + result = gpt2_ht.generate( + prompt, + max_new_tokens=5, + do_sample=False, + verbose=False, + ) + + # Should return a string (default return_type="input" with string input) + assert isinstance(result, str), "Result should be a string" + assert len(result) > len(prompt), "Generated text should be longer than prompt" + + def test_generate_output_logits_with_return_type_tokens(self, gpt2_ht): + """Test output_logits with return_type='tokens' returns ModelOutput with token sequences.""" + prompt = "Hello world" + max_new_tokens = 3 + + result = gpt2_ht.generate( + prompt, + max_new_tokens=max_new_tokens, + return_type="tokens", + do_sample=False, + verbose=False, + output_logits=True, + ) + + # Check ModelOutput structure + assert hasattr(result, "sequences"), "Result should have sequences" + assert hasattr(result, "logits"), "Result should have logits" + + # Sequences should be tokens + assert isinstance(result.sequences, torch.Tensor), "sequences should be a tensor" + assert result.sequences.dtype in [ + torch.long, + torch.int, + torch.int64, + ], "sequences should be integer tokens" + + # Check logits + assert len(result.logits) == max_new_tokens, "logits should match max_new_tokens" + + def test_return_dict_in_generate_silently_ignored(self, gpt2_ht): + """Test that return_dict_in_generate is silently ignored without warnings.""" + prompt = "Test" + + # Should not raise any warnings + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = gpt2_ht.generate( + prompt, + max_new_tokens=2, + verbose=False, + return_dict_in_generate=True, # Should be silently ignored + ) + + # Check no warnings were raised + assert len(w) == 0, "return_dict_in_generate should be silently ignored" + + # Result should still be normal (string) + assert isinstance(result, str), "Result should be a string" + + def test_unsupported_hf_flags_trigger_warning(self, gpt2_ht): + """Test that unsupported HF generation kwargs trigger UserWarning.""" + prompt = "Test" + + with pytest.warns(UserWarning, match="unsupported generation kwargs"): + result = gpt2_ht.generate( + prompt, + max_new_tokens=2, + verbose=False, + output_scores=True, # Unsupported flag + output_attentions=True, # Unsupported flag + ) + + # Result should still work (string) + assert isinstance(result, str), "Result should be a string despite warnings" + + def test_logits_consistency_with_forward_pass(self, gpt2_ht): + """Test that logits from generate match those from forward pass.""" + prompt = "Hello" + tokens = gpt2_ht.to_tokens(prompt) + + # Generate with output_logits + result = gpt2_ht.generate( + prompt, + max_new_tokens=1, + do_sample=False, + verbose=False, + output_logits=True, + ) + + # Get first generated token from sequences + first_new_token = result.sequences[0, -1] + + # Get logits for that token + first_logits = result.logits[0][0] + + # The argmax of logits should match the generated token (since do_sample=False) + assert first_logits.argmax() == first_new_token, "Greedy token should match logits argmax" + + def test_output_logits_batch_generation(self, gpt2_ht): + """Test output_logits works with batch inputs.""" + prompts = ["Hello", "World"] + max_new_tokens = 3 + + result = gpt2_ht.generate( + prompts, + max_new_tokens=max_new_tokens, + do_sample=False, + verbose=False, + output_logits=True, + ) + + # Check batch dimension + assert result.sequences.shape[0] == len( + prompts + ), "Batch dimension should match number of prompts" + + # Check logits batch dimension + for logit in result.logits: + assert logit.shape[0] == len(prompts), "Logits batch dimension should match prompts" + + +class TestTransformerBridgeGenerationModelOutput: + """Tests for TransformerBridge generation with HF-style flags.""" + + def test_generate_with_output_logits_forwards_to_hf(self, gpt2_bridge): + """Test that output_logits is forwarded to HF and returns ModelOutput.""" + prompt = "The quick brown" + + result = gpt2_bridge.generate( + prompt, + max_new_tokens=5, + do_sample=False, + verbose=False, + output_logits=True, + ) + + # When using HF ModelOutput, result should either be a string (decoded) or ModelOutput + # depending on return_type. With return_type="input" and string input, we get string back + # But the underlying HF call should have received output_logits=True + assert isinstance(result, str), "Result should be decoded string with return_type='input'" + + def test_generate_with_output_scores_forwards_to_hf(self, gpt2_bridge): + """Test that output_scores is forwarded to HF model.""" + prompt = "Test" + + # output_scores should be forwarded without error + result = gpt2_bridge.generate( + prompt, + max_new_tokens=3, + do_sample=False, + verbose=False, + output_scores=True, + ) + + # Should return a string (default behavior with string input) + assert isinstance(result, str), "Result should be a string" + + def test_hf_dict_flags_set_return_dict_in_generate(self, gpt2_bridge): + """Test that hf_dict_flags automatically set return_dict_in_generate=True.""" + prompt = "Hello" + + # When we pass output_logits, return_dict_in_generate should be auto-set + # We can't directly inspect the HF call, but we can verify it doesn't error + result = gpt2_bridge.generate( + prompt, + max_new_tokens=2, + do_sample=False, + verbose=False, + output_logits=True, + ) + + # Should work without error + assert isinstance(result, str), "Result should be generated successfully" + + def test_multiple_hf_flags_simultaneously(self, gpt2_bridge): + """Test that multiple HF-style flags can be passed simultaneously.""" + prompt = "Test" + + result = gpt2_bridge.generate( + prompt, + max_new_tokens=2, + do_sample=False, + verbose=False, + output_logits=True, + output_attentions=True, + output_hidden_states=True, + ) + + # Should work and return a result + assert isinstance(result, str), "Result should be generated with multiple flags" + + def test_return_type_tokens_with_hf_flags(self, gpt2_bridge): + """Test return_type='tokens' works with HF flags.""" + prompt = "Hello" + + result = gpt2_bridge.generate( + prompt, + max_new_tokens=2, + return_type="tokens", + do_sample=False, + verbose=False, + output_logits=True, + ) + + # With return_type='tokens', we should get either tokens tensor or ModelOutput + # The implementation returns the raw HF output for tokens + assert result is not None, "Result should not be None" + + def test_hf_flags_coerced_to_bool(self, gpt2_bridge): + """Test that HF flags are properly coerced to boolean values.""" + prompt = "Test" + + # Pass non-boolean values that should be coerced to bool + result = gpt2_bridge.generate( + prompt, + max_new_tokens=2, + do_sample=False, + verbose=False, + output_logits=1, # Should be coerced to True + output_scores=0, # Should be coerced to False (but we pass explicitly so it's truthy) + ) + + # Should work without error + assert isinstance(result, str) or result is not None, "Result should be generated" + + def test_batch_generation_with_hf_flags(self, gpt2_bridge): + """Test batch generation works with HF-style flags.""" + prompts = ["Hello", "World"] + + result = gpt2_bridge.generate( + prompts, + max_new_tokens=2, + do_sample=False, + verbose=False, + output_logits=True, + ) + + # Should return list of strings for batch input + assert isinstance(result, list), "Batch input should return list" + assert len(result) == len(prompts), "Output list should match input length" + + +class TestGenerationBackwardCompatibility: + """Tests to ensure backward compatibility with existing generation usage.""" + + def test_hooked_transformer_basic_generation_unchanged(self, gpt2_ht): + """Test that basic generation without new flags works as before.""" + prompt = "Hello world" + + result = gpt2_ht.generate( + prompt, + max_new_tokens=5, + do_sample=False, + verbose=False, + ) + + assert isinstance(result, str), "Basic generation should return string" + assert len(result) > len(prompt), "Generated text should be longer" + + def test_bridge_basic_generation_unchanged(self, gpt2_bridge): + """Test that basic bridge generation without new flags works as before.""" + prompt = "Hello world" + + result = gpt2_bridge.generate( + prompt, + max_new_tokens=5, + do_sample=False, + verbose=False, + ) + + assert isinstance(result, str), "Basic generation should return string" + assert len(result) > len(prompt), "Generated text should be longer" + + def test_hooked_transformer_return_types_unchanged(self, gpt2_ht): + """Test that all return_type options still work.""" + prompt = "Test" + + # Test return_type='str' + result_str = gpt2_ht.generate( + prompt, max_new_tokens=2, return_type="str", verbose=False, do_sample=False + ) + assert isinstance(result_str, str), "return_type='str' should return string" + + # Test return_type='tokens' + result_tokens = gpt2_ht.generate( + prompt, max_new_tokens=2, return_type="tokens", verbose=False, do_sample=False + ) + assert isinstance(result_tokens, torch.Tensor), "return_type='tokens' should return tensor" + + # Test return_type='embeds' + result_embeds = gpt2_ht.generate( + prompt, max_new_tokens=2, return_type="embeds", verbose=False, do_sample=False + ) + assert isinstance(result_embeds, torch.Tensor), "return_type='embeds' should return tensor" + assert result_embeds.ndim == 3, "Embeddings should be 3D" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 4bcece7b1..2cccf1e4f 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -14,6 +14,7 @@ import logging import os from typing import ( + Any, Dict, List, NamedTuple, @@ -1840,11 +1841,15 @@ def generate( padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, return_type: Optional[str] = "input", verbose: bool = True, + **generation_kwargs, ) -> Union[ str, List[str], Int[torch.Tensor, "batch pos_plus_new_tokens"], Float[torch.Tensor, "batch pos_plus_new_tokens hidden_size"], + Any, # transformers.utils.ModelOutput to accommodate output_logits=True. + # Using Any due to beartype's forward reference resolution limitations. + # See: https://github.com/beartype/beartype/issues/546 ]: """Sample Tokens from the Model. @@ -1943,6 +1948,34 @@ def generate( else: past_kv_cache = None + # We only support a single HF style generation kwarg: `output_logits` which will cause + # the function to return a ModelOutput-like object containing `sequences` and `logits`. + # Any other HF-style generation kwargs are rejected to avoid supporting the full HF API here. + output_logits_flag = False + if generation_kwargs: + if "output_logits" in generation_kwargs: + output_logits_flag = bool(generation_kwargs.pop("output_logits")) + # Identify keys to warn about: anything other than allowed/silently ignored + accepted_keys = {"output_logits", "return_dict_in_generate"} + unsupported_keys = [k for k in generation_kwargs.keys() if k not in accepted_keys] + # silently ignore `return_dict_in_generate` + if "return_dict_in_generate" in generation_kwargs: + generation_kwargs.pop("return_dict_in_generate") + # If any unsupported keys remain, warn and ignore them + if unsupported_keys: + import warnings + + warnings.warn( + f"HookedTransformer.generate received unsupported generation kwargs; ignoring: {unsupported_keys}", + UserWarning, + ) + # Clear unsupported keys + for k in unsupported_keys: + generation_kwargs.pop(k, None) + + # Optionally collect logits at each generation step for downstream tooling/tests + logits_seq_list: Optional[List[torch.Tensor]] = [] if output_logits_flag else None + shortformer_pos_embed = None embeds = input if input_type == "embeds" else self.embed(input) @@ -2033,6 +2066,10 @@ def generate( ) final_logits = logits[:, -1, :] + if output_logits_flag: + assert logits_seq_list is not None + logits_seq_list.append(final_logits.unsqueeze(1)) + if do_sample: if input_type in [ "str", @@ -2089,11 +2126,45 @@ def generate( self.tokenizer.decode(tokens, skip_special_tokens=True) for tokens in output_tokens ] - return decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts + result = decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts elif return_type == "tokens": - return output_tokens + result = output_tokens + else: + result = embeds + + if output_logits_flag: + # Adhere to HF ModelOutput format with sequences (tokens) and logits (per-step) + try: + from transformers.generation.utils import ( + GenerateDecoderOnlyOutput, # type: ignore + ) + except Exception: + from transformers.utils import ModelOutput # type: ignore + + # Use a ModelOutput-like object + logits_tensor = ( + torch.cat(logits_seq_list, dim=1) if logits_seq_list is not None else None + ) + logits_tuple = ( + tuple(logits_tensor[:, i, :] for i in range(logits_tensor.shape[1])) + if logits_tensor is not None + else None + ) + # `sequences` expects a tensor of token ids + return ModelOutput(sequences=output_tokens, logits=logits_tuple) # type: ignore[arg-type] + else: + assert logits_seq_list is not None + logits_tensor = torch.cat(logits_seq_list, dim=1) + # Convert to HF's expected output shape: a tuple of [batch, vocab] per step + logits_tuple = tuple( + logits_tensor[:, i, :] for i in range(logits_tensor.shape[1]) + ) + sequences = ( + output_tokens if isinstance(output_tokens, torch.Tensor) else output_tokens + ) + return GenerateDecoderOnlyOutput(sequences=sequences, logits=logits_tuple) else: - return embeds + return result # Give access to all weights as properties. @property diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 3920f8402..e30f2c91d 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -105,6 +105,8 @@ class TransformerBridge(nn.Module): "hook_pos_embed": ["pos_embed.hook_out", "rotary_emb.hook_out"], "hook_unembed": "unembed.hook_out", } + # HF-style flags that commonly request a dict/ModelOutput from HF `generate` + hf_dict_flags = ("output_scores", "output_logits", "output_attentions", "output_hidden_states") def __init__( self, @@ -6003,7 +6005,10 @@ def generate( padding_side: Optional[str] = None, return_type: Optional[str] = "input", verbose: bool = True, - ) -> Union[str, List[str], torch.Tensor]: + **generation_kwargs, + ) -> Union[str, List[str], torch.Tensor, Any]: # Any to support transformers.utils.ModelOutput + # Using Any due to beartype's forward reference resolution limitations. + # See: https://github.com/beartype/beartype/issues/546 """Generate text from the model using the underlying HuggingFace model.""" # Handle string input by tokenizing it if isinstance(input, str): @@ -6024,13 +6029,16 @@ def generate( if input_ids.device != self.cfg.device: input_ids = input_ids.to(self.cfg.device) - # Set up generation parameters for HuggingFace - generation_kwargs = { - "max_new_tokens": max_new_tokens, - "do_sample": do_sample, - "temperature": temperature, - "pad_token_id": self.tokenizer.eos_token_id, - } + # explicit args supplied will override values in generation_kwargs. + generation_kwargs = dict(generation_kwargs) if generation_kwargs is not None else {} + generation_kwargs.update( + { + "max_new_tokens": max_new_tokens, + "do_sample": do_sample, + "temperature": temperature, + "pad_token_id": self.tokenizer.eos_token_id, + } + ) if top_k is not None: generation_kwargs["top_k"] = top_k @@ -6044,17 +6052,45 @@ def generate( if use_past_kv_cache: generation_kwargs["use_cache"] = True + # If callers provide HF-style output flags, pass them through. + any_flag_set = False + for f in type(self).hf_dict_flags: + if f in generation_kwargs and generation_kwargs.get(f) is not None: + # coerce value to bool if appropriate and pass through to HF generate + generation_kwargs[f] = bool(generation_kwargs[f]) + any_flag_set = True + if any_flag_set: + # Ensure HF returns a ModelOutput for these flags by default + generation_kwargs.setdefault("return_dict_in_generate", True) + # Generate using the original HuggingFace model with torch.no_grad(): outputs = self.original_model.generate(input_ids, **generation_kwargs) # type: ignore[operator] + try: + # We import ModelOutput lazily to avoid import cycles; this is a duck-typing check + from transformers.utils import ModelOutput # type: ignore + + is_model_output = isinstance(outputs, ModelOutput) + except Exception: + is_model_output = False + # Return based on return_type and input format if return_type == "input" or return_type is None: if isinstance(input, str): # Decode the full output back to string + # If we have a ModelOutput with sequences, decode those + if is_model_output and hasattr(outputs, "sequences"): + seqs = outputs.sequences + return self.tokenizer.decode(seqs[0], skip_special_tokens=True) return self.tokenizer.decode(outputs[0], skip_special_tokens=True) elif isinstance(input, list): # Decode each sequence in the batch + if is_model_output and hasattr(outputs, "sequences"): + return [ + self.tokenizer.decode(seq, skip_special_tokens=True) + for seq in outputs.sequences + ] return [self.tokenizer.decode(seq, skip_special_tokens=True) for seq in outputs] else: # Return the full token sequence including input @@ -6064,8 +6100,15 @@ def generate( else: # For other return types, default to the decoded text if isinstance(input, str): + if is_model_output and hasattr(outputs, "sequences"): + return self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) return self.tokenizer.decode(outputs[0], skip_special_tokens=True) elif isinstance(input, list): + if is_model_output and hasattr(outputs, "sequences"): + return [ + self.tokenizer.decode(seq, skip_special_tokens=True) + for seq in outputs.sequences + ] return [self.tokenizer.decode(seq, skip_special_tokens=True) for seq in outputs] else: return outputs @@ -6263,10 +6306,18 @@ def get_caching_hooks( cache = {} if names_filter is None: - names_filter = lambda name: True + + def _names_filter_all(name: str) -> bool: + return True + + names_filter = _names_filter_all elif isinstance(names_filter, str): filter_str = names_filter - names_filter = lambda name: filter_str in name + + def _names_filter_contains(name: str) -> bool: + return filter_str in name + + names_filter = _names_filter_contains elif callable(names_filter): pass # Already a function else: From caf437ef87369b28dc197dbc0ab55b8b5a2ea421 Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Wed, 26 Nov 2025 10:50:50 -0800 Subject: [PATCH 03/11] fix(BlockBridge): Return tuple for HF generation compatibility Fix batch dimension bug where BlockBridge.forward() returned bare tensors instead of tuples, causing HuggingFace generation to incorrectly index into batch dimensions. Changes: - BlockBridge: Always return (first,) tuple for single-element outputs - Add regression test for batch dimension preservation during generation - Rename test_generation_modeloutput.py -> test_generation_compatibility.py The bug manifested when HF's GPT2Model did `hidden_states = outputs[0]` expecting a tuple but got a tensor, causing it to index the batch dimension instead of extracting the first tuple element. Fixes batch generation for TransformerBridge with multiple blocks. --- ...ut.py => test_generation_compatibility.py} | 81 ++++++++++++++++++- .../generalized_components/block.py | 4 +- 2 files changed, 81 insertions(+), 4 deletions(-) rename tests/integration/{test_generation_modeloutput.py => test_generation_compatibility.py} (80%) diff --git a/tests/integration/test_generation_modeloutput.py b/tests/integration/test_generation_compatibility.py similarity index 80% rename from tests/integration/test_generation_modeloutput.py rename to tests/integration/test_generation_compatibility.py index eb89465aa..3d7604cc4 100644 --- a/tests/integration/test_generation_modeloutput.py +++ b/tests/integration/test_generation_compatibility.py @@ -1,7 +1,7 @@ -"""Integration tests for generation API with ModelOutput support. +"""Integration tests for generation API compatibility. -This module tests the new generation API features that support HuggingFace-style -ModelOutput return. +This module tests generation API features including HuggingFace-style ModelOutput +support and TransformerBridge batch dimension compatibility. """ import warnings @@ -371,5 +371,80 @@ def test_hooked_transformer_return_types_unchanged(self, gpt2_ht): assert result_embeds.ndim == 3, "Embeddings should be 3D" +class TestBlockBridgeBatchCompatibility: + """Tests for BlockBridge tuple return format and batch dimension preservation.""" + + def test_block_bridge_batched_generation_compatibility(self, gpt2_bridge): + """Test BlockBridge maintains tuple format and batch dimensions during generation. + + This test exercises two critical aspects of improved HF compatibility: + 1. BlockBridge.forward() always returns tuples (not bare tensors) + 2. Batch dimensions are preserved through multi-block generation pipeline + """ + # Test 1: Direct block forward returns tuple with preserved batch dimension + batch_size = 2 + seq_len = 8 + hidden_dim = gpt2_bridge.cfg.d_model + hidden_states = torch.randn(batch_size, seq_len, hidden_dim) + + # Get first transformer block (BlockBridge component) + first_block = gpt2_bridge.original_model.transformer.h[0] + + # Call forward - this is what HF's GPT2Model does in its loop + block_output = first_block(hidden_states) + + # BlockBridge must return tuple + assert isinstance( + block_output, tuple + ), f"BlockBridge must return tuple for HF compatibility, got {type(block_output)}" + + # Verify first element is a tensor + assert isinstance( + block_output[0], torch.Tensor + ), "First element of BlockBridge output must be a tensor" + + # Batch dimension must be preserved + # Without tuple wrapping, outputs[0] idx op would turn [batch, seq, dim] -> [seq, dim] + assert block_output[0].shape == ( + batch_size, + seq_len, + hidden_dim, + ), f"Expected shape [{batch_size}, {seq_len}, {hidden_dim}], got {block_output[0].shape}" + + assert ( + block_output[0].shape[0] == batch_size + ), f"Batch dimension lost! Expected {batch_size}, got {block_output[0].shape[0]}" + + # Test 2: Batched generation works end-to-end through multiple blocks + prompts = ["Hello world", "Goodbye world"] + + # Tokenize with left padding + tokens = gpt2_bridge.to_tokens(prompts, prepend_bos=False, padding_side="left") + + # Generate tokens - this exercises the full HF generation loop with multiple blocks + output = gpt2_bridge.generate( + tokens, + max_new_tokens=4, + do_sample=False, # Deterministic for testing + use_past_kv_cache=True, + verbose=False, + ) + + # Verify output preserves batch dimension + assert output.shape[0] == len( + prompts + ), f"Batch size must be preserved through generation. Expected {len(prompts)}, got {output.shape[0]}" + + # Verify we actually generated new tokens + assert ( + output.shape[1] > tokens.shape[1] + ), "Generation should produce longer sequences than input" + + # Verify batch items remain independent (not collapsed into single item) + assert not torch.equal( + output[0], output[1] + ), "Batch items should be independent - different prompts should produce different outputs" + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/transformer_lens/model_bridge/generalized_components/block.py b/transformer_lens/model_bridge/generalized_components/block.py index 61a4d49c1..75672b9a0 100644 --- a/transformer_lens/model_bridge/generalized_components/block.py +++ b/transformer_lens/model_bridge/generalized_components/block.py @@ -271,8 +271,10 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: first = output[0] if isinstance(first, torch.Tensor): first = self.hook_out(first) + # Always return tuple to maintain consistency with HF's expected format + # e.g. GPT2Model does hidden_states = outputs[0], it expects outputs to be a tuple if len(output) == 1: - return first + return (first,) output = (first,) + output[1:] return output From f2efb3752625aba082f8719506624ad5ede49a37 Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Wed, 26 Nov 2025 14:55:21 -0800 Subject: [PATCH 04/11] Align TransformerBridge.to() with PyTorch nn.Module semantics Enhance to() method to properly handle both device and dtype arguments in all supported PyTorch formats (positional, keyword, combined). Separately invoke move_to_and_update_config for device/dtype to update cfg while delegating the actual tensor movement to original_model.to() with original args/kwargs. This ensures TransformerBridge respects standard PyTorch behavior for model.to() calls. --- transformer_lens/model_bridge/bridge.py | 40 +++++++++++++++++++++---- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index f83169da8..08dd2e001 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -39,6 +39,7 @@ ) from transformer_lens.model_bridge.get_params_util import get_bridge_params from transformer_lens.utilities.aliases import resolve_alias +from transformer_lens.utilities.devices import move_to_and_update_config if TYPE_CHECKING: from transformer_lens.ActivationCache import ActivationCache @@ -1754,7 +1755,7 @@ def generate( return output_tokens def to(self, *args, **kwargs) -> "TransformerBridge": - """Move model to device or change dtype. + """Move model to device and/or change dtype. Args: args: Positional arguments for nn.Module.to @@ -1763,11 +1764,38 @@ def to(self, *args, **kwargs) -> "TransformerBridge": Returns: Self for chaining """ - # Use the shared utility which also updates `cfg` on device/dtype changes - from transformer_lens.utilities.devices import move_to_and_update_config - - # Move underlying model (and update config) instead of directly calling nn.Module.to - move_to_and_update_config(self, *args, **kwargs) + # Extract print_details if provided + print_details = kwargs.pop("print_details", True) + + # Handle both device and dtype changes + # torch.nn.Module.to() supports: to(device), to(dtype), to(device, dtype), + # to(device=...), to(dtype=...), to(device=..., dtype=...) + target_device, target_dtype = None, None + + if len(args) >= 1: + first_arg = args[0] + if isinstance(first_arg, (torch.device, str)): + target_device = first_arg + elif isinstance(first_arg, torch.dtype): + target_dtype = first_arg + if len(args) >= 2: + second_arg = args[1] + if isinstance(second_arg, torch.dtype): + target_dtype = second_arg + + # these override positional args + if "device" in kwargs: + target_device = kwargs["device"] + if "dtype" in kwargs: + target_dtype = kwargs["dtype"] + + if target_device is not None: + move_to_and_update_config(self, target_device, print_details) + if target_dtype is not None: + move_to_and_update_config(self, target_dtype, print_details) + + # Move the original model with all original args/kwargs (with print_details removed) + self.original_model = self.original_model.to(*args, **kwargs) return self def cuda(self, device: Optional[Union[int, torch.device]] = None) -> "TransformerBridge": From b9ea75ffcc200152827bf158006fbffab3380f48 Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Wed, 26 Nov 2025 17:18:51 -0800 Subject: [PATCH 05/11] Add HuggingFace ModelOutput support to TransformerLens generation API Enhance HF compatibility of HookedTransformer and TransformerBridge Core changes: - HookedTransformer: Add output_logits flag returning ModelOutput with sequences and logits - TransformerBridge.generate(): Add output_logits flag for consistency with HookedTransformer - TransformerBridge.hf_generate(): New method for full HF API passthrough (output_scores, output_logits, output_attentions, output_hidden_states) - Maintain unified API: Both classes share same generate() signature per upstream design - hf_generate() for users needing full HF features, evaluate possibility of making it the default generate option in the future. Architecture: - Respects API consistency vision (unified generate() across both HookedTransformer/TransformerBridge classes) - Adds escape hatch for advanced HF use cases without compromising clean API - Clear separation: .generate() = TL-style, .hf_generate() = full HF Testing: - Comprehensive test suite (20 tests) covering ModelOutput behavior and flag handling - Full backward compatibility maintained with existing generate() usage This enables downstream libraries to leverage HF's standard generation output format for advanced workflows while maintaining TransformerLens's clean, consistent API. --- .../test_generation_compatibility.py | 112 ++++++--- transformer_lens/model_bridge/bridge.py | 218 +++++++++++++++++- 2 files changed, 296 insertions(+), 34 deletions(-) diff --git a/tests/integration/test_generation_compatibility.py b/tests/integration/test_generation_compatibility.py index 3d7604cc4..50292e3f6 100644 --- a/tests/integration/test_generation_compatibility.py +++ b/tests/integration/test_generation_compatibility.py @@ -196,10 +196,48 @@ def test_output_logits_batch_generation(self, gpt2_ht): class TestTransformerBridgeGenerationModelOutput: - """Tests for TransformerBridge generation with HF-style flags.""" + """Tests for TransformerBridge generation with ModelOutput returns.""" - def test_generate_with_output_logits_forwards_to_hf(self, gpt2_bridge): - """Test that output_logits is forwarded to HF and returns ModelOutput.""" + def test_generate_with_output_logits_returns_modeloutput(self, gpt2_bridge): + """Test that output_logits=True returns a ModelOutput with sequences and logits.""" + prompt = "The quick brown" + max_new_tokens = 5 + + result = gpt2_bridge.generate( + prompt, + max_new_tokens=max_new_tokens, + do_sample=False, + verbose=False, + output_logits=True, + ) + + # Check that we got a ModelOutput-like object + assert hasattr(result, "sequences"), "Result should have sequences attribute" + assert hasattr(result, "logits"), "Result should have logits attribute" + + # Check sequences shape and type + assert isinstance(result.sequences, torch.Tensor), "sequences should be a tensor" + assert result.sequences.ndim == 2, "sequences should be 2D [batch, pos]" + + # Check logits structure and shape + assert isinstance(result.logits, tuple), "logits should be a tuple" + assert ( + len(result.logits) == max_new_tokens + ), f"logits tuple should have {max_new_tokens} elements" + + # Each logit tensor should be [batch, vocab] + for i, logit in enumerate(result.logits): + assert isinstance(logit, torch.Tensor), f"logits[{i}] should be a tensor" + assert logit.ndim == 2, f"logits[{i}] should be 2D [batch, vocab]" + assert ( + logit.shape[0] == result.sequences.shape[0] + ), f"logits[{i}] batch size should match sequences" + assert ( + logit.shape[1] == gpt2_bridge.cfg.d_vocab + ), f"logits[{i}] vocab size should match model config" + + def test_generate_without_output_logits_returns_normal(self, gpt2_bridge): + """Test that without output_logits flag, generation returns normal format.""" prompt = "The quick brown" result = gpt2_bridge.generate( @@ -207,56 +245,81 @@ def test_generate_with_output_logits_forwards_to_hf(self, gpt2_bridge): max_new_tokens=5, do_sample=False, verbose=False, + ) + + # Should return a string (default return_type="input" with string input) + assert isinstance(result, str), "Result should be a string" + assert len(result) > len(prompt), "Generated text should be longer than prompt" + + def test_generate_output_logits_batch(self, gpt2_bridge): + """Test output_logits works with batch inputs.""" + prompts = ["Hello", "World"] + max_new_tokens = 3 + + result = gpt2_bridge.generate( + prompts, + max_new_tokens=max_new_tokens, + do_sample=False, + verbose=False, output_logits=True, ) - # When using HF ModelOutput, result should either be a string (decoded) or ModelOutput - # depending on return_type. With return_type="input" and string input, we get string back - # But the underlying HF call should have received output_logits=True - assert isinstance(result, str), "Result should be decoded string with return_type='input'" + # Check ModelOutput structure + assert hasattr(result, "sequences"), "Result should have sequences" + assert hasattr(result, "logits"), "Result should have logits" - def test_generate_with_output_scores_forwards_to_hf(self, gpt2_bridge): + # Check batch dimension + assert result.sequences.shape[0] == len( + prompts + ), "Batch dimension should match number of prompts" + + # Check logits batch dimension + for logit in result.logits: + assert logit.shape[0] == len(prompts), "Logits batch dimension should match prompts" + + +class TestTransformerBridgeHFGenerate: + """Tests for TransformerBridge.hf_generate() with full HF API support.""" + + def test_hf_generate_with_output_scores(self, gpt2_bridge): """Test that output_scores is forwarded to HF model.""" prompt = "Test" # output_scores should be forwarded without error - result = gpt2_bridge.generate( + result = gpt2_bridge.hf_generate( prompt, max_new_tokens=3, do_sample=False, - verbose=False, output_scores=True, ) # Should return a string (default behavior with string input) assert isinstance(result, str), "Result should be a string" - def test_hf_dict_flags_set_return_dict_in_generate(self, gpt2_bridge): + def test_hf_generate_sets_return_dict_in_generate(self, gpt2_bridge): """Test that hf_dict_flags automatically set return_dict_in_generate=True.""" prompt = "Hello" # When we pass output_logits, return_dict_in_generate should be auto-set # We can't directly inspect the HF call, but we can verify it doesn't error - result = gpt2_bridge.generate( + result = gpt2_bridge.hf_generate( prompt, max_new_tokens=2, do_sample=False, - verbose=False, output_logits=True, ) # Should work without error assert isinstance(result, str), "Result should be generated successfully" - def test_multiple_hf_flags_simultaneously(self, gpt2_bridge): + def test_hf_generate_multiple_flags_simultaneously(self, gpt2_bridge): """Test that multiple HF-style flags can be passed simultaneously.""" prompt = "Test" - result = gpt2_bridge.generate( + result = gpt2_bridge.hf_generate( prompt, max_new_tokens=2, do_sample=False, - verbose=False, output_logits=True, output_attentions=True, output_hidden_states=True, @@ -265,16 +328,15 @@ def test_multiple_hf_flags_simultaneously(self, gpt2_bridge): # Should work and return a result assert isinstance(result, str), "Result should be generated with multiple flags" - def test_return_type_tokens_with_hf_flags(self, gpt2_bridge): + def test_hf_generate_return_type_tokens(self, gpt2_bridge): """Test return_type='tokens' works with HF flags.""" prompt = "Hello" - result = gpt2_bridge.generate( + result = gpt2_bridge.hf_generate( prompt, max_new_tokens=2, return_type="tokens", do_sample=False, - verbose=False, output_logits=True, ) @@ -282,32 +344,30 @@ def test_return_type_tokens_with_hf_flags(self, gpt2_bridge): # The implementation returns the raw HF output for tokens assert result is not None, "Result should not be None" - def test_hf_flags_coerced_to_bool(self, gpt2_bridge): + def test_hf_generate_flags_coerced_to_bool(self, gpt2_bridge): """Test that HF flags are properly coerced to boolean values.""" prompt = "Test" # Pass non-boolean values that should be coerced to bool - result = gpt2_bridge.generate( + result = gpt2_bridge.hf_generate( prompt, max_new_tokens=2, do_sample=False, - verbose=False, output_logits=1, # Should be coerced to True - output_scores=0, # Should be coerced to False (but we pass explicitly so it's truthy) + output_scores=0, # Should be coerced to False but still triggers flag ) # Should work without error assert isinstance(result, str) or result is not None, "Result should be generated" - def test_batch_generation_with_hf_flags(self, gpt2_bridge): + def test_hf_generate_batch_generation(self, gpt2_bridge): """Test batch generation works with HF-style flags.""" prompts = ["Hello", "World"] - result = gpt2_bridge.generate( + result = gpt2_bridge.hf_generate( prompts, max_new_tokens=2, do_sample=False, - verbose=False, output_logits=True, ) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 08dd2e001..7102b5059 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -1622,7 +1622,10 @@ def generate( padding_side: Optional[str] = None, return_type: Optional[str] = "input", verbose: bool = True, - ) -> Union[str, List[str], torch.Tensor]: + output_logits: bool = False, + ) -> str | list[str] | torch.Tensor | Any: # Any for transformers.utils.ModelOutput + # Using Any due to beartype's forward reference resolution limitations. + # See: https://github.com/beartype/beartype/issues/546 """Sample tokens from the model. Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. @@ -1643,9 +1646,11 @@ def generate( padding_side: Not used in Bridge (kept for API compatibility) return_type: The type of output to return - 'input', 'str', or 'tokens' verbose: Not used in Bridge (kept for API compatibility) + output_logits: If True, return a ModelOutput with sequences and logits tuple Returns: - Generated sequence as string, list of strings, or tensor depending on input type and return_type + Generated sequence as string, list of strings, or tensor depending on input type and return_type. + If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes. """ # Convert input to tokens if isinstance(input, str): @@ -1695,6 +1700,9 @@ def generate( # Track which sequences have finished finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) + # Optionally collect logits at each generation step for downstream tooling/tests + logits_seq_list: list[torch.Tensor] | None = [] if output_logits else None + # Generate tokens current_tokens = input_tokens.clone() sampled_tokens_list = [] @@ -1705,6 +1713,10 @@ def generate( logits = self(current_tokens, return_type="logits") final_logits = logits[:, -1, :] + # Collect logits if requested + if logits_seq_list is not None: + logits_seq_list.append(final_logits.clone()) + # Sample next token if do_sample: sampled_tokens = utils.sample_logits( @@ -1741,6 +1753,27 @@ def generate( sampled_tokens = torch.cat(sampled_tokens_list, dim=1) output_tokens = torch.cat([input_tokens, sampled_tokens], dim=1) + # Return ModelOutput if output_logits was requested + if output_logits and logits_seq_list is not None: + try: + from transformers.generation.utils import GenerateDecoderOnlyOutput + from transformers.utils import ModelOutput # type: ignore + + # Return a HF-compatible ModelOutput structure + # GenerateDecoderOnlyOutput expects: sequences, scores (optional), logits (optional) + return GenerateDecoderOnlyOutput( + sequences=output_tokens, + logits=tuple(logits_seq_list), + ) + except ImportError: + # Fallback if HF not available or old version + from transformers.utils import ModelOutput + + return ModelOutput( + sequences=output_tokens, + logits=tuple(logits_seq_list), + ) + # Format output if return_type == "str": if input_type == "str": @@ -1754,6 +1787,175 @@ def generate( else: # return_type == "tokens" return output_tokens + def hf_generate( + self, + input: str | list[str] | torch.Tensor = "", + max_new_tokens: int = 10, + stop_at_eos: bool = True, + eos_token_id: int | None = None, + do_sample: bool = True, + top_k: int | None = None, + top_p: float | None = None, + temperature: float = 1.0, + use_past_kv_cache: bool = True, + return_type: str | None = "input", + **generation_kwargs, + ) -> str | list[str] | torch.Tensor | Any: # Any for HF ModelOutput types + # Using Any due to beartype's forward reference resolution limitations. + # See: https://github.com/beartype/beartype/issues/546 + """Generate text using the underlying HuggingFace model with full HF API support. + + This method provides direct access to HuggingFace's generation API, forwarding all + generation parameters (including output_scores, output_logits, output_attentions, + output_hidden_states) directly to the underlying HF model. Use this when you need + full HuggingFace generation features not supported by the standard generate() method. + + For standard generation compatible with HookedTransformer, use generate() instead. + + Args: + input: Text string, list of strings, or tensor of tokens + max_new_tokens: Maximum number of tokens to generate + stop_at_eos: If True, stop generating tokens when the model outputs eos_token + eos_token_id: The token ID to use for end of sentence + do_sample: If True, sample from the model's output distribution + top_k: Number of tokens to sample from + top_p: Probability mass to sample from + temperature: Temperature for sampling + use_past_kv_cache: If True, use KV caching for faster generation + return_type: The type of output to return - 'input', 'str', or 'tokens' + **generation_kwargs: Additional HuggingFace generation parameters including: + - output_scores: Return generation scores + - output_logits: Return generation logits + - output_attentions: Return attention weights + - output_hidden_states: Return hidden states + - return_dict_in_generate: Return ModelOutput object + - And any other HF generation parameters + + Returns: + Generated sequence as string, list of strings, tensor, or HF ModelOutput + depending on input type, return_type, and generation_kwargs. + + Example: + >>> # Get full HF ModelOutput with logits and attentions + >>> result = model.hf_generate( + ... "Hello world", + ... max_new_tokens=5, + ... output_logits=True, + ... output_attentions=True, + ... return_dict_in_generate=True + ... ) + >>> print(result.sequences) # Generated tokens + >>> print(result.logits) # Logits for each generation step + >>> print(result.attentions) # Attention weights + """ + # Handle string input by tokenizing it + if isinstance(input, str): + inputs = self.tokenizer(input, return_tensors="pt", padding=False, truncation=False).to( + self.cfg.device + ) + input_ids = inputs["input_ids"] + input_type = "str" + elif isinstance(input, list): + inputs = self.tokenizer(input, return_tensors="pt", padding=True, truncation=False).to( + self.cfg.device + ) + input_ids = inputs["input_ids"] + input_type = "list" + else: + input_ids = input + if input_ids.device != self.cfg.device: + input_ids = input_ids.to(self.cfg.device) + input_type = "tokens" + + # Build generation_kwargs from explicit args and kwargs + generation_kwargs = dict(generation_kwargs) if generation_kwargs is not None else {} + generation_kwargs.update( + { + "max_new_tokens": max_new_tokens, + "do_sample": do_sample, + "temperature": temperature, + "pad_token_id": self.tokenizer.eos_token_id, + } + ) + + if top_k is not None: + generation_kwargs["top_k"] = top_k + if top_p is not None: + generation_kwargs["top_p"] = top_p + if eos_token_id is not None: + generation_kwargs["eos_token_id"] = eos_token_id + elif stop_at_eos and self.tokenizer.eos_token_id is not None: + generation_kwargs["eos_token_id"] = self.tokenizer.eos_token_id + + if use_past_kv_cache: + generation_kwargs["use_cache"] = True + + # HF dict flags that trigger ModelOutput returns + hf_dict_flags = ( + "output_scores", + "output_logits", + "output_attentions", + "output_hidden_states", + ) + + # If any HF-style output flags are provided, ensure return_dict_in_generate is set + any_flag_set = False + for f in hf_dict_flags: + if f in generation_kwargs and generation_kwargs.get(f) is not None: + generation_kwargs[f] = bool(generation_kwargs[f]) + any_flag_set = True + + if any_flag_set: + generation_kwargs.setdefault("return_dict_in_generate", True) + + # Generate using the original HuggingFace model + with torch.no_grad(): + outputs = self.original_model.generate(input_ids, **generation_kwargs) # type: ignore[operator] + + # Check if output is a ModelOutput + try: + from transformers.utils import ModelOutput # type: ignore + + is_model_output = isinstance(outputs, ModelOutput) + except Exception: + is_model_output = False + + # Return based on return_type and input format + if return_type == "input" or return_type is None: + if input_type == "str": + # Decode the full output back to string + if is_model_output and hasattr(outputs, "sequences"): + return self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) + return self.tokenizer.decode(outputs[0], skip_special_tokens=True) + elif input_type == "list": + # Decode each sequence in the batch + if is_model_output and hasattr(outputs, "sequences"): + return [ + self.tokenizer.decode(seq, skip_special_tokens=True) + for seq in outputs.sequences + ] + return [self.tokenizer.decode(seq, skip_special_tokens=True) for seq in outputs] + else: + # Return the full token sequence including input + return outputs + elif return_type == "tokens": + return outputs + else: + # For other return types, default to the decoded text + if input_type == "str": + if is_model_output and hasattr(outputs, "sequences"): + return self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) + return self.tokenizer.decode(outputs[0], skip_special_tokens=True) + elif input_type == "list": + if is_model_output and hasattr(outputs, "sequences"): + return [ + self.tokenizer.decode(seq, skip_special_tokens=True) + for seq in outputs.sequences + ] + return [self.tokenizer.decode(seq, skip_special_tokens=True) for seq in outputs] + else: + return outputs + def to(self, *args, **kwargs) -> "TransformerBridge": """Move model to device and/or change dtype. @@ -1766,12 +1968,12 @@ def to(self, *args, **kwargs) -> "TransformerBridge": """ # Extract print_details if provided print_details = kwargs.pop("print_details", True) - + # Handle both device and dtype changes - # torch.nn.Module.to() supports: to(device), to(dtype), to(device, dtype), + # torch.nn.Module.to() supports: to(device), to(dtype), to(device, dtype), # to(device=...), to(dtype=...), to(device=..., dtype=...) target_device, target_dtype = None, None - + if len(args) >= 1: first_arg = args[0] if isinstance(first_arg, (torch.device, str)): @@ -1782,18 +1984,18 @@ def to(self, *args, **kwargs) -> "TransformerBridge": second_arg = args[1] if isinstance(second_arg, torch.dtype): target_dtype = second_arg - + # these override positional args if "device" in kwargs: target_device = kwargs["device"] if "dtype" in kwargs: target_dtype = kwargs["dtype"] - + if target_device is not None: move_to_and_update_config(self, target_device, print_details) if target_dtype is not None: move_to_and_update_config(self, target_dtype, print_details) - + # Move the original model with all original args/kwargs (with print_details removed) self.original_model = self.original_model.to(*args, **kwargs) return self From b4660f88a4a5d1c79b0e9c0e2c7e8e87098861d1 Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Sat, 29 Nov 2025 13:16:27 -0800 Subject: [PATCH 06/11] minor formatting and type fix --- transformer_lens/components/abstract_attention.py | 3 ++- transformer_lens/model_bridge/bridge.py | 12 ++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 7894d9a84..a0db051a3 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch import Tensor from better_abc import abstract_attribute from jaxtyping import Float, Int from transformers.utils.import_utils import is_bitsandbytes_available @@ -304,7 +305,7 @@ def forward( # so that the final linear operation occurs on the device of the inputs if w.device != z.device: w = w.to(z.device) - b_O = self.b_O + b_O: Tensor = self.b_O if b_O.device != z.device: b_O = b_O.to(z.device) # Ensure z has the same dtype as weights used in the output projection diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 08dd2e001..81fbfe406 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -1766,12 +1766,12 @@ def to(self, *args, **kwargs) -> "TransformerBridge": """ # Extract print_details if provided print_details = kwargs.pop("print_details", True) - + # Handle both device and dtype changes - # torch.nn.Module.to() supports: to(device), to(dtype), to(device, dtype), + # torch.nn.Module.to() supports: to(device), to(dtype), to(device, dtype), # to(device=...), to(dtype=...), to(device=..., dtype=...) target_device, target_dtype = None, None - + if len(args) >= 1: first_arg = args[0] if isinstance(first_arg, (torch.device, str)): @@ -1782,18 +1782,18 @@ def to(self, *args, **kwargs) -> "TransformerBridge": second_arg = args[1] if isinstance(second_arg, torch.dtype): target_dtype = second_arg - + # these override positional args if "device" in kwargs: target_device = kwargs["device"] if "dtype" in kwargs: target_dtype = kwargs["dtype"] - + if target_device is not None: move_to_and_update_config(self, target_device, print_details) if target_dtype is not None: move_to_and_update_config(self, target_dtype, print_details) - + # Move the original model with all original args/kwargs (with print_details removed) self.original_model = self.original_model.to(*args, **kwargs) return self From ab280f18fc15b4f9e768776983d509a006ee23af Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Sat, 29 Nov 2025 13:21:56 -0800 Subject: [PATCH 07/11] rerun isort fix --- transformer_lens/components/abstract_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index a0db051a3..02400f89f 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -6,9 +6,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch import Tensor from better_abc import abstract_attribute from jaxtyping import Float, Int +from torch import Tensor from transformers.utils.import_utils import is_bitsandbytes_available from transformer_lens.cache.key_value_cache_entry import ( From 0ec75eddcceda01836f0ea1ce134f789fc82467c Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Sat, 29 Nov 2025 14:26:43 -0800 Subject: [PATCH 08/11] type fixes and docstring test fix --- transformer_lens/HookedTransformer.py | 12 ++++++--- transformer_lens/model_bridge/bridge.py | 33 ++++++++++++++----------- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 2cccf1e4f..9000a02cd 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -2126,11 +2126,11 @@ def generate( self.tokenizer.decode(tokens, skip_special_tokens=True) for tokens in output_tokens ] - result = decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts + result: Any = decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts elif return_type == "tokens": - result = output_tokens + result = cast(Any, output_tokens) else: - result = embeds + result = cast(Any, embeds) if output_logits_flag: # Adhere to HF ModelOutput format with sequences (tokens) and logits (per-step) @@ -2162,7 +2162,11 @@ def generate( sequences = ( output_tokens if isinstance(output_tokens, torch.Tensor) else output_tokens ) - return GenerateDecoderOnlyOutput(sequences=sequences, logits=logits_tuple) + return GenerateDecoderOnlyOutput( + sequences=cast(torch.LongTensor, sequences), + # HF's type hint tuple[FloatTensor] is really tuple[FloatTensor, ...] + logits=logits_tuple, # type: ignore[arg-type] + ) else: return result diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 7102b5059..a267a35de 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -1762,8 +1762,10 @@ def generate( # Return a HF-compatible ModelOutput structure # GenerateDecoderOnlyOutput expects: sequences, scores (optional), logits (optional) return GenerateDecoderOnlyOutput( - sequences=output_tokens, - logits=tuple(logits_seq_list), + sequences=cast(torch.LongTensor, output_tokens), + # HF's type hint says tuple[FloatTensor] but should be tuple[FloatTensor, ...] + # (variable-length tuple with one element per generated token) + logits=tuple(logits_seq_list), # type: ignore[arg-type] ) except ImportError: # Fallback if HF not available or old version @@ -1835,18 +1837,21 @@ def hf_generate( Generated sequence as string, list of strings, tensor, or HF ModelOutput depending on input type, return_type, and generation_kwargs. - Example: - >>> # Get full HF ModelOutput with logits and attentions - >>> result = model.hf_generate( - ... "Hello world", - ... max_new_tokens=5, - ... output_logits=True, - ... output_attentions=True, - ... return_dict_in_generate=True - ... ) - >>> print(result.sequences) # Generated tokens - >>> print(result.logits) # Logits for each generation step - >>> print(result.attentions) # Attention weights + Example:: + + # Get full HF ModelOutput with logits and attentions + from transformer_lens import HookedTransformer + model = HookedTransformer.from_pretrained("tiny-stories-1M") + result = model.hf_generate( + "Hello world", + max_new_tokens=5, + output_logits=True, + output_attentions=True, + return_dict_in_generate=True + ) + print(result.sequences) # Generated tokens + print(result.logits) # Logits for each generation step + print(result.attentions) # Attention weights """ # Handle string input by tokenizing it if isinstance(input, str): From 9d4e643589f1c527e470e5be2395dd6064fbfcd7 Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Sat, 29 Nov 2025 15:10:24 -0800 Subject: [PATCH 09/11] minor sync enhancement --- transformer_lens/components/abstract_attention.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 02400f89f..0d144a741 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -337,6 +337,9 @@ def forward( ) if w.device != z.device: w = w.to(z.device) + # Ensure z has the same dtype as w before multiplication + if z.dtype != w.dtype: + z = z.to(w.dtype) z = einops.rearrange( z, "batch pos head_index d_head -> batch pos head_index d_head 1" ) From 264237bfb71f50e085228d66211ede55a7d18b18 Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Sat, 29 Nov 2025 15:48:57 -0800 Subject: [PATCH 10/11] apply minor fixes triggered by copilot review --- .../test_generation_compatibility.py | 3 +- transformer_lens/HookedTransformer.py | 31 +++++++------------ transformer_lens/model_bridge/bridge.py | 12 +++---- 3 files changed, 18 insertions(+), 28 deletions(-) diff --git a/tests/integration/test_generation_compatibility.py b/tests/integration/test_generation_compatibility.py index 50292e3f6..5af4af9a7 100644 --- a/tests/integration/test_generation_compatibility.py +++ b/tests/integration/test_generation_compatibility.py @@ -152,7 +152,6 @@ def test_unsupported_hf_flags_trigger_warning(self, gpt2_ht): def test_logits_consistency_with_forward_pass(self, gpt2_ht): """Test that logits from generate match those from forward pass.""" prompt = "Hello" - tokens = gpt2_ht.to_tokens(prompt) # Generate with output_logits result = gpt2_ht.generate( @@ -354,7 +353,7 @@ def test_hf_generate_flags_coerced_to_bool(self, gpt2_bridge): max_new_tokens=2, do_sample=False, output_logits=1, # Should be coerced to True - output_scores=0, # Should be coerced to False but still triggers flag + output_scores=0, # 0 is not None, so flag is set (coerces to False) ) # Should work without error diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 9000a02cd..7cc1babb1 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -2068,7 +2068,7 @@ def generate( if output_logits_flag: assert logits_seq_list is not None - logits_seq_list.append(final_logits.unsqueeze(1)) + logits_seq_list.append(final_logits.clone()) if do_sample: if input_type in [ @@ -2134,31 +2134,16 @@ def generate( if output_logits_flag: # Adhere to HF ModelOutput format with sequences (tokens) and logits (per-step) + from transformers.utils import ModelOutput # type: ignore + try: from transformers.generation.utils import ( GenerateDecoderOnlyOutput, # type: ignore ) - except Exception: - from transformers.utils import ModelOutput # type: ignore - # Use a ModelOutput-like object - logits_tensor = ( - torch.cat(logits_seq_list, dim=1) if logits_seq_list is not None else None - ) - logits_tuple = ( - tuple(logits_tensor[:, i, :] for i in range(logits_tensor.shape[1])) - if logits_tensor is not None - else None - ) - # `sequences` expects a tensor of token ids - return ModelOutput(sequences=output_tokens, logits=logits_tuple) # type: ignore[arg-type] - else: assert logits_seq_list is not None - logits_tensor = torch.cat(logits_seq_list, dim=1) - # Convert to HF's expected output shape: a tuple of [batch, vocab] per step - logits_tuple = tuple( - logits_tensor[:, i, :] for i in range(logits_tensor.shape[1]) - ) + # Convert list of [batch, vocab] tensors to tuple + logits_tuple = tuple(logits_seq_list) sequences = ( output_tokens if isinstance(output_tokens, torch.Tensor) else output_tokens ) @@ -2167,6 +2152,12 @@ def generate( # HF's type hint tuple[FloatTensor] is really tuple[FloatTensor, ...] logits=logits_tuple, # type: ignore[arg-type] ) + except (ImportError, AttributeError): + # Fallback if GenerateDecoderOnlyOutput not available in this transformers version + assert logits_seq_list is not None + logits_tuple = tuple(logits_seq_list) + # `sequences` expects a tensor of token ids + return ModelOutput(sequences=output_tokens, logits=logits_tuple) # type: ignore[arg-type] else: return result diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index a267a35de..e58526caa 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -1755,9 +1755,10 @@ def generate( # Return ModelOutput if output_logits was requested if output_logits and logits_seq_list is not None: + from transformers.utils import ModelOutput # type: ignore + try: from transformers.generation.utils import GenerateDecoderOnlyOutput - from transformers.utils import ModelOutput # type: ignore # Return a HF-compatible ModelOutput structure # GenerateDecoderOnlyOutput expects: sequences, scores (optional), logits (optional) @@ -1767,10 +1768,8 @@ def generate( # (variable-length tuple with one element per generated token) logits=tuple(logits_seq_list), # type: ignore[arg-type] ) - except ImportError: - # Fallback if HF not available or old version - from transformers.utils import ModelOutput - + except (ImportError, AttributeError): + # Fallback if GenerateDecoderOnlyOutput not available in this transformers version return ModelOutput( sequences=output_tokens, logits=tuple(logits_seq_list), @@ -1906,7 +1905,7 @@ def hf_generate( # If any HF-style output flags are provided, ensure return_dict_in_generate is set any_flag_set = False for f in hf_dict_flags: - if f in generation_kwargs and generation_kwargs.get(f) is not None: + if generation_kwargs.get(f) is not None: generation_kwargs[f] = bool(generation_kwargs[f]) any_flag_set = True @@ -1967,6 +1966,7 @@ def to(self, *args, **kwargs) -> "TransformerBridge": Args: args: Positional arguments for nn.Module.to kwargs: Keyword arguments for nn.Module.to + print_details: Whether to print details about device/dtype changes (default: True) Returns: Self for chaining From d7b5cd9b9dcca6c9c29a1025d949f05903fde53b Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Sat, 29 Nov 2025 16:24:58 -0800 Subject: [PATCH 11/11] minor cleanup of generate() methods for HookedTransformer/TransformerBridge --- transformer_lens/HookedTransformer.py | 23 +++++++++-------------- transformer_lens/model_bridge/bridge.py | 9 +++++++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 7cc1babb1..0b8181a91 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -2136,28 +2136,23 @@ def generate( # Adhere to HF ModelOutput format with sequences (tokens) and logits (per-step) from transformers.utils import ModelOutput # type: ignore + def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...]: + assert logits_list is not None + # Convert list of [batch, vocab] tensors to tuple + return tuple(logits_list) + try: - from transformers.generation.utils import ( - GenerateDecoderOnlyOutput, # type: ignore - ) + from transformers.generation.utils import GenerateDecoderOnlyOutput - assert logits_seq_list is not None - # Convert list of [batch, vocab] tensors to tuple - logits_tuple = tuple(logits_seq_list) - sequences = ( - output_tokens if isinstance(output_tokens, torch.Tensor) else output_tokens - ) return GenerateDecoderOnlyOutput( - sequences=cast(torch.LongTensor, sequences), + sequences=cast(torch.LongTensor, output_tokens), # HF's type hint tuple[FloatTensor] is really tuple[FloatTensor, ...] - logits=logits_tuple, # type: ignore[arg-type] + logits=_logits_to_tuple(logits_seq_list), # type: ignore[arg-type] ) except (ImportError, AttributeError): # Fallback if GenerateDecoderOnlyOutput not available in this transformers version - assert logits_seq_list is not None - logits_tuple = tuple(logits_seq_list) # `sequences` expects a tensor of token ids - return ModelOutput(sequences=output_tokens, logits=logits_tuple) # type: ignore[arg-type] + return ModelOutput(sequences=output_tokens, logits=_logits_to_tuple(logits_seq_list)) # type: ignore[arg-type] else: return result diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index e58526caa..ae44f294d 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -1757,6 +1757,11 @@ def generate( if output_logits and logits_seq_list is not None: from transformers.utils import ModelOutput # type: ignore + def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...]: + assert logits_list is not None + # Convert list of [batch, vocab] tensors to tuple + return tuple(logits_list) + try: from transformers.generation.utils import GenerateDecoderOnlyOutput @@ -1766,13 +1771,13 @@ def generate( sequences=cast(torch.LongTensor, output_tokens), # HF's type hint says tuple[FloatTensor] but should be tuple[FloatTensor, ...] # (variable-length tuple with one element per generated token) - logits=tuple(logits_seq_list), # type: ignore[arg-type] + logits=_logits_to_tuple(logits_seq_list), # type: ignore[arg-type] ) except (ImportError, AttributeError): # Fallback if GenerateDecoderOnlyOutput not available in this transformers version return ModelOutput( sequences=output_tokens, - logits=tuple(logits_seq_list), + logits=_logits_to_tuple(logits_seq_list), ) # Format output