Add flash attention primitives and examples#780
Draft
maleadt wants to merge 20 commits into
Draft
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #780 +/- ##
==========================================
- Coverage 80.77% 80.64% -0.14%
==========================================
Files 61 61
Lines 2866 2872 +6
==========================================
+ Hits 2315 2316 +1
- Misses 551 556 +5 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Wraps the tensor_inline form of Metal 4 tensor_ops. Kernel args stay
buffer-shaped MtlDeviceArrays — no host-side MTLTensor or MTL4 command
encoder wrapping is needed. The descriptors and matmul2d run helper
lower to the externally-defined __tensorops_impl_* family in the
MetalPerformancePrimitives runtime; AIR construction uses the
air.*_private_tensor intrinsics.
Per-thread descriptor storage is a Ref{NTuple{N, UInt8}} that
allocopt promotes to a stack alloca (every constructor is @inline'd
into the kernel, and the gc-managed object only escapes via
pointer_from_objref). Reference IR for the three kernel shapes
(handle / cooperative / inline) is in bin/{simple,coop,inline}_matmul.*.
Fourth path alongside the MPS / MPSGraph / simdgroup_matrix implementations. Builds tensor_inline views over the MtlDeviceArray inputs and dispatches the two matmuls through tensor_ops::matmul2d. The kernel stays buffer-shaped so the existing kernel ABI is unchanged. The forward pass is split across two dispatches (QK+softmax, then PV) to work around a Metal back-end crash on two __tensorops_impl_matmul2d_op_run_* calls in a single kernel. The scores tile is therefore materialized in device memory rather than fused into a cooperative_tensor. Limited to D == N == 64 and a single (head, batch) block; on macOS 26+.
Adds bf16 (Core.BFloat16), i8 (Int8), ui8 (UInt8), and i32 (Int32) to
the suffix dispatch table, covering the common dense-precision and
quantized matmul combinations from MPPTensorOpsMatMul2d.h. Verified
against a CPU reference for {f16, f16, f16}, {f16, f16, f32},
{f32, f32, f32}, {bf16, bf16, bf16}, {bf16, bf16, f32}, {f16, i8, f16},
and {i8, i8, i32}.
The 4-bit formats (i4/ui4) need a custom packed type and are skipped.
MtlInlineTensor gains an ASpace type parameter; constructors dispatch
between air.init_strided_private_tensor.i32.global (device data) and
the .local flavor (threadgroup data). tensor_ops_matmul2d! picks the
dv/tg prefix per operand to name the
__tensorops_impl_matmul2d_op_run_{aspace}_{type}_..._* symbol. That
lets us stage tiles in threadgroup memory between matmuls (e.g.
between QK and PV in attention) — verified against a CPU reference for
half × half → float with the left operand staged to threadgroup.
Apple's static_slice<> only works on tensor_handle, not tensor_inline, and building an inline tensor with static extents emits the same AIR as one with dynamic extents (same air.init_strided_private_tensor + runtime extents tuple). Encoding static extents in MtlInlineTensor's type would only shave a few bytes off the extents alloca without enabling any optimization, so we keep extents dynamic.
Exports the Matmul2DMode constants and adds a docstring example showing the K-loop pattern: zero C, then loop with mode = matmul2d_multiply_accumulate, slicing the K dimension. Keeps the loop trip count dynamic to avoid full unrolling into multiple tensor_ops_matmul2d! call sites — that hits Apple's back-end crash (see ISSUE-tensor-ops.md).
Apple-compiled MSL with two matmul2d calls builds a working pipeline state, so the crash is triggered specifically by our IR pattern: the matmul2d_descriptor ends up as a series of per-field stores rather than the memcpy-from-constant-global pattern Apple's compiler emits. Local reproducer + diff is in bugs/two_matmul_crash/ (gitignored).
C = A * B with natural Julia matrix-product semantics, dispatched as (M/tile_m, N/tile_n) threadgroups with a K-loop inside via multiply_accumulate. Each tile is one matmul2d call site, so the two-matmul-per-kernel back-end crash never fires. Tested against a CPU reference at (64,64,64), (64,128,64), (128,64,64), (128,128,64), (64,64,128) and (256,192,128) shapes — max relative error ~3e-7 (Float16 inputs, Float32 accumulator). The wrapper hides the matmul2d operand-swap from the user: matmul2d's output buffer is laid out as Julia's transpose of (M, N), so we put Julia's B in matmul2d's left slot and Julia's A in the right slot. Two swaps cancel and every operand uses packed strides.
Constructors now take an NTuple{R, Integer} of extents (and optionally
strides); view() takes NTuple{R} for origin and extents. Strides
default to packed (computed via the prefix product of extents). The
air.* intrinsics already take rank as an Int16, so the Julia wrappers
just forward through.
matmul2d itself is rank-2 only — but rank-3/4 tensors are useful for
slicing into matmul operands and for the future convolution2d op.
The per-thread descriptor buffer is bumped to 128 bytes to cover ranks
up to ~8 with i32 indices, since the dynamic-alloca pattern Apple
uses (`alloca i8, i64 %sz` keyed on get_descriptor_size_tensor) can't
be expressed via Julia's static typing.
Tested at rank 3 and rank 4: construct, slice, and read back extents
via air.get_extent_private_tensor — values round-trip exactly.
Mirrors Apple's MSL `matmul2d<desc, execution_simdgroups<N>>` template:
both the descriptor and the simdgroup count are encoded as type
parameters so the generated AIR carries them as compile-time constants.
The simdgroup count specifically: the AGX register allocator runs out of
stack registers ("LLVM ERROR: The shader is out of stack registers
space") when the simdgroup count is a runtime value and two matmul calls
live in the same kernel, which our old `tensor_ops_matmul2d!` signature
forced. Removing the runtime-threads entry point closes the footgun
since MSL has no equivalent.
The callable needs all three of `@device_function`, `@inline`, and
`@generated`: without `@device_function` the call site falls off the
GPUCompiler overlay table and downstream MtlInlineTensor/view calls
lose their stack-alloca lowering.
`_tensor_matmul_kernel!` now takes Val{TM}, Val{TN}, Val{TK}, Val{NSIMD};
both flashattention kernels likewise lift their tile shape into Val
parameters before dispatch.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…PV pass. Now that two matmul callsites in one kernel compile cleanly (via the TensorOpsMatmul2D wrapper), the split that worked around the back-end crash is no longer needed. The scores and softmaxed P tiles move from device MtlArrays to threadgroup memory, so there's no host-side scratch allocation and no device-memory round-trip between the two matmuls. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…loop. Previously the wrapper iterated over `(b, h)` on the host and submitted H*B separate kernel launches. Now a single dispatch with grid = (H, B) covers them all — each threadgroup reads its own `(h, b)` from `threadgroup_position_in_grid` and slices the 4-D buffers via pointer arithmetic. Heads run in parallel where the hardware can, no per-launch encoder overhead. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
312e4ba to
cf01652
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Somebody asked if Metal.jl supports the necessary primitives to implement flash attention, so I had Claude take a stab at it to see what's involved. Mostly vibe-coded, needs to be cleaned up and tested obviously.
This PR introduces a flash attention example, implemented in three ways:
scaledDotProductAttentionopsimdgroup_matrix intrinsicsThat needed a typed Julia wrapper around the existing simdgroup_matrix intrinsics (
MtlSimdgroupMatrix{T,8,8}), plus a binding forMPSGraph.scaledDotProductAttentionWithQueryTensor(macOS 14+).A fourth way would be to use the Metal 4 tensor ops (cooperative tensors +
tensor_ops::matmul2d, with the softmax epilogue fused into the matmul via postfix fusion), but those require macOS 26, which I don't have yet. So to be continued...