Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
415 changes: 415 additions & 0 deletions .pylintrc

Large diffs are not rendered by default.

37 changes: 15 additions & 22 deletions Ironwood/src/benchmark_attention.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,22 @@
"""A script to benchmark tokamax splash attention implementation.
"""A script to benchmark tokamax splash attention implementation."""

"""

import os

# pylint: disable=g-importing-member,g-bad-import-order
import dataclasses
from functools import partial
import logging
import os
from typing import Any, Callable, Dict, Tuple
import dataclasses

from benchmark_utils import timeit_from_trace, MetricsStatistics
from benchmark_utils import MetricsStatistics
from benchmark_utils import timeit_from_trace
import jax
import logging
from tokamax._src.ops.experimental.tpu.splash_attention import (
splash_attention_kernel as splash,
)
from tokamax._src.ops.experimental.tpu.splash_attention import (
splash_attention_mask as mask_lib,
)
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as splash
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as mask_lib
import tune_jax

tune_jax.tune_logger.setLevel(logging.ERROR)

# pylint: disable=g-importing-member,g-bad-import-order
os.environ["LIBTPU_INIT_ARGS"] = "--xla_tpu_dvfs_p_state=7"

os.environ["LIBTPU_INIT_ARGS"] = (
"--xla_tpu_dvfs_p_state=7"
)

def generate_qkv_separate_dims(
batch_size: int,
Expand All @@ -41,7 +32,9 @@ def generate_qkv_separate_dims(
key = jax.random.PRNGKey(seed)
key_q, key_k, key_v = jax.random.split(key, 3)
q = jax.random.normal(key_q, (batch_size, q_heads, q_seq_len, qk_head_dim))
k = jax.random.normal(key_k, (batch_size, kv_heads, kv_seq_len, qk_head_dim))
k = jax.random.normal(
key_k, (batch_size, kv_heads, kv_seq_len, qk_head_dim)
)
v = jax.random.normal(key_v, (batch_size, kv_heads, kv_seq_len, v_head_dim))
return q, k, v

Expand Down Expand Up @@ -141,7 +134,7 @@ def tokamax_splash_attention_benchmark(
# Attention mask
mask = mask_lib.FullMask(_shape=(q_seq_len, kv_seq_len))
if causal:
# Pick offset for causal masks for a "representative" slice of the causal
# Pick offset for causal masks for a representative slice of the causal
offset = v.shape[-2] - q.shape[-2]
mask = mask_lib.CausalMask(shape=(q_seq_len, kv_seq_len), offset=offset)

Expand Down Expand Up @@ -263,7 +256,7 @@ def attention_fn(
trace_dir=trace_dir,
event_name_str_list=[
f"{event_filter_regex}_no_residuals.1",
]
],
)
return {"time_ms_list": time_ms_list, "output": output}

Expand Down
Loading