diff --git a/tests/unit/test_generate_batch.py b/tests/unit/test_generate_batch.py new file mode 100644 index 000000000..a17b6ec59 --- /dev/null +++ b/tests/unit/test_generate_batch.py @@ -0,0 +1,16 @@ +from transformer_lens import HookedTransformer + +def test_generate_batch(): + """ + Test that batched and individual prompt generation produce the same outputs. + """ + model = HookedTransformer.from_pretrained("gpt2") + input_prompts = ["Hello, my dog is cute", "This is a much longer text. Hello, my cat is cute"] + orig_outputs = [] + for prompt in input_prompts: + out = model.generate(prompt, verbose=False, do_sample=False) + orig_outputs.append(out) + + batched_outputs = model.generate(input_prompts, verbose=False, do_sample=False) + for i in range(len(orig_outputs)): + assert orig_outputs[i] == batched_outputs[i] \ No newline at end of file diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 025b43793..2e55ef194 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -2095,7 +2095,7 @@ def generate( freq_penalty: float = 0.0, use_past_kv_cache: bool = True, prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, - padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, + padding_side: Optional[Literal["left", "right"]] = "left", return_type: Optional[str] = "input", verbose: bool = True, ) -> Union[ @@ -2139,9 +2139,9 @@ def generate( the BOS token to the input (applicable when input is a string). Defaults to None, implying usage of self.cfg.default_prepend_bos (default is True unless specified otherwise). Pass True or False to override the default. - padding_side (Union[Literal["left", "right"], None], optional): Overrides - self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple - strings of different lengths. + padding_side (Union[Literal["left", "right"], None], optional): Specifies which side to + pad when tokenizing multiple strings of different lengths. Defaults to left for + correct generation behavior. If None uses self.tokenizer.padding_side. return_type (Optional[str]): The type of the output to return - a string or a list of strings ('str'), a tensor of tokens ('tokens'), a tensor of output embeddings ('embeds') or whatever the format of the input was ('input'). @@ -2240,7 +2240,11 @@ def generate( for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): pos_offset = self.get_pos_offset(past_kv_cache, batch_size) - tokens = torch.zeros((embeds.size(0), embeds.size(1))).to(torch.int) + if len(sampled_tokens_list) > 0: + sampled_tokens = torch.cat(sampled_tokens_list, dim=1) + tokens = torch.cat((input_tokens, sampled_tokens), dim=1) + else: + tokens = input_tokens attention_mask = utils.get_attention_mask( self.tokenizer, tokens, False if prepend_bos is None else prepend_bos ).to(device) @@ -2267,6 +2271,7 @@ def generate( past_kv_cache=past_kv_cache, start_at_layer=start_at_layer, shortformer_pos_embed=shortformer_pos_embed, + attention_mask=attention_mask, ) else: logits = self.forward( @@ -2277,6 +2282,7 @@ def generate( past_kv_cache=past_kv_cache, start_at_layer=start_at_layer, shortformer_pos_embed=shortformer_pos_embed, + attention_mask=attention_mask, ) else: # We input the entire sequence, as a [batch, pos] tensor, since we aren't using @@ -2288,6 +2294,7 @@ def generate( padding_side=padding_side, start_at_layer=start_at_layer, shortformer_pos_embed=shortformer_pos_embed, + attention_mask=attention_mask, ) final_logits = logits[:, -1, :]