diff --git a/examples/flashattention.jl b/examples/flashattention.jl new file mode 100644 index 000000000..4a8e803a2 --- /dev/null +++ b/examples/flashattention.jl @@ -0,0 +1,356 @@ +# Flash Attention reference implementations on Apple Silicon. +# +# Four ways to spell scaled dot-product attention on Metal, illustrating +# the programming models Metal.jl exposes: +# +# attention_mps(Q, K, V) +# The trivial baseline. Uses standard Julia operators (`*`, +# broadcasting, `maximum`, `sum`, `exp`) on `MtlArray`. The matrix +# multiplies are dispatched to MPSGraph / MPSMatrixMultiplication by +# `src/linalg.jl`; the rest is GPUArrays. Not actually a Flash +# Attention algorithm — the full N×N scores matrix is materialized +# in device memory — but it's the right reference and the fastest +# path to "attention runs on GPU" when you don't need a custom +# kernel. Works on macOS 13+ / M1+. +# +# attention_mpsgraph(Q, K, V) +# The high-level MPS path. Builds a one-node MPSGraph using +# `scaledDotProductAttentionWithQueryTensor` (macOS 14+), which +# fuses Q·Kᵀ → scale → softmax → ·V into a single op. Apple uses +# the same op as the backbone of their own SDPA paths (MLX falls +# back to it; Core ML lowers attention to it), so it's the closest +# thing to "ask Apple for attention" that Metal.jl can give you. +# +# attention_simdgroup(Q, K, V) +# A single-block scaled dot-product attention kernel built from +# `MtlSimdgroupMatrix{Float16, 8, 8}`. One simdgroup of 32 lanes +# does the QKᵀ and PV matrix multiplies via two `simdgroup_matrix` +# ops; the row-wise softmax is done in scalar code through +# threadgroup memory. Limited to N = D = 8, single head, single +# batch — illustrative, not production. See +# https://github.com/philipturner/metal-flash-attention for a +# tuned reference. Works on macOS 13+ / M1+. +# +# attention_tensor(Q, K, V) +# One fused kernel (QKᵀ → softmax → ·V) using the Metal 4 +# `tensor_ops::matmul2d` primitives. Single dispatch with grid = +# (H, B), one threadgroup per (head, batch) pair, so all heads +# run in parallel — the kernel reads its own `(h, b)` from +# `threadgroup_position_in_grid`. The kernel builds +# `tensor_inline` views over the `MtlDeviceArray` inputs, so the +# kernel signature stays buffer-shaped — no host-side `MTLTensor` +# / `MTL4ComputeCommandEncoder` wrapping is needed. The matmuls +# lower to externally-defined `__tensorops_impl_matmul2d_op_*` +# symbols (linked from the MetalPerformancePrimitives runtime), +# not `air.*` intrinsics. Scratch for the scores and softmaxed P +# lives in threadgroup memory for the lifetime of the kernel — no +# device-memory round-trip between the two matmuls. Requires +# macOS 26+; on M3/M4 the runtime still lowers to the same +# simdgroup MMA hardware. Limited to N = D = 64 because the +# matmul descriptor is specialized to that single 64×64 tile, and +# the two matmul callsites only avoid Apple's back-end +# "out of stack registers" crash when the `matmul2d` op is built +# through Metal.jl's `TensorOpsMatmul2D{DESC, NSIMD}` wrapper +# (descriptor + simdgroup count as type parameters, mirroring +# MSL's `matmul2d>`). +# `cooperative_tensor` would keep the scores tile in registers for +# true postfix-fusion, but the device-side dynamic-alloca support +# that requires isn't wired up yet. +# +# All implementations take Julia 4-D `(head_dim, seq, num_heads, batch)` +# inputs — MPSGraph sees these reversed as `(batch, num_heads, seq, +# head_dim)`, the layout Apple's SDPA expects. + +using Metal +using Test + +using Metal.MPS: MPSCommandBuffer, commit!, wait_completed +using Metal.MPSGraphs: MPSGraph, MPSGraphTensor, MPSGraphTensorData, + placeholderTensor, scaledDotProductAttentionWithQueryTensor, + encode!, default_exec_desc +using Metal.Foundation: NSDictionary, nil + + +## MPS / GPUArrays path + +function attention_mps(Q::MtlArray{T,4}, K::MtlArray{T,4}, V::MtlArray{T,4}; + scale = inv(sqrt(T(size(Q, 1))))) where {T} + _, _, H, B = size(Q) + O = similar(Q) + for b in 1:B, h in 1:H + Qm, Km, Vm = Q[:, :, h, b], K[:, :, h, b], V[:, :, h, b] + S = (transpose(Qm) * Km) .* scale + S = S .- maximum(S; dims = 2) + P = exp.(S) + P = P ./ sum(P; dims = 2) + O[:, :, h, b] = Vm * transpose(P) + end + return O +end + + +## MPSGraph SDPA path + +function attention_mpsgraph(Q::MtlArray{T,4}, K::MtlArray{T,4}, V::MtlArray{T,4}; + scale = inv(sqrt(T(size(Q, 1))))) where {T} + O = similar(Q) + + graph = MPSGraph() + qph = placeholderTensor(graph, size(Q), T) + kph = placeholderTensor(graph, size(K), T) + vph = placeholderTensor(graph, size(V), T) + out = scaledDotProductAttentionWithQueryTensor(graph, qph, kph, vph, + Float32(scale)) + + feeds = Dict{MPSGraphTensor, MPSGraphTensorData}( + qph => MPSGraphTensorData(Q), + kph => MPSGraphTensorData(K), + vph => MPSGraphTensorData(V), + ) + results = Dict{MPSGraphTensor, MPSGraphTensorData}( + out => MPSGraphTensorData(O), + ) + + cmdbuf = MPSCommandBuffer(Metal.global_queue(device())) + encode!(cmdbuf, graph, NSDictionary(feeds), NSDictionary(results), nil, + default_exec_desc()) + commit!(cmdbuf) + wait_completed(cmdbuf) + return O +end + + +## Custom kernel with MtlSimdgroupMatrix + +function _fa_kernel!(O::AbstractMatrix{Float16}, + Q::AbstractMatrix{Float16}, + K_t::AbstractMatrix{Float16}, + V::AbstractMatrix{Float16}, + scale::Float32) + Ss = MtlThreadGroupArray(Float32, (8, 8)) # scores, then P + Sh = MtlThreadGroupArray(Float16, (8, 8)) # P cast back to fp16 + + # 1. S = Q · K_t (single 8x8 simdgroup_matrix multiply) + Qm = simdgroup_load(MtlSimdgroupMatrix{Float16, 8, 8}, Q) + Km = simdgroup_load(MtlSimdgroupMatrix{Float16, 8, 8}, K_t) + Sm = Qm * Km + + # 2. Spill to threadgroup memory for the row-wise softmax in scalar code. + Sh_tmp = MtlThreadGroupArray(Float16, (8, 8)) + simdgroup_store(Sm, Sh_tmp) + threadgroup_barrier(Metal.MemoryFlagThreadGroup) + + tid = Int(thread_index_in_threadgroup()) - 1 # 0..31 + @inbounds for k in 0:1 + idx = tid * 2 + k + r = idx ÷ 8 + 1 + c = idx % 8 + 1 + Ss[r, c] = Float32(Sh_tmp[r, c]) * scale + end + threadgroup_barrier(Metal.MemoryFlagThreadGroup) + + # 3. Row-wise softmax. 8 of 32 lanes do real work. + if tid < 8 + m = -Inf32 + @inbounds for j in 1:8 + v = Ss[tid + 1, j] + m = v > m ? v : m + end + s = 0.0f0 + @inbounds for j in 1:8 + p = exp(Ss[tid + 1, j] - m) + Ss[tid + 1, j] = p + s += p + end + inv_s = 1.0f0 / s + @inbounds for j in 1:8 + Sh[tid + 1, j] = Float16(Ss[tid + 1, j] * inv_s) + end + end + threadgroup_barrier(Metal.MemoryFlagThreadGroup) + + # 4. O = P · V (second 8x8 simdgroup_matrix multiply) + Pm = simdgroup_load(MtlSimdgroupMatrix{Float16, 8, 8}, Sh) + Vm = simdgroup_load(MtlSimdgroupMatrix{Float16, 8, 8}, V) + Om = Pm * Vm + + simdgroup_store(Om, O) + return +end + +function attention_simdgroup(Q::MtlArray{Float16,4}, K::MtlArray{Float16,4}, + V::MtlArray{Float16,4}; + scale = inv(sqrt(Float32(size(Q, 1))))) + @assert size(Q) == size(K) == size(V) == (8, 8, 1, 1) "simdgroup kernel only handles (D=8, N=8, H=1, B=1)" + + # Inputs are (D, N, 1, 1). Kernel works with Q and V in (N, D), and K_t in + # (D, N) — the latter is the user-facing K already transposed, so reshape + # of the (D, N) slice gives us K_t for free. + Q2 = permutedims(reshape(Q, 8, 8), (2, 1)) + V2 = permutedims(reshape(V, 8, 8), (2, 1)) + K_t = reshape(K, 8, 8) + O2 = similar(Q2) + + Metal.@sync @metal threads = 32 _fa_kernel!(O2, Q2, K_t, V2, Float32(scale)) + return reshape(permutedims(O2, (2, 1)), 8, 8, 1, 1) +end + + +## Custom kernel with Metal 4 tensor ops (matmul2d, inline tensors) + +# One fused kernel per (head, batch): QKᵀ → softmax → ·V, with scores and +# softmaxed P kept in threadgroup memory. The matmul writes its (M, N) output +# in a layout that Julia reads as KᵀQ (the transpose of QᵀK), so we apply a +# *column*-wise softmax — that's what corresponds to row-wise softmax of the +# implicit QᵀK, and it's the right direction for column-major contiguous +# memory access. +function _fa_tensor!(O::MtlDeviceArray{Float16, 4}, + Q::MtlDeviceArray{Float16, 4}, + K::MtlDeviceArray{Float16, 4}, + V::MtlDeviceArray{Float16, 4}, + scale::Float32, + ::Val{TD}, ::Val{TN}, + ::Val{NSIMD}) where {TD, TN, NSIMD} + # One threadgroup per (head, batch) pair. + tgid = threadgroup_position_in_grid_3d() + h = Int32(tgid.x) - Int32(1) + b = Int32(tgid.y) - Int32(1) + tid = Int32(thread_position_in_threadgroup_3d().x) - Int32(1) + + # Pointer arithmetic for the (h, b) slice of each 4-D buffer. + H = Int32(size(Q, 3)) + DN = Int32(TD) * Int32(TN) + slice_first = (b * H + h) * DN + Int32(1) + Qb = MtlDeviceArray{Float16, 2, Metal.AS.Device}((Int32(TD), Int32(TN)), pointer(Q, slice_first)) + Kb = MtlDeviceArray{Float16, 2, Metal.AS.Device}((Int32(TD), Int32(TN)), pointer(K, slice_first)) + Vb = MtlDeviceArray{Float16, 2, Metal.AS.Device}((Int32(TD), Int32(TN)), pointer(V, slice_first)) + Ob = MtlDeviceArray{Float16, 2, Metal.AS.Device}((Int32(TD), Int32(TN)), pointer(O, slice_first)) + + # Scratch lives in threadgroup memory for the entire kernel: scores tile + # (Float32 for accumulator precision) and the softmaxed P (Float16 for the + # second matmul). + S = MtlThreadGroupArray(Float32, (TN, TN)) + P = MtlThreadGroupArray(Float16, (TN, TN)) + + # Step 1: S = QᵀK (read as KᵀQ in Julia layout, see above). + let tA = MtlInlineTensor(Qb, (Int32(TD), Int32(TN))), + tB = MtlInlineTensor(Kb, (Int32(TD), Int32(TN))), + tC = MtlInlineTensor(S, (Int32(TN), Int32(TN))) + op = TensorOpsMatmul2D{matmul2d_descriptor(TN, TN, TD; + transpose_right = true), + Int32(NSIMD)}() + op(tA, tB, tC) + end + threadgroup_barrier(Metal.MemoryFlagThreadGroup) + + # Step 2: column-wise softmax. TN of (NSIMD*32) threads do real work; the + # rest wait at the barrier below. + @inbounds if tid < Int32(TN) + col = tid + Int32(1) + m = -Inf32 + for i in Int32(1):Int32(TN) + v = S[i, col] * scale + m = v > m ? v : m + end + s = 0.0f0 + for i in Int32(1):Int32(TN) + p = exp(S[i, col] * scale - m) + S[i, col] = p + s += p + end + inv_s = 1.0f0 / s + for i in Int32(1):Int32(TN) + P[i, col] = Float16(S[i, col] * inv_s) + end + end + threadgroup_barrier(Metal.MemoryFlagThreadGroup) + + # Step 3: O = V·P (Julia view; equivalent to V·Pᵀ in math notation because + # the softmax output is stored in the transposed layout). + let tA = MtlInlineTensor(P, (Int32(TN), Int32(TN))), + tB = MtlInlineTensor(Vb, (Int32(TD), Int32(TN))), + tC = MtlInlineTensor(Ob, (Int32(TD), Int32(TN))) + op = TensorOpsMatmul2D{matmul2d_descriptor(TN, TD, TN), Int32(NSIMD)}() + op(tA, tB, tC) + end + return +end + +function attention_tensor(Q::MtlArray{Float16,4}, K::MtlArray{Float16,4}, + V::MtlArray{Float16,4}; + scale = inv(sqrt(Float32(size(Q, 1))))) + @assert size(Q) == size(K) == size(V) + D, N, H, B = size(Q) + # The matmul descriptor below is specialized to (N, N, D); allowing other + # shapes would mean dispatching multiple threadgroups. + @assert D == N "tensor-ops kernel currently expects D == N" + O = similar(Q) + + simdgroup_size = 32 + nsimd = 4 # matches `execution_simdgroups<4>` in the op desc + threads = nsimd * simdgroup_size + + # Single dispatch covering all (head, batch) pairs: one threadgroup each, + # grid = (H, B). The kernel uses `threadgroup_position_in_grid` to pick its + # slice. The matmul descriptors carry (TN, TD) — the static tile shape per + # head. + Metal.@sync @metal threads = threads groups = (H, B, 1) _fa_tensor!( + O, Q, K, V, Float32(scale), + Val(D), Val(N), Val(nsimd)) + return O +end + + +## CPU reference + driver + +function attention_cpu(Q, K, V; scale = inv(sqrt(eltype(Q)(size(Q, 1))))) + D, N_q, H, B = size(Q) + O = similar(Q) + for b in 1:B, h in 1:H + Qm, Km, Vm = Q[:, :, h, b], K[:, :, h, b], V[:, :, h, b] + S = (Qm' * Km) .* scale + S .-= maximum(S; dims = 2) + P = exp.(S) + P ./= sum(P; dims = 2) + O[:, :, h, b] = Vm * P' + end + return O +end + +function main() + T = Float16 # simdgroup path requires fp16 + + # The simdgroup kernel is locked to 8x8 tiles, and the tensor-ops kernel + # uses a 64x64 matmul descriptor. Run each at its natural shape. + let D = N = 8 + Q = MtlArray(randn(T, D, N, 1, 1)) + K = MtlArray(randn(T, D, N, 1, 1)) + V = MtlArray(randn(T, D, N, 1, 1)) + + O_cpu = attention_cpu(Array(Q), Array(K), Array(V)) + O_mps = attention_mps(Q, K, V) + O_mpsgraph = attention_mpsgraph(Q, K, V) + O_simdgroup = attention_simdgroup(Q, K, V) + + @test Array(O_mps) ≈ O_cpu rtol = 1e-2 + @test Array(O_mpsgraph) ≈ O_cpu rtol = 1e-2 + @test Array(O_simdgroup) ≈ O_cpu rtol = 1e-2 + end + + if Metal.macos_version() >= v"26.0.0" + let D = N = 64 + Q = MtlArray(randn(T, D, N, 1, 1)) + K = MtlArray(randn(T, D, N, 1, 1)) + V = MtlArray(randn(T, D, N, 1, 1)) + + O_cpu = attention_cpu(Array(Q), Array(K), Array(V)) + O_mps = attention_mps(Q, K, V) + O_tensor = attention_tensor(Q, K, V) + + @test Array(O_mps) ≈ O_cpu rtol = 1e-2 + @test Array(O_tensor) ≈ O_cpu rtol = 1e-2 + end + end +end + +isinteractive() || main() diff --git a/lib/mpsgraphs/operations.jl b/lib/mpsgraphs/operations.jl index 107c9ae31..96ffe8c52 100644 --- a/lib/mpsgraphs/operations.jl +++ b/lib/mpsgraphs/operations.jl @@ -74,3 +74,28 @@ Dumps the `graph`. This function is undocumented from Apple so it may stop working at any time. """ dump_graph(graph::MPSGraph) = @objc [graph::id{MPSGraph} dump]::Nothing ## COV_EXCL_LINE + +# Scaled dot-product attention: softmax((Q · K^T) * scale [+ mask]) · V, fused. +# Available macOS 14.0+. +function scaledDotProductAttentionWithQueryTensor(graph::MPSGraph, Q::MPSGraphTensor, + K::MPSGraphTensor, V::MPSGraphTensor, + scale::Real, name = "sdpa") + obj = @objc [graph::id{MPSGraph} scaledDotProductAttentionWithQueryTensor:Q::id{MPSGraphTensor} + keyTensor:K::id{MPSGraphTensor} + valueTensor:V::id{MPSGraphTensor} + scale:scale::Cfloat + name:name::id{NSString}]::id{MPSGraphTensor} + MPSGraphTensor(obj) +end +function scaledDotProductAttentionWithQueryTensor(graph::MPSGraph, Q::MPSGraphTensor, + K::MPSGraphTensor, V::MPSGraphTensor, + mask::MPSGraphTensor, scale::Real, + name = "sdpa") + obj = @objc [graph::id{MPSGraph} scaledDotProductAttentionWithQueryTensor:Q::id{MPSGraphTensor} + keyTensor:K::id{MPSGraphTensor} + valueTensor:V::id{MPSGraphTensor} + maskTensor:mask::id{MPSGraphTensor} + scale:scale::Cfloat + name:name::id{NSString}]::id{MPSGraphTensor} + MPSGraphTensor(obj) +end diff --git a/src/Metal.jl b/src/Metal.jl index ecc403a73..f6ab9af37 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -40,6 +40,7 @@ include("device/intrinsics/math.jl") include("device/intrinsics/synchronization.jl") include("device/intrinsics/memory.jl") include("device/intrinsics/simd.jl") +include("device/intrinsics/tensor.jl") include("device/intrinsics/atomics.jl") include("device/malloc.jl") include("device/random.jl") @@ -70,6 +71,7 @@ include("utilities.jl") include("broadcast.jl") include("mapreduce.jl") include("accumulate.jl") +include("tensor.jl") include("indexing.jl") include("random.jl") include("fft.jl") diff --git a/src/compiler/compilation.jl b/src/compiler/compilation.jl index bbf5150cf..21ce31907 100644 --- a/src/compiler/compilation.jl +++ b/src/compiler/compilation.jl @@ -10,6 +10,15 @@ GPUCompiler.method_table(::MetalCompilerJob) = method_table GPUCompiler.kernel_state_type(job::MetalCompilerJob) = KernelState +# Metal 4 tensor-ops lower to externally-defined `__tensorops_impl_*` symbols +# resolved at metallib link time against the MetalPerformancePrimitives +# runtime — not `air.*` intrinsics, so the upstream `startswith(fn, "air.")` +# check in `GPUCompiler/src/metal.jl` rejects them. +GPUCompiler.isintrinsic(job::MetalCompilerJob, fn::String) = + invoke(GPUCompiler.isintrinsic, + Tuple{CompilerJob{MetalCompilerTarget}, String}, job, fn) || + startswith(fn, "__tensorops_impl_") + function GPUCompiler.finish_module!(@nospecialize(job::MetalCompilerJob), mod::LLVM.Module, entry::LLVM.Function) entry = invoke(GPUCompiler.finish_module!, diff --git a/src/device/intrinsics/simd.jl b/src/device/intrinsics/simd.jl index 368227ac2..30f3ce6ff 100644 --- a/src/device/intrinsics/simd.jl +++ b/src/device/intrinsics/simd.jl @@ -1,4 +1,5 @@ export simdgroup_load, simdgroup_store, simdgroup_multiply, simdgroup_multiply_accumulate, + MtlSimdgroupMatrix, simd_shuffle_down, simd_shuffle_up, simd_shuffle_and_fill_down, simd_shuffle_and_fill_up, simd_shuffle, simd_shuffle_xor, simd_ballot, simd_vote_all, simd_vote_any @@ -85,6 +86,76 @@ Returns `a * b + c`. """ simdgroup_multiply_accumulate +## Typed wrapper + +""" + MtlSimdgroupMatrix{T,R,C} + +Typed wrapper around a SIMD-group matrix fragment. `T` is the element type +(`Float16` or `Float32`); `R` and `C` are the matrix dimensions. Only the +8×8 shape is supported by current Apple GPUs. + +The fragment data is distributed across the 32 lanes of a SIMD-group; the +per-lane element layout is implementation-defined and elements cannot be +accessed directly. To inspect or modify individual entries, store the +matrix to device or threadgroup memory first. + +Construct via [`simdgroup_load`](@ref), [`zero`](@ref) or the explicit +fill constructor `MtlSimdgroupMatrix{T,8,8}(val::T)`. +""" +struct MtlSimdgroupMatrix{T,R,C} + data::NTuple{64, VecElement{T}} + + global _unsafe_wrap_simdgroup_matrix(::Type{MtlSimdgroupMatrix{T,R,C}}, + data::NTuple{64, VecElement{T}}) where {T,R,C} = + new{T,R,C}(data) +end + +Base.size(::Type{<:MtlSimdgroupMatrix{<:Any,R,C}}) where {R,C} = (R, C) +Base.size(m::MtlSimdgroupMatrix) = size(typeof(m)) +Base.eltype(::Type{<:MtlSimdgroupMatrix{T}}) where {T} = T +Base.eltype(m::MtlSimdgroupMatrix) = eltype(typeof(m)) + +# Fill constructor: materialize a fragment whose elements are all `val`. +@inline function MtlSimdgroupMatrix{T,8,8}(val::T) where {T} + return _unsafe_wrap_simdgroup_matrix(MtlSimdgroupMatrix{T,8,8}, + ntuple(_ -> VecElement{T}(val), Val(64))) +end + +@inline Base.zero(::Type{MtlSimdgroupMatrix{T,8,8}}) where {T} = + MtlSimdgroupMatrix{T,8,8}(zero(T)) + +# Load: build a fragment from a device or threadgroup array tile. +@device_function @inline function simdgroup_load(::Type{MtlSimdgroupMatrix{T,8,8}}, + src::MtlDeviceArray{T}, + matrix_origin::NTuple{2, Int64} = (1, 1)) where {T} + return _unsafe_wrap_simdgroup_matrix(MtlSimdgroupMatrix{T,8,8}, + simdgroup_load(src, matrix_origin)) +end + +# Store: write the fragment back to a device or threadgroup array tile. +@device_function @inline function simdgroup_store(m::MtlSimdgroupMatrix{T,8,8}, + dest::MtlDeviceArray{T}, + matrix_origin::NTuple{2, Int64} = (1, 1)) where {T} + return simdgroup_store(m.data, dest, matrix_origin) +end + +# Multiply: D = A * B. +@inline function Base.:(*)(a::MtlSimdgroupMatrix{T,8,8}, + b::MtlSimdgroupMatrix{T,8,8}) where {T} + return _unsafe_wrap_simdgroup_matrix(MtlSimdgroupMatrix{T,8,8}, + simdgroup_multiply(a.data, b.data)) +end + +# Fused multiply-add: D = A * B + C. +@inline function Base.muladd(a::MtlSimdgroupMatrix{T,8,8}, + b::MtlSimdgroupMatrix{T,8,8}, + c::MtlSimdgroupMatrix{T,8,8}) where {T} + return _unsafe_wrap_simdgroup_matrix(MtlSimdgroupMatrix{T,8,8}, + simdgroup_multiply_accumulate(a.data, b.data, c.data)) +end + + ## SIMD Shuffle Up/Down simd_shuffle_map = ((Float32, "f32"), diff --git a/src/device/intrinsics/tensor.jl b/src/device/intrinsics/tensor.jl new file mode 100644 index 000000000..bbd2fcb57 --- /dev/null +++ b/src/device/intrinsics/tensor.jl @@ -0,0 +1,316 @@ +export MtlInlineTensor, matmul2d_descriptor, TensorOpsMatmul2D, + matmul2d_multiply, matmul2d_multiply_accumulate, tensor_matmul! + +using Core: LLVMPtr +using BFloat16s: BFloat16 + +# Wrappers for Metal 4 tensor-ops / `mpp::tensor_ops` device-side APIs. +# +# The host-bound `tensor_handle` form needs opaque kernel arguments and a +# different ABI (see `ISSUE-tensor-ops.md`). The `tensor_inline` form, which +# is what this file targets, constructs the tensor on the kernel stack from +# a buffer pointer and a set of extents/strides — kernel signature stays the +# same as a plain `MtlDeviceArray` kernel. +# +# Each tensor descriptor lives in a per-thread byte buffer held by a +# `Ref{NTuple{N, UInt8}}`. Julia's `llvm-alloc-opt` pass promotes the Ref to +# a stack alloca once everything is inlined into the kernel. + +# Conservative upper bound on the size of an i32-indexed tensor descriptor. +# Apple's compiler emits a dynamic `alloca i8, i64 %sz` where `%sz` comes +# from `air.get_descriptor_size_tensor` (marked deferred-static-alloca-size +# so it's resolved at metallib build time). We use one static size that +# covers the ranks we care about (≤ 8 for i32-indexed tensors); higher +# ranks would need a bigger buffer. +const _TENSOR_DESCRIPTOR_SIZE = 128 + +const _TensorDescriptorStorage = Base.RefValue{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}} + + +## Tensor descriptor primitives (`air.*` intrinsics). + +# Returns the per-thread tensor descriptor size for the given rank/index-size. +@device_function get_descriptor_size_tensor(rank::Int16, index_size::Int16) = + ccall("extern air.get_descriptor_size_tensor", llvmcall, + Int16, (Int16, Int16), rank, index_size) + +# Build an `i32`-indexed strided tensor view over a device-memory buffer. +# Ccall arg types use `Ref{T}` so that NTuple values passed in get auto-boxed +# into temporary Refs (via `cconvert(::Type{Ref{T}}, ::NTuple{N,T})` from +# `base/refpointer.jl`). `llvm-alloc-opt` promotes those temporaries to +# stack allocas. The element type must match what we pass: a mismatched +# `Ref{T}` would force ccall to emit a `jl_f_throw_methoderror` path, and +# the dead heap alloc on that path defeats the promotion. +@device_function init_strided_tensor_device!( + handle::_TensorDescriptorStorage, + rank::Int16, + data::LLVMPtr{UInt8, AS.Device}, + extents::NTuple{<:Any, Int32}, + strides::NTuple{<:Any, Int32}, + contiguous::Int8, +) = ccall("extern air.init_strided_private_tensor.i32.global", llvmcall, + Cvoid, + (Ref{UInt8}, Int16, LLVMPtr{UInt8, AS.Device}, + Ref{Int32}, Ref{Int32}, Int8), + handle, rank, data, extents, strides, contiguous) + +@device_function init_strided_tensor_threadgroup!( + handle::_TensorDescriptorStorage, + rank::Int16, + data::LLVMPtr{UInt8, AS.ThreadGroup}, + extents::NTuple{<:Any, Int32}, + strides::NTuple{<:Any, Int32}, + contiguous::Int8, +) = ccall("extern air.init_strided_private_tensor.i32.local", llvmcall, + Cvoid, + (Ref{UInt8}, Int16, LLVMPtr{UInt8, AS.ThreadGroup}, + Ref{Int32}, Ref{Int32}, Int8), + handle, rank, data, extents, strides, contiguous) + +@device_function get_extent_private_tensor(handle::_TensorDescriptorStorage, + rank::Int16, dim::Int16) = + ccall("extern air.get_extent_private_tensor.i32", llvmcall, + Int32, (Ref{UInt8}, Int16, Int16), + handle, rank, dim) + +@device_function slice_private_tensor!( + dst::_TensorDescriptorStorage, + src::_TensorDescriptorStorage, + rank::Int16, + origin::NTuple{<:Any, Int32}, + extents::NTuple{<:Any, Int32}, +) = ccall("extern air.slice_private_tensor_private_tensor.s.i32", llvmcall, + Cvoid, + (Ref{UInt8}, Ref{UInt8}, Int16, Ref{Int32}, Ref{Int32}), + dst, src, rank, origin, extents) + + +## High-level inline-tensor wrapper. + +""" + MtlInlineTensor{T, R, ASpace} + +Kernel-stack tensor view over an `MtlDeviceArray` or `MtlThreadGroupArray`. +`T` is the element type, `R` the rank, `ASpace` the address space of the +underlying data (`AS.Device` or `AS.ThreadGroup`). Backed by a thread- +private byte buffer (an inline `Ref`) that the runtime initializes at +construction. + +Extents follow the MSL `dextents{e1, e2, ...}` convention +(innermost dimension first), which is the row-major view the matmul kernel +expects. For a Julia column-major `MtlMatrix(M, N)`, pass extents `(M, N)` +if you want to treat columns as the inner dimension. + +`tensor_ops::matmul2d` itself is rank-2 — higher-rank tensors are useful +for slicing, multi-batch lifts, and the future `convolution2d` op. +""" +struct MtlInlineTensor{T, R, ASpace} + storage::_TensorDescriptorStorage +end + +# `tensor_inline` packed-stride strides: stride(0) = 1, stride(k) = prod(extents[1:k]). +@inline function _packed_strides(extents::NTuple{R, Int32}) where {R} + ntuple(Val(R)) do k + s = Int32(1) + for j in 1:(k - 1) + s *= extents[j] + end + s + end +end + +# In-kernel constructors (packed strides): +@device_function @inline function MtlInlineTensor{T, R, AS.Device}( + data::MtlDeviceArray{T, <:Any, AS.Device}, + extents::NTuple{R, <:Integer}) where {T, R} + e = Int32.(extents) + storage = Ref{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}}() + init_strided_tensor_device!(storage, Int16(R), + reinterpret(LLVMPtr{UInt8, AS.Device}, pointer(data)), + e, _packed_strides(e), Int8(1)) + return MtlInlineTensor{T, R, AS.Device}(storage) +end + +@device_function @inline function MtlInlineTensor{T, R, AS.ThreadGroup}( + data::MtlDeviceArray{T, <:Any, AS.ThreadGroup}, + extents::NTuple{R, <:Integer}) where {T, R} + e = Int32.(extents) + storage = Ref{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}}() + init_strided_tensor_threadgroup!(storage, Int16(R), + reinterpret(LLVMPtr{UInt8, AS.ThreadGroup}, pointer(data)), + e, _packed_strides(e), Int8(1)) + return MtlInlineTensor{T, R, AS.ThreadGroup}(storage) +end + +# Explicit-stride variants: +@device_function @inline function MtlInlineTensor{T, R, AS.Device}( + data::MtlDeviceArray{T, <:Any, AS.Device}, + extents::NTuple{R, <:Integer}, + strides::NTuple{R, <:Integer}) where {T, R} + storage = Ref{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}}() + init_strided_tensor_device!(storage, Int16(R), + reinterpret(LLVMPtr{UInt8, AS.Device}, pointer(data)), + Int32.(extents), Int32.(strides), Int8(0)) + return MtlInlineTensor{T, R, AS.Device}(storage) +end + +@device_function @inline function MtlInlineTensor{T, R, AS.ThreadGroup}( + data::MtlDeviceArray{T, <:Any, AS.ThreadGroup}, + extents::NTuple{R, <:Integer}, + strides::NTuple{R, <:Integer}) where {T, R} + storage = Ref{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}}() + init_strided_tensor_threadgroup!(storage, Int16(R), + reinterpret(LLVMPtr{UInt8, AS.ThreadGroup}, pointer(data)), + Int32.(extents), Int32.(strides), Int8(0)) + return MtlInlineTensor{T, R, AS.ThreadGroup}(storage) +end + +# Convenience: infer rank and address space from the inputs. +@inline MtlInlineTensor(data::MtlDeviceArray{T, <:Any, A}, + extents::NTuple{R, <:Integer}) where {T, R, A} = + MtlInlineTensor{T, R, A}(data, extents) + +@inline MtlInlineTensor(data::MtlDeviceArray{T, <:Any, A}, + extents::NTuple{R, <:Integer}, + strides::NTuple{R, <:Integer}) where {T, R, A} = + MtlInlineTensor{T, R, A}(data, extents, strides) + +Base.eltype(::Type{<:MtlInlineTensor{T}}) where {T} = T +Base.eltype(::MtlInlineTensor{T}) where {T} = T + +# Slice. `origin` is 1-based to match Julia indexing (the underlying +# `air.slice_private_tensor` intrinsic uses 0-based MSL coordinates). +@device_function @inline function Base.view( + t::MtlInlineTensor{T, R, A}, + origin::NTuple{R, <:Integer}, + extents::NTuple{R, <:Integer}) where {T, R, A} + storage = Ref{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}}() + slice_private_tensor!(storage, t.storage, Int16(R), + Int32.(origin) .- Int32(1), Int32.(extents)) + return MtlInlineTensor{T, R, A}(storage) +end + + +## matmul2d descriptor (mirrors `mpp::tensor_ops::matmul2d_descriptor`). + +@enum Matmul2DMode::Int32 begin + matmul2d_multiply = 0 + matmul2d_multiply_accumulate = 1 +end + +""" + matmul2d_descriptor(m, n, [k]; transpose_left=false, transpose_right=false, + relaxed_precision=false, mode=matmul2d_multiply) + +Configuration descriptor for a `tensor_ops::matmul2d` operation. `k` +defaults to `-1` (dynamic — inferred from the input tensors at runtime). +Layout matches the 20-byte `mpp::tensor_ops::matmul2d_descriptor` POD. + +For an outer `K`-loop where each iteration accumulates a partial product +into the destination, set `mode = matmul2d_multiply_accumulate` and zero +the destination before the loop. A typical pattern: + +```julia +op = TensorOpsMatmul2D{matmul2d_descriptor(M, N, TileK; + mode = matmul2d_multiply_accumulate), + Int32(NSIMD)}() +for s in 0:(nslices - 1) + sA = view(tA, (Int32(s) * Int32(TileK) + Int32(1), Int32(1)), (Int32(TileK), Int32(M))) + sB = view(tB, (Int32(1), Int32(s) * Int32(TileK) + Int32(1)), (Int32(N), Int32(TileK))) + op(sA, sB, tC) +end +``` + +Keep the loop's trip count dynamic — a compile-time-known trip count +that fully unrolls into multiple op call sites currently crashes Apple's +back-end (see `ISSUE-tensor-ops.md`). +""" +struct matmul2d_descriptor + m::Int32 + n::Int32 + k::Int32 + transpose_left::Int8 + transpose_right::Int8 + relaxed_precision::Int8 + matmul_mode::Matmul2DMode +end + +matmul2d_descriptor(m::Integer, n::Integer, k::Integer = -1; + transpose_left::Bool = false, + transpose_right::Bool = false, + relaxed_precision::Bool = false, + mode::Matmul2DMode = matmul2d_multiply) = + matmul2d_descriptor(Int32(m), Int32(n), Int32(k), + Int8(transpose_left), Int8(transpose_right), + Int8(relaxed_precision), mode) + + +## matmul2d run (inline-tensor → inline-tensor variant). + +const _TENSOR_DESC_INLINE = Int32(2) # `__tensor_ops_tensor_descriptor_type::tensor_inline` + +# Element-type suffix for `__tensorops_impl_matmul2d_op_run_*` symbols. +# The 4-bit integer formats (`i4`, `ui4`) aren't exposed yet — Julia has no +# native 4-bit integer type. `int32` is only valid as the destination of +# an `i8`/`ui8` × `i4`/`ui4` matmul. +_tensorops_suffix(::Type{Float16}) = "f16" +_tensorops_suffix(::Type{Float32}) = "f32" +_tensorops_suffix(::Type{BFloat16}) = "b16" +_tensorops_suffix(::Type{Int8}) = "i8" +_tensorops_suffix(::Type{UInt8}) = "ui8" +_tensorops_suffix(::Type{Int32}) = "i32" + +# Address-space prefix for the run helpers (`dv` for device, `tg` for threadgroup). +_tensorops_aspace_prefix(::Val{AS.Device}) = "dv" +_tensorops_aspace_prefix(::Val{AS.ThreadGroup}) = "tg" + +""" + TensorOpsMatmul2D{DESC, NSIMD} + +Configured `tensor_ops::matmul2d` op. Mirrors Apple's MSL +`matmul2d>` template: `DESC` is the +[`matmul2d_descriptor`](@ref) value, `NSIMD` is the simdgroup count +(`execution_simdgroups`). Both are encoded as type parameters so the +generated AIR carries them as compile-time constants — `NSIMD` +specifically: the AGX register allocator runs out of stack registers +when the simdgroup count is a runtime value and two matmul calls live +in the same kernel. + +Construct with [`TensorOpsMatmul2D(desc, Val(N))`](@ref) and invoke +like a function: + +```julia +op = TensorOpsMatmul2D(matmul2d_descriptor(64, 32, -1), Val(4)) +op(left, right, dest) # run; mirrors `op.run(...)` in MSL +``` +""" +struct TensorOpsMatmul2D{DESC, NSIMD} end + +TensorOpsMatmul2D(desc::matmul2d_descriptor, ::Val{NSIMD}) where {NSIMD} = + TensorOpsMatmul2D{desc, Int32(NSIMD)}() + +@device_function @inline @generated function (::TensorOpsMatmul2D{DESC, NSIMD})( + left::MtlInlineTensor{TL, 2, AL}, + right::MtlInlineTensor{TR, 2, AR}, + dest::MtlInlineTensor{TD, 2, AD}) where {DESC, NSIMD, TL, TR, TD, AL, AR, AD} + sym = "__tensorops_impl_matmul2d_op_run" * + "_$(_tensorops_aspace_prefix(Val(AL)))_$(_tensorops_suffix(TL))" * + "_$(_tensorops_aspace_prefix(Val(AR)))_$(_tensorops_suffix(TR))" * + "_$(_tensorops_aspace_prefix(Val(AD)))_$(_tensorops_suffix(TD))" + quote + threads = Int32(NSIMD) * Int32(threads_per_simdgroup()) + ccall($"extern $sym", llvmcall, Cvoid, + (Ref{matmul2d_descriptor}, + Ref{UInt8}, Int32, + Ref{UInt8}, Int32, + Ref{UInt8}, Int32, + Int32), + $(QuoteNode(DESC)), + left.storage, $_TENSOR_DESC_INLINE, + right.storage, $_TENSOR_DESC_INLINE, + dest.storage, $_TENSOR_DESC_INLINE, + threads) + return nothing + end +end + diff --git a/src/tensor.jl b/src/tensor.jl new file mode 100644 index 000000000..2d2e68a17 --- /dev/null +++ b/src/tensor.jl @@ -0,0 +1,87 @@ +# High-level tile-decomposed matmul on top of `tensor_ops::matmul2d`. + +# Specialized on the tile shape (TM, TN, TK) and simdgroup count (NSIMD) so the +# matmul descriptor and execution width are compile-time constants — see +# `TensorOpsMatmul2D`. +function _tensor_matmul_kernel!(C::MtlDeviceArray, A::MtlDeviceArray, B::MtlDeviceArray, + M::UInt32, N::UInt32, K::UInt32, + ::Val{TM}, ::Val{TN}, ::Val{TK}, + ::Val{NSIMD}) where {TM, TN, TK, NSIMD} + tgid = threadgroup_position_in_grid_3d() + n_tile = Int32(tgid.x) - Int32(1) + m_tile = Int32(tgid.y) - Int32(1) + n_off = n_tile * Int32(TN) + Int32(1) + m_off = m_tile * Int32(TM) + Int32(1) + + tA = MtlInlineTensor(B, (K, M)) + tB = MtlInlineTensor(A, (N, K)) + tC = MtlInlineTensor(C, (N, M)) + + mC = view(tC, (n_off, m_off), (Int32(TN), Int32(TM))) + + op = TensorOpsMatmul2D{matmul2d_descriptor(TM, TN, TK; + mode = matmul2d_multiply_accumulate), + Int32(NSIMD)}() + nslices = Int32(K ÷ UInt32(TK)) + for s in Int32(0):(nslices - Int32(1)) + k_off = s * Int32(TK) + Int32(1) + mA = view(tA, (k_off, m_off), (Int32(TK), Int32(TM))) + mB = view(tB, (n_off, k_off), (Int32(TN), Int32(TK))) + op(mA, mB, mC) + end + return +end + +""" + tensor_matmul!(C, A, B; tile_m=64, tile_n=64, tile_k=32) + +Compute `C = A * B` (Julia matrix-product semantics) using +`tensor_ops::matmul2d` with a tile decomposition. `A` is `(m, k)`, `B` is +`(k, n)`, `C` is `(m, n)`, all column-major `MtlMatrix`. `C` is zeroed +before the matmul (the kernel accumulates). + +`tile_m`, `tile_n`, `tile_k` set the per-threadgroup tile shape. The +matrix dimensions must be evenly divisible by their respective tiles +(`m % tile_m == 0`, `n % tile_n == 0`, `k % tile_k == 0`). Each +threadgroup uses 4 SIMD-groups (128 threads on the M1/M2 hardware) and +covers one `(tile_m, tile_n)` output tile; the K dimension is looped over +inside the kernel via `multiply_accumulate`. + +The implementation maps Julia's natural `C = A * B` to the matmul2d +operand convention by swapping `A`/`B` at the kernel level: matmul2d's +`apple_A` slot receives Julia's `B`, its `apple_B` slot receives Julia's +`A`. This lets every operand use packed strides without an explicit +transpose flag — matmul2d's storage order for the (M, N) output happens +to be the transpose of how Julia reads the same buffer column-major, so +two swaps cancel out and the user gets the answer they expect. + +Requires macOS 26+. +""" +function tensor_matmul!(C::MtlMatrix{TC}, A::MtlMatrix{TA}, B::MtlMatrix{TB}; + tile_m::Integer = 64, tile_n::Integer = 64, + tile_k::Integer = 32) where {TA, TB, TC} + m, k = size(A) + k2, n = size(B) + k == k2 || throw(DimensionMismatch( + "A is ($m, $k), B is ($k2, $n) — inner dims must match")) + size(C) == (m, n) || throw(DimensionMismatch( + "C is $(size(C)), expected ($m, $n)")) + + # Apple-side dims (see above for the swap derivation). + aM, aN, aK = n, m, k + aM % tile_m == 0 || throw(ArgumentError( + "tile_m=$tile_m must divide n=$n")) + aN % tile_n == 0 || throw(ArgumentError( + "tile_n=$tile_n must divide m=$m")) + aK % tile_k == 0 || throw(ArgumentError( + "tile_k=$tile_k must divide k=$k")) + + fill!(C, zero(TC)) + groups = (aN ÷ tile_n, aM ÷ tile_m, 1) + nsimd = 4 + @metal threads = nsimd * 32 groups = groups _tensor_matmul_kernel!( + C, A, B, + UInt32(aM), UInt32(aN), UInt32(aK), + Val(tile_m), Val(tile_n), Val(tile_k), Val(nsimd)) + return C +end diff --git a/test/device/intrinsics/simd.jl b/test/device/intrinsics/simd.jl index 684bb04bf..eb8d55fb0 100644 --- a/test/device/intrinsics/simd.jl +++ b/test/device/intrinsics/simd.jl @@ -288,4 +288,140 @@ end @metal threads=(8, 8) kernel(a, b, c, d) @test Array(a) * Array(b) + Array(c) ≈ Array(d) end + + @testset "MtlSimdgroupMatrix type" begin + @testset for T in (Float16, Float32) + @test eltype(MtlSimdgroupMatrix{T,8,8}) === T + @test size(MtlSimdgroupMatrix{T,8,8}) === (8, 8) + end + end + + @testset "MtlSimdgroupMatrix fill($T)" for T in (Float16, Float32) + function kernel(out::MtlDeviceMatrix{T}, val::T) where {T} + m = MtlSimdgroupMatrix{T,8,8}(val) + simdgroup_store(m, out) + return + end + + out = MtlArray(zeros(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(out, T(3.5)) + @test all(Array(out) .== T(3.5)) + end + + @testset "MtlSimdgroupMatrix zero($T)" for T in (Float16, Float32) + function kernel(out::MtlDeviceMatrix{T}) where {T} + m = zero(MtlSimdgroupMatrix{T,8,8}) + simdgroup_store(m, out) + return + end + + out = MtlArray(ones(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(out) + @test all(Array(out) .== zero(T)) + end + + @testset "MtlSimdgroupMatrix load_store($T)" for T in (Float16, Float32) + function kernel(a::MtlDeviceMatrix{T}, b::MtlDeviceMatrix{T}) where {T} + m = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, a) + simdgroup_store(m, b) + return + end + + a = MtlArray(rand(T, 8, 8)) + b = MtlArray(zeros(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(a, b) + @test Array(a) == Array(b) + end + + @testset "MtlSimdgroupMatrix load_store with origin($T)" for T in (Float16, Float32) + function kernel(a::MtlDeviceMatrix{T}, b::MtlDeviceMatrix{T}, + origin_a::NTuple{2,Int64}, origin_b::NTuple{2,Int64}) where {T} + m = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, a, origin_a) + simdgroup_store(m, b, origin_b) + return + end + + a = MtlArray(rand(T, 20, 15)) + b = MtlArray(zeros(T, 15, 20)) + Metal.@sync @metal threads=(8, 8) kernel(a, b, (4, 2), (3, 5)) + @test Array(a)[4:11, 2:9] == Array(b)[3:10, 5:12] + end + + @testset "MtlSimdgroupMatrix multiply($T)" for T in (Float16, Float32) + function kernel(a::MtlDeviceMatrix{T}, b::MtlDeviceMatrix{T}, c::MtlDeviceMatrix{T}) where {T} + ma = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, a) + mb = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, b) + simdgroup_store(ma * mb, c) + return + end + + a = MtlArray(rand(T, 8, 8)) + b = MtlArray(rand(T, 8, 8)) + c = MtlArray(zeros(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(a, b, c) + @test Array(a) * Array(b) ≈ Array(c) + end + + @testset "MtlSimdgroupMatrix muladd($T)" for T in (Float16, Float32) + function kernel(a::MtlDeviceMatrix{T}, b::MtlDeviceMatrix{T}, + c::MtlDeviceMatrix{T}, d::MtlDeviceMatrix{T}) where {T} + ma = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, a) + mb = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, b) + mc = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, c) + simdgroup_store(muladd(ma, mb, mc), d) + return + end + + a = MtlArray(rand(T, 8, 8)) + b = MtlArray(rand(T, 8, 8)) + c = MtlArray(rand(T, 8, 8)) + d = MtlArray(zeros(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(a, b, c, d) + @test Array(a) * Array(b) + Array(c) ≈ Array(d) + end + + # Composed K-loop GEMM: C(8×8) = A(8×K) * B(K×8) with K=32, accumulating + # four 8×8×8 fragment MMAs. + @testset "MtlSimdgroupMatrix K-loop GEMM($T)" for T in (Float16, Float32) + function kernel(A::MtlDeviceMatrix{T}, B::MtlDeviceMatrix{T}, C::MtlDeviceMatrix{T}) where {T} + acc = zero(MtlSimdgroupMatrix{T,8,8}) + for k in 0:3 + ma = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, A, (1, 1 + k*8)) + mb = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, B, (1 + k*8, 1)) + acc = muladd(ma, mb, acc) + end + simdgroup_store(acc, C) + return + end + + A = MtlArray(rand(T, 8, 32)) + B = MtlArray(rand(T, 32, 8)) + C = MtlArray(zeros(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(A, B, C) + @test Array(A) * Array(B) ≈ Array(C) rtol=sqrt(eps(T)) + end + + # Threadgroup-memory variant: stage tiles through threadgroup memory, then + # load fragments from there. Mirrors how Flash Attention stages K/V tiles. + @testset "MtlSimdgroupMatrix threadgroup load($T)" for T in (Float16, Float32) + function kernel(a::MtlDeviceMatrix{T}, b::MtlDeviceMatrix{T}) where {T} + pos = thread_position_in_threadgroup() + tg = MtlThreadGroupArray(T, (8, 8)) + tg[pos.x, pos.y] = a[pos.x, pos.y] + threadgroup_barrier(Metal.MemoryFlagThreadGroup) + + m = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, tg) + tg2 = MtlThreadGroupArray(T, (8, 8)) + simdgroup_store(m, tg2) + threadgroup_barrier(Metal.MemoryFlagThreadGroup) + + b[pos.x, pos.y] = tg2[pos.x, pos.y] + return + end + + a = MtlArray(rand(T, 8, 8)) + b = MtlArray(zeros(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(a, b) + @test Array(a) == Array(b) + end end # End Matrix Functions diff --git a/test/mpsgraphs/sdpa.jl b/test/mpsgraphs/sdpa.jl new file mode 100644 index 000000000..ab0231d39 --- /dev/null +++ b/test/mpsgraphs/sdpa.jl @@ -0,0 +1,85 @@ +if MPS.is_supported(device()) && Metal.macos_version() >= v"14" + +using .MPS: MPSCommandBuffer, commit!, wait_completed +using .MPSGraphs: MPSGraph, MPSGraphTensor, MPSGraphTensorData, + placeholderTensor, scaledDotProductAttentionWithQueryTensor, + encode!, default_exec_desc +using ObjectiveC.Foundation: NSDictionary, nil + +# Reference attention in 4-D `(head_dim, seq, num_heads, batch)` Julia layout. +# Mask (if provided) is `(N_kv, N_q, num_heads, batch)` — that's MPS's natural +# `(B, H, N_q, N_kv)` layout reversed for Julia col-major. +function _ref_attention(Q, K, V, scale, mask = nothing) + D, N_q, H, B = size(Q) + N_kv = size(K, 2) + O = similar(Q) + for b in 1:B, h in 1:H + Qm, Km, Vm = Q[:, :, h, b], K[:, :, h, b], V[:, :, h, b] + S = (Qm' * Km) .* scale + if mask !== nothing + S .+= transpose(mask[:, :, h, b]) + end + S .-= maximum(S; dims = 2) + P = exp.(S) + P ./= sum(P; dims = 2) + O[:, :, h, b] = Vm * P' + end + return O +end + +function _run_sdpa(Q, K, V, scale; mask = nothing) + O = similar(Q) + graph = MPSGraph() + qph = placeholderTensor(graph, size(Q), eltype(Q)) + kph = placeholderTensor(graph, size(K), eltype(K)) + vph = placeholderTensor(graph, size(V), eltype(V)) + out = if mask === nothing + scaledDotProductAttentionWithQueryTensor(graph, qph, kph, vph, Float32(scale)) + else + mph = placeholderTensor(graph, size(mask), eltype(mask)) + scaledDotProductAttentionWithQueryTensor(graph, qph, kph, vph, mph, + Float32(scale)) + end + + feeds = Dict{MPSGraphTensor, MPSGraphTensorData}( + qph => MPSGraphTensorData(Q), + kph => MPSGraphTensorData(K), + vph => MPSGraphTensorData(V), + ) + if mask !== nothing + feeds[graph.placeholderTensors[end]] = MPSGraphTensorData(mask) + end + results = Dict{MPSGraphTensor, MPSGraphTensorData}( + out => MPSGraphTensorData(O), + ) + + cmdbuf = MPSCommandBuffer(Metal.global_queue(device())) + encode!(cmdbuf, graph, NSDictionary(feeds), NSDictionary(results), nil, + default_exec_desc()) + commit!(cmdbuf) + wait_completed(cmdbuf) + return O +end + +@testset "scaled dot-product attention ($T)" for T in (Float16, Float32) + D, N_q, N_kv, H, B = 8, 12, 16, 2, 1 + Q = MtlArray(randn(T, D, N_q, H, B)) + K = MtlArray(randn(T, D, N_kv, H, B)) + V = MtlArray(randn(T, D, N_kv, H, B)) + scale = inv(sqrt(T(D))) + + @testset "no mask" begin + O = _run_sdpa(Q, K, V, scale) + O_ref = _ref_attention(Array(Q), Array(K), Array(V), scale) + @test Array(O) ≈ O_ref rtol = (T === Float16 ? 1e-2 : sqrt(eps(T))) + end + + @testset "with mask" begin + mask = MtlArray(randn(T, N_kv, N_q, H, B)) + O = _run_sdpa(Q, K, V, scale; mask) + O_ref = _ref_attention(Array(Q), Array(K), Array(V), scale, Array(mask)) + @test Array(O) ≈ O_ref rtol = (T === Float16 ? 1e-2 : sqrt(eps(T))) + end +end + +end # MPS.is_supported(device()) && macOS 14+