Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion gemma/gm/text/_chat_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class ChatSampler:
but exposed for power users to access the logits, cache, ... or initialize
the sampler.
turns: The current conversation.
cache_usage: Property that returns the current cache usage as (used, total)
tokens. Returns `None` if no sampling has been performed yet.
"""
# TODO(epot): Custom repr to avoid displaying the full weights.

Expand All @@ -88,7 +90,6 @@ class ChatSampler:
forbidden_tokens: Sequence[str | int] | None = None
stop_tokens: Sequence[str | int] | None = None
# TODO(epot): Support and test rolling cache.
# TODO(epot): Add a property to show how much of the cache is used.
cache_length: int | None = 4096
max_out_length: int = 2048

Expand Down Expand Up @@ -127,6 +128,29 @@ def sampler(self) -> _sampler.Sampler:
max_out_length=self.max_out_length,
)

@property
def cache_usage(self) -> tuple[int, int] | None:
"""Returns the current cache usage as (used, total).

Returns:
A tuple of (used_cache_length, total_cache_length) if sampling has been
performed, otherwise `None`. The `used_cache_length` includes the prompt,
previous turns, and generated tokens so far.

Example:
```python
sampler = ChatSampler(model=model, params=params)
sampler.chat('Hello')
used, total = sampler.cache_usage
print(f'Cache: {used}/{total} tokens ({100*used/total:.1f}%)')
```
"""
if self.last_state is None or self.cache_length is None:
return None
used = int(self.last_state.used_cache_length)
total = self.cache_length
return (used, total)

def chat(
self,
prompt: str,
Expand Down
30 changes: 30 additions & 0 deletions gemma/gm/text/_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,36 @@ class Sampler:
max_out_length: int = 2048
pad_length: None | int | tuple[int, ...] = (256, 512, 1024)

@staticmethod
def get_cache_usage(
state: _sampler_loop.SamplingState | None,
cache_length: int,
) -> tuple[int, int] | None:
"""Returns the cache usage from a sampling state.

Args:
state: The sampling state from a previous `sample()` call with
`return_state=True`. If `None`, returns `None`.
cache_length: The total cache length configured for the sampler.

Returns:
A tuple of (used_cache_length, total_cache_length) if state is provided,
otherwise `None`. The `used_cache_length` includes the prompt and
generated tokens.

Example:
```python
sampler = Sampler(model=model, params=params)
output = sampler.sample('Hello', return_state=True)
used, total = Sampler.get_cache_usage(output.state, sampler.cache_length)
print(f'Cache: {used}/{total} tokens ({100*used/total:.1f}%)')
```
"""
if state is None:
return None
used = int(state.used_cache_length)
return (used, cache_length)

def __post_init__(self):
# If not provided, initialize the tokenizer.
if self.tokenizer is None:
Expand Down