Skip to content

Add flash attention primitives and examples#780

Draft
maleadt wants to merge 20 commits into
mainfrom
tb/flashattention
Draft

Add flash attention primitives and examples#780
maleadt wants to merge 20 commits into
mainfrom
tb/flashattention

Conversation

@maleadt
Copy link
Copy Markdown
Member

@maleadt maleadt commented May 24, 2026

Somebody asked if Metal.jl supports the necessary primitives to implement flash attention, so I had Claude take a stab at it to see what's involved. Mostly vibe-coded, needs to be cleaned up and tested obviously.


This PR introduces a flash attention example, implemented in three ways:

  • using basic MPS ops (matmul + softmax via broadcasts)
  • using MPSGraph's fused scaledDotProductAttention op
  • using a custom kernel built on simdgroup_matrix intrinsics

That needed a typed Julia wrapper around the existing simdgroup_matrix intrinsics (MtlSimdgroupMatrix{T,8,8}), plus a binding for MPSGraph.scaledDotProductAttentionWithQueryTensor (macOS 14+).

A fourth way would be to use the Metal 4 tensor ops (cooperative tensors + tensor_ops::matmul2d, with the softmax epilogue fused into the matmul via postfix fusion), but those require macOS 26, which I don't have yet. So to be continued...

@codecov
Copy link
Copy Markdown

codecov Bot commented May 24, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 80.64%. Comparing base (706b87f) to head (696d968).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #780      +/-   ##
==========================================
- Coverage   80.77%   80.64%   -0.14%     
==========================================
  Files          61       61              
  Lines        2866     2872       +6     
==========================================
+ Hits         2315     2316       +1     
- Misses        551      556       +5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

maleadt and others added 13 commits May 26, 2026 15:57
Wraps the tensor_inline form of Metal 4 tensor_ops. Kernel args stay
buffer-shaped MtlDeviceArrays — no host-side MTLTensor or MTL4 command
encoder wrapping is needed. The descriptors and matmul2d run helper
lower to the externally-defined __tensorops_impl_* family in the
MetalPerformancePrimitives runtime; AIR construction uses the
air.*_private_tensor intrinsics.

Per-thread descriptor storage is a Ref{NTuple{N, UInt8}} that
allocopt promotes to a stack alloca (every constructor is @inline'd
into the kernel, and the gc-managed object only escapes via
pointer_from_objref). Reference IR for the three kernel shapes
(handle / cooperative / inline) is in bin/{simple,coop,inline}_matmul.*.
Fourth path alongside the MPS / MPSGraph / simdgroup_matrix
implementations. Builds tensor_inline views over the MtlDeviceArray
inputs and dispatches the two matmuls through tensor_ops::matmul2d.
The kernel stays buffer-shaped so the existing kernel ABI is unchanged.

The forward pass is split across two dispatches (QK+softmax, then PV)
to work around a Metal back-end crash on two
__tensorops_impl_matmul2d_op_run_* calls in a single kernel. The
scores tile is therefore materialized in device memory rather than
fused into a cooperative_tensor.

Limited to D == N == 64 and a single (head, batch) block; on macOS 26+.
Adds bf16 (Core.BFloat16), i8 (Int8), ui8 (UInt8), and i32 (Int32) to
the suffix dispatch table, covering the common dense-precision and
quantized matmul combinations from MPPTensorOpsMatMul2d.h. Verified
against a CPU reference for {f16, f16, f16}, {f16, f16, f32},
{f32, f32, f32}, {bf16, bf16, bf16}, {bf16, bf16, f32}, {f16, i8, f16},
and {i8, i8, i32}.

The 4-bit formats (i4/ui4) need a custom packed type and are skipped.
MtlInlineTensor gains an ASpace type parameter; constructors dispatch
between air.init_strided_private_tensor.i32.global (device data) and
the .local flavor (threadgroup data). tensor_ops_matmul2d! picks the
dv/tg prefix per operand to name the
__tensorops_impl_matmul2d_op_run_{aspace}_{type}_..._* symbol. That
lets us stage tiles in threadgroup memory between matmuls (e.g.
between QK and PV in attention) — verified against a CPU reference for
half × half → float with the left operand staged to threadgroup.
Apple's static_slice<> only works on tensor_handle, not tensor_inline,
and building an inline tensor with static extents emits the same AIR
as one with dynamic extents (same air.init_strided_private_tensor +
runtime extents tuple). Encoding static extents in MtlInlineTensor's
type would only shave a few bytes off the extents alloca without
enabling any optimization, so we keep extents dynamic.
Exports the Matmul2DMode constants and adds a docstring example
showing the K-loop pattern: zero C, then loop with
mode = matmul2d_multiply_accumulate, slicing the K dimension. Keeps
the loop trip count dynamic to avoid full unrolling into multiple
tensor_ops_matmul2d! call sites — that hits Apple's back-end crash
(see ISSUE-tensor-ops.md).
Apple-compiled MSL with two matmul2d calls builds a working pipeline
state, so the crash is triggered specifically by our IR pattern: the
matmul2d_descriptor ends up as a series of per-field stores rather
than the memcpy-from-constant-global pattern Apple's compiler emits.
Local reproducer + diff is in bugs/two_matmul_crash/ (gitignored).
C = A * B with natural Julia matrix-product semantics, dispatched as
(M/tile_m, N/tile_n) threadgroups with a K-loop inside via
multiply_accumulate. Each tile is one matmul2d call site, so the
two-matmul-per-kernel back-end crash never fires.

Tested against a CPU reference at (64,64,64), (64,128,64), (128,64,64),
(128,128,64), (64,64,128) and (256,192,128) shapes — max relative
error ~3e-7 (Float16 inputs, Float32 accumulator).

The wrapper hides the matmul2d operand-swap from the user: matmul2d's
output buffer is laid out as Julia's transpose of (M, N), so we put
Julia's B in matmul2d's left slot and Julia's A in the right slot.
Two swaps cancel and every operand uses packed strides.
Constructors now take an NTuple{R, Integer} of extents (and optionally
strides); view() takes NTuple{R} for origin and extents. Strides
default to packed (computed via the prefix product of extents). The
air.* intrinsics already take rank as an Int16, so the Julia wrappers
just forward through.

matmul2d itself is rank-2 only — but rank-3/4 tensors are useful for
slicing into matmul operands and for the future convolution2d op.

The per-thread descriptor buffer is bumped to 128 bytes to cover ranks
up to ~8 with i32 indices, since the dynamic-alloca pattern Apple
uses (`alloca i8, i64 %sz` keyed on get_descriptor_size_tensor) can't
be expressed via Julia's static typing.

Tested at rank 3 and rank 4: construct, slice, and read back extents
via air.get_extent_private_tensor — values round-trip exactly.
Mirrors Apple's MSL `matmul2d<desc, execution_simdgroups<N>>` template:
both the descriptor and the simdgroup count are encoded as type
parameters so the generated AIR carries them as compile-time constants.
The simdgroup count specifically: the AGX register allocator runs out of
stack registers ("LLVM ERROR: The shader is out of stack registers
space") when the simdgroup count is a runtime value and two matmul calls
live in the same kernel, which our old `tensor_ops_matmul2d!` signature
forced. Removing the runtime-threads entry point closes the footgun
since MSL has no equivalent.

The callable needs all three of `@device_function`, `@inline`, and
`@generated`: without `@device_function` the call site falls off the
GPUCompiler overlay table and downstream MtlInlineTensor/view calls
lose their stack-alloca lowering.

`_tensor_matmul_kernel!` now takes Val{TM}, Val{TN}, Val{TK}, Val{NSIMD};
both flashattention kernels likewise lift their tile shape into Val
parameters before dispatch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…PV pass.

Now that two matmul callsites in one kernel compile cleanly (via the
TensorOpsMatmul2D wrapper), the split that worked around the back-end
crash is no longer needed. The scores and softmaxed P tiles move from
device MtlArrays to threadgroup memory, so there's no host-side scratch
allocation and no device-memory round-trip between the two matmuls.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…loop.

Previously the wrapper iterated over `(b, h)` on the host and submitted
H*B separate kernel launches. Now a single dispatch with grid = (H, B)
covers them all — each threadgroup reads its own `(h, b)` from
`threadgroup_position_in_grid` and slices the 4-D buffers via pointer
arithmetic. Heads run in parallel where the hardware can, no per-launch
encoder overhead.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@maleadt maleadt force-pushed the tb/flashattention branch from 312e4ba to cf01652 Compare May 26, 2026 17:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant