From 149ecafd01975e77b61e26d930eb88953ee2dbfb Mon Sep 17 00:00:00 2001 From: arJ-V Date: Sat, 13 Dec 2025 02:01:06 -0500 Subject: [PATCH] Add cache usage property to ChatSampler and Sampler - Add cache_usage property to ChatSampler that returns (used, total) tuple - Add get_cache_usage() static method to Sampler for stateless usage - Provides easy way to monitor cache usage during multi-turn conversations - Returns None if no sampling has been performed yet - Includes example usage in docstrings --- gemma/gm/text/_chat_sampler.py | 26 +++++++++++++++++++++++++- gemma/gm/text/_sampler.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/gemma/gm/text/_chat_sampler.py b/gemma/gm/text/_chat_sampler.py index 672c2b93..72851c13 100644 --- a/gemma/gm/text/_chat_sampler.py +++ b/gemma/gm/text/_chat_sampler.py @@ -74,6 +74,8 @@ class ChatSampler: but exposed for power users to access the logits, cache, ... or initialize the sampler. turns: The current conversation. + cache_usage: Property that returns the current cache usage as (used, total) + tokens. Returns `None` if no sampling has been performed yet. """ # TODO(epot): Custom repr to avoid displaying the full weights. @@ -88,7 +90,6 @@ class ChatSampler: forbidden_tokens: Sequence[str | int] | None = None stop_tokens: Sequence[str | int] | None = None # TODO(epot): Support and test rolling cache. - # TODO(epot): Add a property to show how much of the cache is used. cache_length: int | None = 4096 max_out_length: int = 2048 @@ -127,6 +128,29 @@ def sampler(self) -> _sampler.Sampler: max_out_length=self.max_out_length, ) + @property + def cache_usage(self) -> tuple[int, int] | None: + """Returns the current cache usage as (used, total). + + Returns: + A tuple of (used_cache_length, total_cache_length) if sampling has been + performed, otherwise `None`. The `used_cache_length` includes the prompt, + previous turns, and generated tokens so far. + + Example: + ```python + sampler = ChatSampler(model=model, params=params) + sampler.chat('Hello') + used, total = sampler.cache_usage + print(f'Cache: {used}/{total} tokens ({100*used/total:.1f}%)') + ``` + """ + if self.last_state is None or self.cache_length is None: + return None + used = int(self.last_state.used_cache_length) + total = self.cache_length + return (used, total) + def chat( self, prompt: str, diff --git a/gemma/gm/text/_sampler.py b/gemma/gm/text/_sampler.py index 98e0f9d3..fe083441 100644 --- a/gemma/gm/text/_sampler.py +++ b/gemma/gm/text/_sampler.py @@ -128,6 +128,36 @@ class Sampler: max_out_length: int = 2048 pad_length: None | int | tuple[int, ...] = (256, 512, 1024) + @staticmethod + def get_cache_usage( + state: _sampler_loop.SamplingState | None, + cache_length: int, + ) -> tuple[int, int] | None: + """Returns the cache usage from a sampling state. + + Args: + state: The sampling state from a previous `sample()` call with + `return_state=True`. If `None`, returns `None`. + cache_length: The total cache length configured for the sampler. + + Returns: + A tuple of (used_cache_length, total_cache_length) if state is provided, + otherwise `None`. The `used_cache_length` includes the prompt and + generated tokens. + + Example: + ```python + sampler = Sampler(model=model, params=params) + output = sampler.sample('Hello', return_state=True) + used, total = Sampler.get_cache_usage(output.state, sampler.cache_length) + print(f'Cache: {used}/{total} tokens ({100*used/total:.1f}%)') + ``` + """ + if state is None: + return None + used = int(state.used_cache_length) + return (used, cache_length) + def __post_init__(self): # If not provided, initialize the tokenizer. if self.tokenizer is None: