From 0b54ad53ab232c3da0fd9fb6c8d4e01c15e68e43 Mon Sep 17 00:00:00 2001 From: samreedh bhuyan Date: Thu, 4 Dec 2025 10:46:38 +0530 Subject: [PATCH] Fix pad_length tuple handling in _tokenize_prompts --- gemma/gm/text/_sampler.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/gemma/gm/text/_sampler.py b/gemma/gm/text/_sampler.py index 98e0f9d3..993866d7 100644 --- a/gemma/gm/text/_sampler.py +++ b/gemma/gm/text/_sampler.py @@ -407,15 +407,28 @@ def _tokenize_prompts( prompt: str | Sequence[str], *, add_bos: bool, - pad_length: int | None = None, + pad_length: int | tuple[int, ...] | None = None, ) -> Float['B L']: """Encode the prompts.""" prompt = _normalize_prompt(prompt) tokens = [self.tokenizer.encode(p, add_bos=add_bos) for p in prompt] - # Notice that if pad_length exceeds the maximum length of the prompts, - # an error will be raised by the `.pad` function below. - max_prompt_len = pad_length or max(len(t) for t in tokens) + # Calculate the maximum prompt length, handling pad_length buckets. + actual_max = max(len(t) for t in tokens) + if pad_length is None: + max_prompt_len = actual_max + elif isinstance(pad_length, tuple): + # Handle tuple buckets - pick smallest bucket that fits. + for bucket_size in pad_length: + if actual_max <= bucket_size: + max_prompt_len = bucket_size + break + else: + # No bucket fits, use actual max. + max_prompt_len = actual_max + else: + # pad_length is an int. + max_prompt_len = pad_length # In multi-host, each host read different data, so sync to the max length # across all hosts. max_prompt_len = _max_across_hosts(max_prompt_len)