diff --git a/gemma/gm/text/_sampler.py b/gemma/gm/text/_sampler.py index 98e0f9d3..993866d7 100644 --- a/gemma/gm/text/_sampler.py +++ b/gemma/gm/text/_sampler.py @@ -407,15 +407,28 @@ def _tokenize_prompts( prompt: str | Sequence[str], *, add_bos: bool, - pad_length: int | None = None, + pad_length: int | tuple[int, ...] | None = None, ) -> Float['B L']: """Encode the prompts.""" prompt = _normalize_prompt(prompt) tokens = [self.tokenizer.encode(p, add_bos=add_bos) for p in prompt] - # Notice that if pad_length exceeds the maximum length of the prompts, - # an error will be raised by the `.pad` function below. - max_prompt_len = pad_length or max(len(t) for t in tokens) + # Calculate the maximum prompt length, handling pad_length buckets. + actual_max = max(len(t) for t in tokens) + if pad_length is None: + max_prompt_len = actual_max + elif isinstance(pad_length, tuple): + # Handle tuple buckets - pick smallest bucket that fits. + for bucket_size in pad_length: + if actual_max <= bucket_size: + max_prompt_len = bucket_size + break + else: + # No bucket fits, use actual max. + max_prompt_len = actual_max + else: + # pad_length is an int. + max_prompt_len = pad_length # In multi-host, each host read different data, so sync to the max length # across all hosts. max_prompt_len = _max_across_hosts(max_prompt_len)