Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 45 additions & 10 deletions benchmarks/bench_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,39 @@ def bench_fn(fn, warmup, iters, repeats):
return mean, mn, mx


def run_case(seq_lens, H, D, warmup, iters, repeats):
def format_seq_lens(seq_lens):
if len(seq_lens) > 4 and len(set(seq_lens)) == 1:
return f"[{seq_lens[0]}] * {len(seq_lens)}"
if len(seq_lens) > 4 and len(set(seq_lens[1:])) == 1:
return f"[{seq_lens[0]}] + [{seq_lens[1]}] * {len(seq_lens) - 1}"
return str(seq_lens)


def run_case(seq_lens, H, D, warmup, iters, repeats, varlen_metadata_label,
use_varlen_metadata):
device = torch.device("cuda")
LOWER_BOUND = -5.0
scale_float = 1.0 / math.sqrt(D)

varlen = len(seq_lens) > 1
T_total = sum(seq_lens)
N = len(seq_lens)
cu_seqlens = None

if varlen:
cu_seqlens = torch.tensor(
[0] + list(torch.cumsum(torch.tensor(seq_lens), dim=0).tolist()),
dtype=torch.long, device=device,
)
print(f"varlen shape=[{T_total},{H},{D}] seq_lens={seq_lens} warmup={warmup} iters={iters} repeats={repeats}")
extra = {"cu_seqlens": cu_seqlens}
print(
f"varlen shape=[{T_total},{H},{D}] seq_lens={format_seq_lens(seq_lens)} "
f"use_varlen_metadata={varlen_metadata_label} warmup={warmup} "
f"iters={iters} repeats={repeats}"
)
else:
print(f"shape=[{T_total},{H},{D}] warmup={warmup} iters={iters} repeats={repeats}")
extra = {}

varlen_kwargs = {"cu_seqlens": cu_seqlens} if varlen else {}

q = F.normalize(torch.randn((1, T_total, H, D), dtype=torch.float32, device=device), p=2, dim=-1).to(torch.bfloat16)
k = F.normalize(torch.randn((1, T_total, H, D), dtype=torch.float32, device=device), p=2, dim=-1).to(torch.bfloat16)
Expand All @@ -67,15 +81,19 @@ def run_case(seq_lens, H, D, warmup, iters, repeats):
def run_flash_kda():
flash_kda.fwd(q, k, v, g, beta, scale, out,
A_log=A_log, dt_bias=dt_bias, lower_bound=LOWER_BOUND,
initial_state=initial_state, final_state=final_state, **extra)
initial_state=initial_state, final_state=final_state,
cu_seqlens=cu_seqlens,
use_varlen_metadata=use_varlen_metadata)

mean, mn, mx = bench_fn(run_flash_kda, warmup, iters, repeats)
print(f" flash_kda (bf16 state) : mean={mean:.4f} ms, min={mn:.4f} ms, max={mx:.4f} ms")

# --- flash_kda: no state ---
def run_flash_kda_no_state():
flash_kda.fwd(q, k, v, g, beta, scale, out,
A_log=A_log, dt_bias=dt_bias, lower_bound=LOWER_BOUND, **extra)
A_log=A_log, dt_bias=dt_bias, lower_bound=LOWER_BOUND,
cu_seqlens=cu_seqlens,
use_varlen_metadata=use_varlen_metadata)

mean, mn, mx = bench_fn(run_flash_kda_no_state, warmup, iters, repeats)
print(f" flash_kda (no state) : mean={mean:.4f} ms, min={mn:.4f} ms, max={mx:.4f} ms")
Expand All @@ -87,7 +105,9 @@ def run_flash_kda_no_state():
def run_flash_kda_fp32():
flash_kda.fwd(q, k, v, g, beta, scale, out,
A_log=A_log, dt_bias=dt_bias, lower_bound=LOWER_BOUND,
initial_state=initial_state_fp32, final_state=final_state_fp32, **extra)
initial_state=initial_state_fp32, final_state=final_state_fp32,
cu_seqlens=cu_seqlens,
use_varlen_metadata=use_varlen_metadata)

mean, mn, mx = bench_fn(run_flash_kda_fp32, warmup, iters, repeats)
print(f" flash_kda (fp32 state) : mean={mean:.4f} ms, min={mn:.4f} ms, max={mx:.4f} ms")
Expand All @@ -107,7 +127,7 @@ def run_chunk_kda():
A_log=A_log, dt_bias=dt_bias,
lower_bound=LOWER_BOUND,
transpose_state_layout=True,
**extra,
**varlen_kwargs,
)

mean, mn, mx = bench_fn(run_chunk_kda, warmup, iters, repeats)
Expand All @@ -125,7 +145,7 @@ def run_chunk_gated_delta_rule():
output_final_state=True,
use_qk_l2norm_in_kernel=True,
transpose_state_layout=True,
**extra,
**varlen_kwargs,
)

mean, mn, mx = bench_fn(run_chunk_gated_delta_rule, warmup, iters, repeats)
Expand All @@ -139,8 +159,20 @@ def run_chunk_gated_delta_rule():
VARLEN_CASES = [
[1300, 547, 2048, 963, 271, 3063],
[1024] * 8,
[512] * 16,
[256] * 32,
[64] * 128,
[32] * 256,
[16] * 512,
[4096] + [8] * 512,
]

VARLEN_METADATA_OPTIONS = {
"default": None,
"on": True,
"off": False,
}


def main():
import argparse
Expand All @@ -151,6 +183,7 @@ def main():
p.add_argument("--mode", choices=["fixed", "varlen", "all"], default="all")
p.add_argument("--H", type=int, default=96)
p.add_argument("--D", type=int, default=128)
p.add_argument("--use-varlen-metadata", choices=["default", "on", "off"], default="default")
args = p.parse_args()

cases = []
Expand All @@ -160,7 +193,9 @@ def main():
cases.extend(VARLEN_CASES)

for seq_lens in cases:
run_case(seq_lens, args.H, args.D, args.warmup, args.iters, args.repeats)
run_case(seq_lens, args.H, args.D, args.warmup, args.iters, args.repeats,
args.use_varlen_metadata,
VARLEN_METADATA_OPTIONS[args.use_varlen_metadata])


if __name__ == "__main__":
Expand Down
85 changes: 64 additions & 21 deletions benchmarks/generate_benchmark_md.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
"""
Run ``bench_fwd.py`` twice (default ``H`` and ``--H 64``), parse stdout, and write
a benchmark markdown report.
Run ``bench_fwd.py`` for each requested varlen metadata mode at default ``H`` and
``--H 64``, parse stdout, and write a benchmark markdown report.

Reports mean latency for ``flash_kda (fp32 state)`` and ``fla_chunk_kda`` (FLA
``chunk_kda``), plus speedup ``chunk_mean / flash_mean``. Generated date is UTC,
Expand Down Expand Up @@ -38,7 +38,8 @@
r"^shape=\[(\d+),(\d+),(\d+)\] warmup=(\d+) iters=(\d+) repeats=(\d+)\s*$"
)
RE_HEADER_VARLEN = re.compile(
r"^varlen shape=\[(\d+),(\d+),(\d+)\] seq_lens=(\[[^\]]+\]) "
r"^varlen shape=\[(\d+),(\d+),(\d+)\] seq_lens=(.+?) "
r"(?:use_varlen_metadata=(\w+) )?"
r"warmup=(\d+) iters=(\d+) repeats=(\d+)\s*$"
)
RE_RESULT = re.compile(
Expand Down Expand Up @@ -102,7 +103,7 @@ def new_case_base(
if m:
if current is not None:
cases.append(current)
t, h, d, seq_lens, w, it, rep = m.groups()
t, h, d, seq_lens, varlen_metadata, w, it, rep = m.groups()
current = new_case_base(
"varlen",
T=int(t),
Expand All @@ -113,6 +114,7 @@ def new_case_base(
repeats=int(rep),
seq_lens=seq_lens,
)
current["varlen_metadata"] = varlen_metadata or "default"
continue

m = RE_HEADER_FIXED.match(line)
Expand Down Expand Up @@ -150,6 +152,12 @@ def new_case_base(

def _fmt_seq_lens(seq_lens_str: str) -> str:
"""Uniform segment lengths become ``1024 x 8``; mixed lists keep the bracket form."""
m = re.fullmatch(r"\[(\d+)\]\s*\+\s*\[(\d+)\]\s*\*\s*(\d+)", seq_lens_str)
if m:
return f"{m.group(1)} + {m.group(2)} x {m.group(3)}"
m = re.fullmatch(r"\[(\d+)\]\s*\*\s*(\d+)", seq_lens_str)
if m:
return f"{m.group(1)} x {m.group(2)}"
try:
xs = ast.literal_eval(seq_lens_str)
except (ValueError, SyntaxError):
Expand Down Expand Up @@ -202,6 +210,24 @@ def _argv_with_h(argv: list[str], h: int) -> list[str]:
return out


def _argv_with_varlen_metadata(argv: list[str], mode: str) -> list[str]:
"""Drop any ``--use-varlen-metadata`` from *argv*, then append *mode*."""
out: list[str] = []
i = 0
while i < len(argv):
a = argv[i]
if a == "--use-varlen-metadata" and i + 1 < len(argv):
i += 2
continue
if a.startswith("--use-varlen-metadata="):
i += 1
continue
out.append(a)
i += 1
out.extend(["--use-varlen-metadata", mode])
return out


def _complete_cases(raw: list[dict]) -> list[dict]:
return [
c
Expand Down Expand Up @@ -236,13 +262,13 @@ def _render_table_block(cases: list[dict]) -> list[str]:


def render_markdown(
sections: list[list[dict]],
sections: list[tuple[str, list[dict]]],
generated_at: str,
generator_cmd: str,
device_label: str,
) -> str:
"""
*sections*: one ``cases`` list per table (default ``H``, then ``H=64``).
*sections*: ``(label, cases)`` pairs, one per table.
*generator_cmd*: command that reproduces this report.
*device_label*: device/platform label printed in the report title.
"""
Expand All @@ -264,7 +290,7 @@ def render_markdown(
lines.append(f"- Command: `{generator_cmd}`")
lines.append("")

first_cases = next((c for c in sections if c), None)
first_cases = next((cases for _label, cases in sections if cases), None)
c0 = first_cases[0] if first_cases else None
if c0 is not None:
lines.append(
Expand All @@ -276,11 +302,14 @@ def render_markdown(
lines.append(FLA_CHUNK_GDN_OPTIONS_MD)
lines.append("")

for cases in sections:
for label, cases in sections:
if not cases:
continue
c0 = cases[0]
lines.append(f"### `T={c0['T']}`, `H={c0['H']}`, `D={c0['D']}`")
title = f"`T={c0['T']}`, `H={c0['H']}`, `D={c0['D']}`"
if label:
title += f", `{label}`"
lines.append(f"### {title}")
lines.append("")
lines.extend(_render_table_block(cases))

Expand All @@ -303,30 +332,44 @@ def main() -> None:
default=DEFAULT_DEVICE_LABEL,
help=f"Device/platform label for the report title (default: {DEFAULT_DEVICE_LABEL!r})",
)
p.add_argument(
"--varlen-metadata-modes",
default="default,off,on",
help=(
"Comma-separated --use-varlen-metadata modes to benchmark "
"(default: default,off,on)."
),
)
args, bench_extra = p.parse_known_args()
metadata_modes = [m.strip() for m in args.varlen_metadata_modes.split(",") if m.strip()]
invalid_modes = sorted(set(metadata_modes) - {"default", "on", "off"})
if invalid_modes:
p.error(f"invalid --varlen-metadata-modes value(s): {', '.join(invalid_modes)}")

def _fmt_generator_cmd(extra: list[str]) -> str:
cmd = "python benchmarks/generate_benchmark_md.py"
if args.output != DEFAULT_OUT:
cmd += f" -o {args.output}"
if args.device_label != DEFAULT_DEVICE_LABEL:
cmd += f" --device-label {args.device_label}"
if args.varlen_metadata_modes != "default,off,on":
cmd += f" --varlen-metadata-modes {args.varlen_metadata_modes}"
tail = " ".join(extra)
return f"{cmd} {tail}".strip() if tail else cmd

argv_default = list(bench_extra)
argv_h64 = _argv_with_h(bench_extra, 64)

stdout_a = run_bench(argv_default)
stdout_b = run_bench(argv_h64)
cases_a = _complete_cases(parse_stdout(stdout_a))
cases_b = _complete_cases(parse_stdout(stdout_b))

sections: list[list[dict]] = [cases_a, cases_b]

if not cases_a or not cases_b:
sections: list[tuple[str, list[dict]]] = []
for mode in metadata_modes:
argv_mode = _argv_with_varlen_metadata(bench_extra, mode)
for h in (None, 64):
argv = list(argv_mode) if h is None else _argv_with_h(argv_mode, h)
stdout = run_bench(argv)
cases = _complete_cases(parse_stdout(stdout))
label = f"use_varlen_metadata={mode}"
sections.append((label, cases))

if any(not cases for _label, cases in sections):
sys.stderr.write(
"Warning: missing complete benchmark rows for one or both runs "
"Warning: missing complete benchmark rows for one or more runs "
"(need fp32 state, fla_chunk_kda, and fla_chunk_gated_delta_rule "
"for each).\n"
)
Expand Down
21 changes: 19 additions & 2 deletions csrc/flash_kda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ void fwd(
double lower_bound,
std::optional<torch::Tensor> initial_state = std::nullopt,
std::optional<torch::Tensor> final_state = std::nullopt,
std::optional<torch::Tensor> cu_seqlens = std::nullopt
std::optional<torch::Tensor> cu_seqlens = std::nullopt,
std::optional<bool> use_varlen_metadata = std::nullopt
) {
TORCH_CHECK(q.is_cuda() && k.is_cuda() && v.is_cuda() && g.is_cuda() && beta.is_cuda() && out.is_cuda() && workspace.is_cuda(),
"all tensors must be on CUDA");
Expand Down Expand Up @@ -177,13 +178,28 @@ void fwd(
total_tiles = int(N_val * ((T_seq + CHUNK - 1) / CHUNK)); // exact for batched
}

torch::Tensor chunk_indices_t;
torch::Tensor chunk_offsets_t;
VarlenMetadata varlen_metadata;
bool build_varlen_metadata = is_varlen && use_varlen_metadata.value_or(
N_val >= kVarlenMetadataAutoMinSequences);

if (build_varlen_metadata) {
auto meta_options = q.options().dtype(torch::kInt32);
chunk_indices_t = torch::empty({total_tiles, 2}, meta_options);
chunk_offsets_t = torch::empty({N_val + 1}, meta_options);
varlen_metadata.chunk_indices = reinterpret_cast<int2*>(chunk_indices_t.data_ptr<int32_t>());
varlen_metadata.chunk_offsets = chunk_offsets_t.data_ptr<int32_t>();
}

// Dispatch based on state configuration and varlen
#define LAUNCH(HI, HO, FP32, VL) \
launch_fwd<128, HI, HO, FP32, VL>( \
q_ptr, k_ptr, v_ptr, g_ptr, beta_t_ptr, \
initial_state_raw, scale_f, final_state_raw, out_ptr, \
workspace_ptr, total_tiles, \
int(T_total), int(H), int(N_val), cu_seqlens_dev, \
varlen_metadata, \
A_log_ptr, dt_bias_ptr, gate_scale, stream)

#define DISPATCH_STATE(VL) \
Expand Down Expand Up @@ -220,7 +236,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("workspace"),
py::arg("A_log"), py::arg("dt_bias"), py::arg("lower_bound"),
py::arg("initial_state") = py::none(), py::arg("final_state") = py::none(),
py::arg("cu_seqlens") = py::none());
py::arg("cu_seqlens") = py::none(),
py::arg("use_varlen_metadata") = py::none());
m.def("get_workspace_size",
static_cast<int64_t(*)(int64_t, int64_t, int64_t)>(&get_workspace_size),
"Get workspace size in bytes",
Expand Down
14 changes: 14 additions & 0 deletions csrc/fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,19 @@

#include <cutlass/bfloat16.h>

struct VarlenMetadata {
int2* chunk_indices = nullptr;
int32_t* chunk_offsets = nullptr;

__host__ __device__ bool enabled() const {
return chunk_indices != nullptr && chunk_offsets != nullptr;
}
};

constexpr int kVarlenMetadataWarpSize = 32;
constexpr int kVarlenMetadataThreads = 256;
constexpr int kVarlenMetadataAutoMinSequences = 32;

template <int D, bool HasStateIn = true, bool HasStateOut = true, bool StateFP32 = false, bool IsVarlen = true>
void launch_fwd(
cutlass::bfloat16_t const* q_ptr,
Expand All @@ -20,6 +33,7 @@ void launch_fwd(
int H,
int N,
int64_t const* cu_seqlens_ptr,
VarlenMetadata varlen_metadata,
float const* A_log_ptr,
float const* dt_bias_ptr,
float gate_scale,
Expand Down
Loading