diff --git a/kv_cache_benchmark/kv_cache/cache.py b/kv_cache_benchmark/kv_cache/cache.py index bf55a4ab..10eaa615 100755 --- a/kv_cache_benchmark/kv_cache/cache.py +++ b/kv_cache_benchmark/kv_cache/cache.py @@ -826,7 +826,8 @@ def access_cache(self, key: str, phase: InferencePhase = InferencePhase.DECODE, self.stats['storage_read_host_latencies'].append(timing.host) if self.model_config.kv_cache_size_per_token > 0: - num_tokens = entry_size / self.model_config.kv_cache_size_per_token + sharded_bytes_per_token = kv_cache_size_per_token / max(1, self.tensor_parallel) + num_tokens = entry_size / sharded_bytes_per_token self.stats['storage_tokens_processed'] += num_tokens return location, timing.total