diff --git a/tests/acceptance/model_bridge/compatibility/test_activation_cache.py b/tests/acceptance/model_bridge/compatibility/test_activation_cache.py index aff5b03f3..cb4e39d76 100644 --- a/tests/acceptance/model_bridge/compatibility/test_activation_cache.py +++ b/tests/acceptance/model_bridge/compatibility/test_activation_cache.py @@ -27,14 +27,14 @@ def bridge_model(self, gpt2_bridge): def sample_cache(self, bridge_model): """Create a sample cache for testing.""" prompt = "The quick brown fox jumps over the lazy dog." - output, cache = bridge_model.run_with_cache(prompt) + output, cache = bridge_model.run_with_cache(input=prompt) return cache def test_cache_creation(self, bridge_model): """Test that caches can be created from TransformerBridge.""" prompt = "Test cache creation." - output, cache = bridge_model.run_with_cache(prompt, return_cache_object=True) + output, cache = bridge_model.run_with_cache(input=prompt, return_cache_object=True) assert isinstance(output, torch.Tensor) assert isinstance(cache, (dict, ActivationCache)) @@ -92,7 +92,7 @@ def test_cache_with_names_filter(self, bridge_model): filter_names = list(hook_dict.keys())[:3] try: - output, cache = bridge_model.run_with_cache(prompt, names_filter=filter_names) + output, cache = bridge_model.run_with_cache(input=prompt, names_filter=filter_names) if hasattr(cache, "cache_dict"): cache_dict = cache.cache_dict @@ -153,7 +153,7 @@ def test_cache_batch_dimension_handling(self, bridge_model): prompts = ["First prompt for batch testing.", "Second prompt for batch testing."] try: - output, cache = bridge_model.run_with_cache(prompts) + output, cache = bridge_model.run_with_cache(input=prompts) if hasattr(cache, "cache_dict"): cache_dict = cache.cache_dict @@ -173,7 +173,7 @@ def test_cache_device_consistency(self, bridge_model): prompt = "Test device consistency." model_cpu = bridge_model.cpu() - output, cache = model_cpu.run_with_cache(prompt) + output, cache = model_cpu.run_with_cache(input=prompt) if hasattr(cache, "cache_dict"): cache_dict = cache.cache_dict @@ -192,7 +192,7 @@ def test_cache_memory_efficiency(self, bridge_model): initial_memory = torch.cuda.memory_allocated() for _ in range(3): - output, cache = bridge_model.run_with_cache(prompt) + output, cache = bridge_model.run_with_cache(input=prompt) del output, cache import gc @@ -209,10 +209,10 @@ def test_cache_memory_efficiency(self, bridge_model): def test_cache_with_different_inputs(self, bridge_model): """Test that cache works with different input types.""" - output1, cache1 = bridge_model.run_with_cache("String input test.") + output1, cache1 = bridge_model.run_with_cache(input="String input test.") tokens = bridge_model.to_tokens("Token input test.") - output2, cache2 = bridge_model.run_with_cache(tokens) + output2, cache2 = bridge_model.run_with_cache(input=tokens) assert isinstance(output1, torch.Tensor) assert isinstance(output2, torch.Tensor) diff --git a/tests/acceptance/model_bridge/test_run_with_cache_batch.py b/tests/acceptance/model_bridge/test_run_with_cache_batch.py index 3600c42ce..cffae9d4a 100644 --- a/tests/acceptance/model_bridge/test_run_with_cache_batch.py +++ b/tests/acceptance/model_bridge/test_run_with_cache_batch.py @@ -48,7 +48,7 @@ def capture_individual(tensor, hook): for p in prompts: gpt2_bridge.run_with_hooks( - p, + input=p, fwd_hooks=[("blocks.11.hook_resid_post", capture_individual)], ) @@ -61,7 +61,7 @@ def capture_batched(tensor, hook): captured_batched.append(tensor[i, -1, :].detach().clone()) gpt2_bridge.run_with_hooks( - prompts, + input=prompts, fwd_hooks=[("blocks.11.hook_resid_post", capture_batched)], ) diff --git a/tests/acceptance/test_hook_tokens.py b/tests/acceptance/test_hook_tokens.py index 73690ae1a..b80871e5e 100644 --- a/tests/acceptance/test_hook_tokens.py +++ b/tests/acceptance/test_hook_tokens.py @@ -39,7 +39,7 @@ def hook_fn(tokens: Int[t.Tensor, "batch seq"], hook: HookPoint, new_first_token # Run with hooks out_from_hook = model.run_with_hooks( - prompt, + input=prompt, prepend_bos=False, fwd_hooks=[("hook_tokens", functools.partial(hook_fn, new_first_token=new_first_token))], ) diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index 8020c1923..cd6f281fb 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -585,7 +585,7 @@ def remove_pos_embed(z, hook): z[:] = 0.0 return z - _ = model.run_with_hooks("Hello, world", fwd_hooks=[("hook_pos_embed", remove_pos_embed)]) + _ = model.run_with_hooks(input="Hello, world", fwd_hooks=[("hook_pos_embed", remove_pos_embed)]) # Check that pos embed has not been permanently changed assert (model.W_pos == initial_W_pos).all() @@ -600,7 +600,7 @@ def edit_pos_embed(z, hook): return z _ = model.run_with_hooks( - ["Hello, world", "Goodbye, world"], + input=["Hello, world", "Goodbye, world"], fwd_hooks=[("hook_pos_embed", edit_pos_embed)], ) diff --git a/tests/acceptance/test_multi_gpu.py b/tests/acceptance/test_multi_gpu.py index ad407eb6e..28574ec64 100644 --- a/tests/acceptance/test_multi_gpu.py +++ b/tests/acceptance/test_multi_gpu.py @@ -50,10 +50,10 @@ def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices): gpt2_tokens = model_1_device.to_tokens(gpt2_text) gpt2_logits_1_device, gpt2_cache_1_device = model_1_device.run_with_cache( - gpt2_tokens, remove_batch_dim=True + input=gpt2_tokens, remove_batch_dim=True ) gpt2_logits_n_devices, gpt2_cache_n_devices = model_n_devices.run_with_cache( - gpt2_tokens, remove_batch_dim=True + input=gpt2_tokens, remove_batch_dim=True ) # Make sure the tensors in cache remain on their respective devices @@ -106,16 +106,16 @@ def test_load_model_on_target_device(): def test_cache_device(): model = HookedTransformer.from_pretrained("gpt2-small", device="cuda:1") - logits, cache = model.run_with_cache("Hello there") + logits, cache = model.run_with_cache(input="Hello there") assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device( torch.device("cuda:1") ) - logits, cache = model.run_with_cache("Hello there", device=torch.device("cpu")) + logits, cache = model.run_with_cache(input="Hello there", device=torch.device("cpu")) assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(torch.device("cpu")) model.to("cuda") - logits, cache = model.run_with_cache("Hello there") + logits, cache = model.run_with_cache(input="Hello there") assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(logits.device) diff --git a/tests/integration/model_bridge/compatibility/test_bridge_cache_behavior.py b/tests/integration/model_bridge/compatibility/test_bridge_cache_behavior.py index 0a347cac4..06ffa0ec0 100644 --- a/tests/integration/model_bridge/compatibility/test_bridge_cache_behavior.py +++ b/tests/integration/model_bridge/compatibility/test_bridge_cache_behavior.py @@ -56,20 +56,20 @@ class TestCacheBasics: def test_run_with_cache_returns_nonempty(self, bridge_compat): """run_with_cache returns a non-empty cache.""" with torch.no_grad(): - _, cache = bridge_compat.run_with_cache("Hello world") + _, cache = bridge_compat.run_with_cache(input="Hello world") assert len(cache) > 0 def test_cache_contains_residual_hooks(self, bridge_compat): """Cache should contain residual stream hooks.""" with torch.no_grad(): - _, cache = bridge_compat.run_with_cache("Hello world") + _, cache = bridge_compat.run_with_cache(input="Hello world") cache_keys = list(cache.keys()) assert any("hook_resid" in k for k in cache_keys) def test_cache_values_are_tensors(self, bridge_compat): """All cached values should be tensors with correct batch dimension.""" with torch.no_grad(): - _, cache = bridge_compat.run_with_cache("Hello") + _, cache = bridge_compat.run_with_cache(input="Hello") for key, value in cache.items(): assert isinstance(value, torch.Tensor), f"Cache[{key}] is {type(value)}" assert value.shape[0] == 1, f"Cache[{key}] batch dim is {value.shape[0]}" @@ -81,9 +81,9 @@ class TestCacheNamesFilter: def test_names_filter_returns_subset(self, bridge_compat): """names_filter should return only matching keys.""" with torch.no_grad(): - _, full_cache = bridge_compat.run_with_cache("Hello") + _, full_cache = bridge_compat.run_with_cache(input="Hello") _, filtered_cache = bridge_compat.run_with_cache( - "Hello", + input="Hello", names_filter=lambda name: "hook_resid_pre" in name, ) @@ -98,7 +98,7 @@ class TestCacheCompleteness: def test_all_expected_hooks_in_cache(self, bridge_compat): """Cache should contain all expected hook names.""" - _, cache = bridge_compat.run_with_cache("Hello World!") + _, cache = bridge_compat.run_with_cache(input="Hello World!") actual_keys = set(cache.keys()) missing = set(EXPECTED_HOOKS) - actual_keys assert len(missing) == 0, f"Missing expected hooks: {sorted(missing)}" @@ -133,8 +133,8 @@ def test_cache_values_match(self, bridge_compat, reference_ht): Unmasked scores and resulting patterns should still match. """ prompt = "Hello World!" - _, bridge_cache = bridge_compat.run_with_cache(prompt) - _, ht_cache = reference_ht.run_with_cache(prompt) + _, bridge_cache = bridge_compat.run_with_cache(input=prompt) + _, ht_cache = reference_ht.run_with_cache(input=prompt) for hook in EXPECTED_HOOKS: if hook not in bridge_cache or hook not in ht_cache: diff --git a/tests/integration/model_bridge/compatibility/test_bridge_hook_behavior.py b/tests/integration/model_bridge/compatibility/test_bridge_hook_behavior.py index 8df4599ef..38f854d20 100644 --- a/tests/integration/model_bridge/compatibility/test_bridge_hook_behavior.py +++ b/tests/integration/model_bridge/compatibility/test_bridge_hook_behavior.py @@ -44,7 +44,7 @@ def hook_fn(tensor, hook): return tensor bridge.run_with_hooks( - "Hello world", + input="Hello world", fwd_hooks=[("blocks.0.hook_resid_pre", hook_fn)], ) assert count == 1 @@ -58,7 +58,7 @@ def hook_fn(tensor, hook): return tensor bridge.run_with_hooks( - "Hello", + input="Hello", fwd_hooks=[("blocks.0.hook_resid_pre", hook_fn)], ) assert len(captured["shape"]) >= 2 @@ -76,7 +76,7 @@ def hook_fn(tensor, hook): return hook_fn bridge.run_with_hooks( - "Hello", + input="Hello", fwd_hooks=[ ("blocks.0.hook_resid_pre", make_hook("resid_pre_0")), ("blocks.0.hook_resid_post", make_hook("resid_post_0")), @@ -116,7 +116,7 @@ def zero_hook(tensor, hook): return torch.zeros_like(tensor) modified_output = bridge.run_with_hooks( - "Hello world", + input="Hello world", fwd_hooks=[("blocks.0.hook_resid_pre", zero_hook)], ) @@ -132,7 +132,7 @@ def ablation_hook(activation, hook): return activation ablated_loss = bridge_compat.run_with_hooks( - test_text, + input=test_text, return_type="loss", fwd_hooks=[("blocks.0.attn.hook_v", ablation_hook)], ) @@ -154,14 +154,14 @@ def ablation_hook(activation, hook): ht_baseline = reference_ht(test_text, return_type="loss") ht_ablated = reference_ht.run_with_hooks( - test_text, + input=test_text, return_type="loss", fwd_hooks=[("blocks.0.attn.hook_v", ablation_hook)], ) bridge_baseline = bridge_compat(test_text, return_type="loss") bridge_ablated = bridge_compat.run_with_hooks( - test_text, + input=test_text, return_type="loss", fwd_hooks=[("blocks.0.attn.hook_v", ablation_hook)], ) @@ -190,7 +190,7 @@ def hook_fn(activation, hook): return hook_fn bridge_compat.run_with_hooks( - "The quick brown fox", + input="The quick brown fox", return_type="logits", fwd_hooks=[("hook_embed", capture("embed"))], ) @@ -209,7 +209,7 @@ def hook_fn(activation, hook): return hook_fn bridge_compat.run_with_hooks( - "The quick brown fox", + input="The quick brown fox", return_type="logits", fwd_hooks=[("blocks.0.attn.hook_v", capture("v"))], ) @@ -259,7 +259,7 @@ def hook_fn(tensor, hook): with torch.no_grad(): bridge.run_with_hooks( - "Hello", + input="Hello", fwd_hooks=[("blocks.0.hook_resid_pre", hook_fn)], ) assert count == 1 diff --git a/tests/integration/model_bridge/generalized_components/test_joint_qkv_attention_bridge_integration.py b/tests/integration/model_bridge/generalized_components/test_joint_qkv_attention_bridge_integration.py index 13e2379a1..36eea11b5 100644 --- a/tests/integration/model_bridge/generalized_components/test_joint_qkv_attention_bridge_integration.py +++ b/tests/integration/model_bridge/generalized_components/test_joint_qkv_attention_bridge_integration.py @@ -83,7 +83,7 @@ def v_ablation_hook(value, hook): original_loss = model(tokens, return_type="loss") # Use the correct hook name for Bridge architecture (v.hook_out instead of hook_v) hooked_loss = model.run_with_hooks( - tokens, + input=tokens, return_type="loss", fwd_hooks=[("blocks.0.attn.v.hook_out", v_ablation_hook)], ) diff --git a/tests/integration/model_bridge/test_bridge_stop_at_layer.py b/tests/integration/model_bridge/test_bridge_stop_at_layer.py index 20e6e2298..a8a8ba477 100644 --- a/tests/integration/model_bridge/test_bridge_stop_at_layer.py +++ b/tests/integration/model_bridge/test_bridge_stop_at_layer.py @@ -180,7 +180,7 @@ def count_hook(activation, hook): # Hook at blocks.0 should fire # Hook at blocks.1 should NOT fire (stop_at_layer=1) output = bridge_default.run_with_hooks( - rand_input, + input=rand_input, stop_at_layer=1, fwd_hooks=[ ("embed.hook_out", count_hook), @@ -505,7 +505,7 @@ def count_hook(activation, hook): # Hook at blocks.0 should fire # Hook at blocks.1 should NOT fire (stop_at_layer=1) output = bridge_with_compat_and_processing.run_with_hooks( - rand_input, + input=rand_input, stop_at_layer=1, fwd_hooks=[ ("embed.hook_out", count_hook), diff --git a/tests/integration/model_bridge/test_bridge_vs_hf_eager_parity.py b/tests/integration/model_bridge/test_bridge_vs_hf_eager_parity.py index dcecd45ee..c2182f7a0 100644 --- a/tests/integration/model_bridge/test_bridge_vs_hf_eager_parity.py +++ b/tests/integration/model_bridge/test_bridge_vs_hf_eager_parity.py @@ -94,7 +94,7 @@ def _h(_m, _i, o): for i in range(n_layers) ] with torch.inference_mode(): - bridge.run_with_hooks(tokens, fwd_hooks=fwd_hooks) + bridge.run_with_hooks(input=tokens, fwd_hooks=fwd_hooks) for i in range(n_layers): d = (hf_layer_out[i] - bridge_layer_out[i]).abs().max().item() @@ -114,7 +114,7 @@ def test_bridge_attention_reconstruction_actually_runs(bridge, tokenize): tokens = tokenize("Hello, world!") attn_scores_fired: list[bool] = [] bridge.run_with_hooks( - tokens, + input=tokens, fwd_hooks=[ ("blocks.0.attn.hook_attn_scores", lambda v, hook: attn_scores_fired.append(True)), ], diff --git a/tests/integration/model_bridge/test_mamba_adapter.py b/tests/integration/model_bridge/test_mamba_adapter.py index 3f72449a4..5fc8d61d9 100644 --- a/tests/integration/model_bridge/test_mamba_adapter.py +++ b/tests/integration/model_bridge/test_mamba_adapter.py @@ -168,7 +168,7 @@ def zero_resid(t, hook): with torch.no_grad(): ablated = mamba_bridge.run_with_hooks( - tokens, + input=tokens, fwd_hooks=[("blocks.12.hook_in", zero_resid)], ) # Zeroing a mid-layer residual stream should change the output diff --git a/tests/integration/model_bridge/test_qwen3_5_multimodal_bridge.py b/tests/integration/model_bridge/test_qwen3_5_multimodal_bridge.py index 02c26c8d8..8c1f5ad30 100644 --- a/tests/integration/model_bridge/test_qwen3_5_multimodal_bridge.py +++ b/tests/integration/model_bridge/test_qwen3_5_multimodal_bridge.py @@ -38,7 +38,7 @@ def fn(tensor, hook): with torch.no_grad(): bridge.run_with_hooks( - "The quick brown fox", fwd_hooks=[(n, capture(n)) for n in gate_hooks] + input="The quick brown fox", fwd_hooks=[(n, capture(n)) for n in gate_hooks] ) gate = captured.get("blocks.1.attn.hook_q_gate") diff --git a/tests/integration/model_bridge/test_smollm3_adapter.py b/tests/integration/model_bridge/test_smollm3_adapter.py index 79945efeb..38ae17f69 100644 --- a/tests/integration/model_bridge/test_smollm3_adapter.py +++ b/tests/integration/model_bridge/test_smollm3_adapter.py @@ -142,7 +142,7 @@ def _hook(_module, _inputs, output): for i in range(n_layers) ] with torch.inference_mode(): - bridge.run_with_hooks(tokens, fwd_hooks=fwd_hooks) + bridge.run_with_hooks(input=tokens, fwd_hooks=fwd_hooks) for i, layer in enumerate(hf_eager.model.layers): drift = (hf_layer_out[i] - bridge_layer_out[i]).abs().max().item() @@ -163,7 +163,7 @@ def test_bridge_runs_its_own_attention_reconstruction( """ fired: list[bool] = [] bridge.run_with_hooks( - tokens, + input=tokens, fwd_hooks=[ ("blocks.0.attn.hook_attn_scores", lambda value, hook: fired.append(True)), ], diff --git a/tests/integration/model_bridge/test_weight_processing.py b/tests/integration/model_bridge/test_weight_processing.py index ce22f6423..8213468a5 100644 --- a/tests/integration/model_bridge/test_weight_processing.py +++ b/tests/integration/model_bridge/test_weight_processing.py @@ -95,7 +95,7 @@ def ablation_hook(activation, hook): return activation ref_ablated_loss = reference_ht.run_with_hooks( - test_text, return_type="loss", fwd_hooks=[(hook_name, ablation_hook)] + input=test_text, return_type="loss", fwd_hooks=[(hook_name, ablation_hook)] ) ref_ablation_effect = ref_ablated_loss - ref_loss @@ -115,7 +115,7 @@ def ablation_hook(activation, hook): # Test ablation with bridge bridge_ablated_loss = bridge.run_with_hooks( - test_text, return_type="loss", fwd_hooks=[(hook_name, ablation_hook)] + input=test_text, return_type="loss", fwd_hooks=[(hook_name, ablation_hook)] ) bridge_ablation_effect = bridge_ablated_loss - bridge_loss @@ -193,7 +193,7 @@ def ablation_hook( hook_name = utils.get_act_name("v", layer) orig = model(tokens, return_type="loss").item() ablated = model.run_with_hooks( - tokens, return_type="loss", fwd_hooks=[(hook_name, ablation_hook)] + input=tokens, return_type="loss", fwd_hooks=[(hook_name, ablation_hook)] ).item() return orig, ablated diff --git a/tests/integration/model_bridge/test_weight_processing_perfect_match.py b/tests/integration/model_bridge/test_weight_processing_perfect_match.py index 0307bf6cf..54288100e 100644 --- a/tests/integration/model_bridge/test_weight_processing_perfect_match.py +++ b/tests/integration/model_bridge/test_weight_processing_perfect_match.py @@ -90,11 +90,11 @@ def head_ablation_hook(value, hook): hook_name = utils.get_act_name("v", layer_to_ablate) hooked_ablated = hooked_processed.run_with_hooks( - tokens, return_type="loss", fwd_hooks=[(hook_name, head_ablation_hook)] + input=tokens, return_type="loss", fwd_hooks=[(hook_name, head_ablation_hook)] ) corrected_ablated = corrected_processed.run_with_hooks( - tokens, return_type="loss", fwd_hooks=[(hook_name, head_ablation_hook)] + input=tokens, return_type="loss", fwd_hooks=[(hook_name, head_ablation_hook)] ) print(f"HookedTransformer: {hooked_ablated.item():.6f}") diff --git a/tests/integration/test_grouped_query_attention.py b/tests/integration/test_grouped_query_attention.py index a8cf6f397..2b23b8fe5 100644 --- a/tests/integration/test_grouped_query_attention.py +++ b/tests/integration/test_grouped_query_attention.py @@ -225,7 +225,7 @@ def test_ungroup_grouped_query_attention_flag_changes_k_v_hooks_shape(): x = torch.arange(1, 9).unsqueeze(0) flag_off_output, flag_off_cache = model.run_with_cache( - x, + input=x, names_filter=[ "blocks.0.attn.hook_k", "blocks.0.attn.hook_v", @@ -238,7 +238,7 @@ def test_ungroup_grouped_query_attention_flag_changes_k_v_hooks_shape(): assert model.cfg.ungroup_grouped_query_attention is True flag_on_output, flag_on_cache = model.run_with_cache( - x, + input=x, names_filter=[ "blocks.0.attn.hook_k", "blocks.0.attn.hook_v", diff --git a/tests/integration/test_hooks.py b/tests/integration/test_hooks.py index f3d3989c4..e606747f8 100644 --- a/tests/integration/test_hooks.py +++ b/tests/integration/test_hooks.py @@ -68,7 +68,7 @@ def test_context_manager_run_with_cache(): c = Counter() with model.hooks(fwd_hooks=[(embed, c.inc)]): assert len(model.hook_dict["hook_embed"].fwd_hooks) == 1 - model.run_with_cache(prompt) + model.run_with_cache(input=prompt) assert len(model.hook_dict["hook_embed"].fwd_hooks) == 1 assert len(model.hook_dict["hook_embed"].fwd_hooks) == 0 assert c.count == 1 @@ -224,7 +224,7 @@ def identity_hook(z, hook): set_use_hook_function(True) cache = model.run_with_cache( - prompt, + input=prompt, names_filter=lambda x: x == hook_name, )[1] diff --git a/tests/integration/test_main_demo_pattern_hooks.py b/tests/integration/test_main_demo_pattern_hooks.py index 4539cb42a..1ecc69ee3 100644 --- a/tests/integration/test_main_demo_pattern_hooks.py +++ b/tests/integration/test_main_demo_pattern_hooks.py @@ -60,7 +60,7 @@ def induction_score_hook(pattern, hook): # Run with hooks (should not raise any errors) model.run_with_hooks( - repeated_tokens, + input=repeated_tokens, return_type=None, # For efficiency, don't calculate logits fwd_hooks=[(pattern_hook_names_filter, induction_score_hook)], ) @@ -97,7 +97,7 @@ def tracking_hook(pattern, hook): pattern_filter = lambda name: name.endswith("pattern") # Run with hooks - model.run_with_hooks(tokens, return_type=None, fwd_hooks=[(pattern_filter, tracking_hook)]) + model.run_with_hooks(input=tokens, return_type=None, fwd_hooks=[(pattern_filter, tracking_hook)]) # Verify each pattern hook was called exactly once for name, count in hook_calls.items(): @@ -130,7 +130,7 @@ def layer_tracking_hook(pattern, hook): # Run with hooks model.run_with_hooks( - tokens, return_type=None, fwd_hooks=[(pattern_filter, layer_tracking_hook)] + input=tokens, return_type=None, fwd_hooks=[(pattern_filter, layer_tracking_hook)] ) # Verify we got layer indices for all layers @@ -183,7 +183,7 @@ def induction_score_hook(pattern, hook): # Run with hooks model.run_with_hooks( - repeated_tokens, + input=repeated_tokens, return_type=None, fwd_hooks=[(pattern_hook_names_filter, induction_score_hook)], ) diff --git a/tests/integration/test_start_at_layer.py b/tests/integration/test_start_at_layer.py index f1d007829..ec2ae8139 100644 --- a/tests/integration/test_start_at_layer.py +++ b/tests/integration/test_start_at_layer.py @@ -52,7 +52,7 @@ def count_hook(activation, hook): return None output = model.run_with_hooks( - rand_embed, + input=rand_embed, start_at_layer=1, fwd_hooks=[ ("hook_embed", count_hook), diff --git a/tests/integration/test_stop_at_layer.py b/tests/integration/test_stop_at_layer.py index 2692c8f49..279663aee 100644 --- a/tests/integration/test_stop_at_layer.py +++ b/tests/integration/test_stop_at_layer.py @@ -56,7 +56,7 @@ def count_hook(activation, hook): return None output = model.run_with_hooks( - rand_input, + input=rand_input, stop_at_layer=1, fwd_hooks=[ ("hook_embed", count_hook), diff --git a/tests/mps/test_mps_basic.py b/tests/mps/test_mps_basic.py index 63f6e6b5d..5fc3eeba0 100644 --- a/tests/mps/test_mps_basic.py +++ b/tests/mps/test_mps_basic.py @@ -178,7 +178,7 @@ def test_mps_run_with_cache(): model = _load_tiny_model(device="mps") tokens = model.to_tokens("The quick brown fox") - logits, cache = model.run_with_cache(tokens) + logits, cache = model.run_with_cache(input=tokens) assert logits.device.type == "mps" @@ -204,7 +204,7 @@ def capture_hook(value, hook): return value model.run_with_hooks( - tokens, + input=tokens, fwd_hooks=[ ("blocks.0.attn.hook_q", capture_hook), ("blocks.0.mlp.hook_post", capture_hook), diff --git a/tests/unit/model_bridge/compatibility/test_next_sentence_prediction.py b/tests/unit/model_bridge/compatibility/test_next_sentence_prediction.py index 59fa5c372..b41fe2849 100644 --- a/tests/unit/model_bridge/compatibility/test_next_sentence_prediction.py +++ b/tests/unit/model_bridge/compatibility/test_next_sentence_prediction.py @@ -179,7 +179,7 @@ def test_run_with_cache(bert_nsp, mock_transformer_bridge): # Run with cache output, cache = bert_nsp.run_with_cache( - input_data, return_type="logits", return_cache_object=True + input=input_data, return_type="logits", return_cache_object=True ) # Verify output shape and values diff --git a/tests/unit/model_bridge/compatibility/test_split_qkv.py b/tests/unit/model_bridge/compatibility/test_split_qkv.py index 4e36dd72f..6932d8b83 100644 --- a/tests/unit/model_bridge/compatibility/test_split_qkv.py +++ b/tests/unit/model_bridge/compatibility/test_split_qkv.py @@ -54,7 +54,7 @@ def _hook(tensor, hook): try: gpt2_bridge.run_with_hooks( - x, + input=x, fwd_hooks=[ ("blocks.0.attn.hook_q_input", cap("q")), ("blocks.0.attn.hook_k_input", cap("k")), @@ -93,7 +93,7 @@ def _hook(tensor, hook): return _hook gpt2_bridge.run_with_hooks( - x, + input=x, fwd_hooks=[ ("blocks.0.attn.k.hook_out", cap_baseline("k")), ("blocks.0.attn.v.hook_out", cap_baseline("v")), @@ -113,7 +113,7 @@ def zero_q_input(tensor, hook): return torch.zeros_like(tensor) gpt2_bridge.run_with_hooks( - x, + input=x, fwd_hooks=[ ("blocks.0.attn.hook_q_input", zero_q_input), ("blocks.0.attn.k.hook_out", cap_patched("k")), diff --git a/tests/unit/model_bridge/compatibility/test_use_attn_result.py b/tests/unit/model_bridge/compatibility/test_use_attn_result.py index 521a5a7bf..8eb44ec34 100644 --- a/tests/unit/model_bridge/compatibility/test_use_attn_result.py +++ b/tests/unit/model_bridge/compatibility/test_use_attn_result.py @@ -43,7 +43,7 @@ def _hook(tensor, hook): try: gpt2_bridge.run_with_hooks( - x, + input=x, fwd_hooks=[ ("blocks.0.attn.hook_result", cap("result")), ("blocks.0.attn.hook_out", cap("out")), @@ -80,7 +80,7 @@ def _hook(tensor, hook): fired["result"] = True return tensor - gpt2_bridge.run_with_hooks(x, fwd_hooks=[("blocks.0.attn.hook_result", _hook)]) + gpt2_bridge.run_with_hooks(input=x, fwd_hooks=[("blocks.0.attn.hook_result", _hook)]) assert fired["result"] is False, ( "hook_result fired when use_attn_result was False; the flag is " "supposed to skip the per-head computation." diff --git a/tests/unit/model_bridge/supported_architectures/test_qwen3_5_adapter.py b/tests/unit/model_bridge/supported_architectures/test_qwen3_5_adapter.py index 21b7a5a89..715d0e100 100644 --- a/tests/unit/model_bridge/supported_architectures/test_qwen3_5_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_qwen3_5_adapter.py @@ -803,7 +803,7 @@ def _capture(tensor: torch.Tensor, hook: object) -> torch.Tensor: tokens = torch.randint(0, 512, (1, 4)) with torch.no_grad(): bridge.run_with_hooks( - tokens, + input=tokens, fwd_hooks=[(name, capture_hook(name)) for name in hook_names], ) diff --git a/tests/unit/model_bridge/test_bridge_vs_hooked_transformer_mlp_in_patching.py b/tests/unit/model_bridge/test_bridge_vs_hooked_transformer_mlp_in_patching.py index e643905a1..ebc2deba7 100644 --- a/tests/unit/model_bridge/test_bridge_vs_hooked_transformer_mlp_in_patching.py +++ b/tests/unit/model_bridge/test_bridge_vs_hooked_transformer_mlp_in_patching.py @@ -72,8 +72,8 @@ def _inner(tensor: torch.Tensor, hook: object) -> torch.Tensor: return _inner - bridge.run_with_hooks(prompt_a, fwd_hooks=[(f"blocks.{layer}.hook_mlp_in", _cap_bridge)]) - ht.run_with_hooks(prompt_a, fwd_hooks=[(f"blocks.{layer}.hook_mlp_in", _cap_ht)]) + bridge.run_with_hooks(input=prompt_a, fwd_hooks=[(f"blocks.{layer}.hook_mlp_in", _cap_bridge)]) + ht.run_with_hooks(input=prompt_a, fwd_hooks=[(f"blocks.{layer}.hook_mlp_in", _cap_ht)]) # Pins down a silent-miss in the ln2 pre-hook (the #1317 bug class). assert bridge_fire_count["n"] == 1, ( @@ -89,10 +89,10 @@ def _inner(tensor: torch.Tensor, hook: object) -> torch.Tensor: ) bridge_logits = bridge.run_with_hooks( - prompt_b, fwd_hooks=[(f"blocks.{layer}.hook_mlp_in", _patch(cache_a_bridge))] + input=prompt_b, fwd_hooks=[(f"blocks.{layer}.hook_mlp_in", _patch(cache_a_bridge))] ) ht_logits = ht.run_with_hooks( - prompt_b, fwd_hooks=[(f"blocks.{layer}.hook_mlp_in", _patch(cache_a_ht))] + input=prompt_b, fwd_hooks=[(f"blocks.{layer}.hook_mlp_in", _patch(cache_a_ht))] ) baseline_diff = _baseline_logit_diff(model, no_processing, prompt_b) @@ -117,7 +117,7 @@ def _counter(tensor: torch.Tensor, hook: object) -> torch.Tensor: return tensor prompt = torch.arange(1, 9).unsqueeze(0) - bridge.run_with_hooks(prompt, fwd_hooks=[("blocks.0.hook_mlp_in", _counter)]) + bridge.run_with_hooks(input=prompt, fwd_hooks=[("blocks.0.hook_mlp_in", _counter)]) assert fire_count["n"] == 0, ( f"hook_mlp_in fired {fire_count['n']} times with use_hook_mlp_in=False; " "should not fire when the flag is off" diff --git a/tests/unit/model_bridge/test_component_inspection.py b/tests/unit/model_bridge/test_component_inspection.py index d0f76ac30..3554c42bc 100644 --- a/tests/unit/model_bridge/test_component_inspection.py +++ b/tests/unit/model_bridge/test_component_inspection.py @@ -91,7 +91,7 @@ def test_forward_returns_loss(self, bridge): def test_run_with_cache_returns_activations(self, bridge): """run_with_cache should return non-empty cache.""" with torch.no_grad(): - _, cache = bridge.run_with_cache("Hello") + _, cache = bridge.run_with_cache(input="Hello") assert len(cache) > 0 # Should have block-level hooks block_keys = [k for k in cache.keys() if "blocks.0" in k] diff --git a/tests/unit/test_hubert_hooks.py b/tests/unit/test_hubert_hooks.py index 0ed316992..e7b138662 100644 --- a/tests/unit/test_hubert_hooks.py +++ b/tests/unit/test_hubert_hooks.py @@ -47,7 +47,7 @@ class TestHubertRunWithCache: def test_cache_contains_attention_pattern(self, audio_model, frames_and_mask): frames, frame_mask = frames_and_mask _, cache = audio_model.run_with_cache( - frames, one_zero_attention_mask=frame_mask, remove_batch_dim=True + inputs=frames, one_zero_attention_mask=frame_mask, remove_batch_dim=True ) layer = 0 pattern_name = utils.get_act_name("pattern", layer) @@ -61,7 +61,7 @@ def test_cache_contains_attention_pattern(self, audio_model, frames_and_mask): def test_cache_attention_pattern_shape(self, audio_model, frames_and_mask): frames, frame_mask = frames_and_mask _, cache = audio_model.run_with_cache( - frames, one_zero_attention_mask=frame_mask, remove_batch_dim=True + inputs=frames, one_zero_attention_mask=frame_mask, remove_batch_dim=True ) pattern_name = utils.get_act_name("pattern", 0) if pattern_name in cache: @@ -91,13 +91,13 @@ def head_ablation_hook(value, hook): # Baseline baseline_out = audio_model.run_with_hooks( - frames, fwd_hooks=[], one_zero_attention_mask=frame_mask + inputs=frames, fwd_hooks=[], one_zero_attention_mask=frame_mask ) baseline_tensor = _get_output_tensor(baseline_out) # Ablated ablated_out = audio_model.run_with_hooks( - frames, + inputs=frames, fwd_hooks=[(v_act_name, head_ablation_hook)], one_zero_attention_mask=frame_mask, ) diff --git a/tests/unit/test_next_sentence_prediction.py b/tests/unit/test_next_sentence_prediction.py index 3c355743d..46143a2a5 100644 --- a/tests/unit/test_next_sentence_prediction.py +++ b/tests/unit/test_next_sentence_prediction.py @@ -247,7 +247,7 @@ def test_run_with_cache(bert_nsp, mock_hooked_encoder): # Run with cache output, cache = bert_nsp.run_with_cache( - input_data, return_type="logits", return_cache_object=True + input=input_data, return_type="logits", return_cache_object=True ) # Verify output shape and values diff --git a/transformer_lens/BertNextSentencePrediction.py b/transformer_lens/BertNextSentencePrediction.py index eb38e6879..dd5d5325a 100644 --- a/transformer_lens/BertNextSentencePrediction.py +++ b/transformer_lens/BertNextSentencePrediction.py @@ -217,19 +217,18 @@ def forward( @overload def run_with_cache( - self, *model_args, return_cache_object: Literal[True] = True, **kwargs + self, return_cache_object: Literal[True] = True, **kwargs ) -> Tuple[Float[torch.Tensor, "batch 2"], ActivationCache,]: ... @overload def run_with_cache( - self, *model_args, return_cache_object: Literal[False], **kwargs + self, return_cache_object: Literal[False], **kwargs ) -> Tuple[Float[torch.Tensor, "batch 2"], Dict[str, torch.Tensor],]: ... def run_with_cache( self, - *model_args, return_cache_object: bool = True, remove_batch_dim: bool = False, **kwargs, @@ -263,7 +262,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): with ForwardWrapper(self): out, cache_dict = self.model.run_with_cache( - *model_args, remove_batch_dim=remove_batch_dim, **kwargs + remove_batch_dim=remove_batch_dim, **kwargs ) if return_cache_object: cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index 472425ec5..ba6cb0983 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -296,19 +296,18 @@ def forward( @overload def run_with_cache( - self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any + self, return_cache_object: Literal[True] = True, **kwargs: Any ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: ... @overload def run_with_cache( - self, *model_args: Any, return_cache_object: Literal[False], **kwargs: Any + self, return_cache_object: Literal[False], **kwargs: Any ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: ... def run_with_cache( self, - *model_args: Any, return_cache_object: bool = True, remove_batch_dim: bool = False, **kwargs: Any, @@ -320,7 +319,7 @@ def run_with_cache( Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. """ out, cache_dict = super().run_with_cache( - *model_args, remove_batch_dim=remove_batch_dim, **kwargs + remove_batch_dim=remove_batch_dim, **kwargs ) if return_cache_object: cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) diff --git a/transformer_lens/HookedEncoder.py b/transformer_lens/HookedEncoder.py index f71a9c75a..2c6a33e89 100644 --- a/transformer_lens/HookedEncoder.py +++ b/transformer_lens/HookedEncoder.py @@ -317,19 +317,18 @@ def forward( @overload def run_with_cache( - self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any + self, return_cache_object: Literal[True] = True, **kwargs: Any ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: ... @overload def run_with_cache( - self, *model_args: Any, return_cache_object: Literal[False], **kwargs: Any + self, return_cache_object: Literal[False], **kwargs: Any ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: ... def run_with_cache( self, - *model_args: Any, return_cache_object: bool = True, remove_batch_dim: bool = False, **kwargs: Any, @@ -341,7 +340,7 @@ def run_with_cache( Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. """ out, cache_dict = super().run_with_cache( - *model_args, remove_batch_dim=remove_batch_dim, **kwargs + remove_batch_dim=remove_batch_dim, **kwargs ) if return_cache_object: cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) diff --git a/transformer_lens/HookedEncoderDecoder.py b/transformer_lens/HookedEncoderDecoder.py index 7bc56df9e..5dc6faa05 100644 --- a/transformer_lens/HookedEncoderDecoder.py +++ b/transformer_lens/HookedEncoderDecoder.py @@ -487,19 +487,18 @@ def generate( @overload # type: ignore[overload-overlap] def run_with_cache( - self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any + self, return_cache_object: Literal[True] = True, **kwargs: Any ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: ... @overload # type: ignore[overload-overlap] def run_with_cache( - self, *model_args: Any, return_cache_object: Literal[False] = False, **kwargs: Any + self, return_cache_object: Literal[False] = False, **kwargs: Any ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: ... def run_with_cache( self, - *model_args: Any, return_cache_object: bool = True, remove_batch_dim: bool = False, **kwargs: Any, @@ -511,7 +510,7 @@ def run_with_cache( Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. """ out, cache_dict = super().run_with_cache( - *model_args, remove_batch_dim=remove_batch_dim, **kwargs + remove_batch_dim=remove_batch_dim, **kwargs ) if return_cache_object: cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) diff --git a/transformer_lens/HookedRootModule.py b/transformer_lens/HookedRootModule.py index 262886057..08fd08975 100644 --- a/transformer_lens/HookedRootModule.py +++ b/transformer_lens/HookedRootModule.py @@ -278,7 +278,6 @@ def hooks( def run_with_hooks( self, - *model_args: Any, # TODO: unsure about whether or not this Any typing is correct or not; may need to be replaced with something more specific? fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], reset_hooks_end: bool = True, @@ -299,7 +298,6 @@ def run_with_hooks( during this run. Default is True. clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. Default is False. - *model_args: Positional arguments for the model. **model_kwargs: Keyword arguments for the model's forward function. See your related models forward pass for details as to what sort of arguments you can pass through. @@ -313,7 +311,7 @@ def run_with_hooks( ) with self.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_model: - return hooked_model.forward(*model_args, **model_kwargs) + return hooked_model.forward(**model_kwargs) def add_caching_hooks( self, @@ -372,7 +370,6 @@ def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool): def run_with_cache( self, - *model_args: Any, names_filter: NamesFilter = None, device: DeviceType = None, remove_batch_dim: bool = False, @@ -386,7 +383,6 @@ def run_with_cache( Runs the model and returns the model output and a Cache object. Args: - *model_args: Positional arguments for the model. names_filter (NamesFilter, optional): A filter for which activations to cache. Accepts None, str, list of str, or a function that takes a string and returns a bool. Defaults to None, which means cache everything. @@ -428,7 +424,7 @@ def run_with_cache( reset_hooks_end=reset_hooks_end, clear_contexts=clear_contexts, ): - model_out = self(*model_args, **model_kwargs) + model_out = self(**model_kwargs) if incl_bwd: model_out.backward() diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 9baf479e5..77949c8cd 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -705,18 +705,18 @@ def loss_fn( @overload def run_with_cache( - self, *model_args, return_cache_object: Literal[True] = True, **kwargs + self, return_cache_object: Literal[True] = True, **kwargs ) -> Tuple[Output, ActivationCache]: ... @overload def run_with_cache( - self, *model_args, return_cache_object: Literal[False], **kwargs + self, return_cache_object: Literal[False], **kwargs ) -> Tuple[Output, Dict[str, torch.Tensor]]: ... def run_with_cache( - self, *model_args, return_cache_object=True, remove_batch_dim=False, **kwargs + self, return_cache_object=True, remove_batch_dim=False, **kwargs ) -> Tuple[ Union[ None, @@ -733,7 +733,7 @@ def run_with_cache( activations as in HookedRootModule. """ out, cache_dict = super().run_with_cache( - *model_args, remove_batch_dim=remove_batch_dim, **kwargs + remove_batch_dim=remove_batch_dim, **kwargs ) if return_cache_object: cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim)