From 21152594f039e09311a61963850e8ee47f89fdd9 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Sun, 24 May 2026 11:59:56 +0200 Subject: [PATCH 01/24] Add MtlSimdgroupMatrix wrapper for SIMD-group 8x8 fragments. --- src/Metal.jl | 1 + src/device/intrinsics/simdgroup_matrix.jl | 68 +++++++++++ test/device/intrinsics/simdgroup_matrix.jl | 135 +++++++++++++++++++++ 3 files changed, 204 insertions(+) create mode 100644 src/device/intrinsics/simdgroup_matrix.jl create mode 100644 test/device/intrinsics/simdgroup_matrix.jl diff --git a/src/Metal.jl b/src/Metal.jl index ecc403a73..7ed8567cb 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -40,6 +40,7 @@ include("device/intrinsics/math.jl") include("device/intrinsics/synchronization.jl") include("device/intrinsics/memory.jl") include("device/intrinsics/simd.jl") +include("device/intrinsics/simdgroup_matrix.jl") include("device/intrinsics/atomics.jl") include("device/malloc.jl") include("device/random.jl") diff --git a/src/device/intrinsics/simdgroup_matrix.jl b/src/device/intrinsics/simdgroup_matrix.jl new file mode 100644 index 000000000..2762797a5 --- /dev/null +++ b/src/device/intrinsics/simdgroup_matrix.jl @@ -0,0 +1,68 @@ +export MtlSimdgroupMatrix + +""" + MtlSimdgroupMatrix{T,R,C} + +Typed wrapper around a SIMD-group matrix fragment. `T` is the element type +(`Float16` or `Float32`); `R` and `C` are the matrix dimensions. Only the +8×8 shape is supported by current Apple GPUs. + +The fragment data is distributed across the 32 lanes of a SIMD-group; the +per-lane element layout is implementation-defined and elements cannot be +accessed directly. To inspect or modify individual entries, store the +matrix to device or threadgroup memory first. + +Construct via [`simdgroup_load`](@ref), [`zero`](@ref) or the explicit +fill constructor `MtlSimdgroupMatrix{T,8,8}(val::T)`. +""" +struct MtlSimdgroupMatrix{T,R,C} + data::NTuple{64, VecElement{T}} + + global _unsafe_wrap_simdgroup_matrix(::Type{MtlSimdgroupMatrix{T,R,C}}, + data::NTuple{64, VecElement{T}}) where {T,R,C} = + new{T,R,C}(data) +end + +Base.size(::Type{<:MtlSimdgroupMatrix{<:Any,R,C}}) where {R,C} = (R, C) +Base.size(m::MtlSimdgroupMatrix) = size(typeof(m)) +Base.eltype(::Type{<:MtlSimdgroupMatrix{T}}) where {T} = T +Base.eltype(m::MtlSimdgroupMatrix) = eltype(typeof(m)) + +# Fill constructor: materialize a fragment whose elements are all `val`. +@inline function MtlSimdgroupMatrix{T,8,8}(val::T) where {T} + return _unsafe_wrap_simdgroup_matrix(MtlSimdgroupMatrix{T,8,8}, + ntuple(_ -> VecElement{T}(val), Val(64))) +end + +@inline Base.zero(::Type{MtlSimdgroupMatrix{T,8,8}}) where {T} = + MtlSimdgroupMatrix{T,8,8}(zero(T)) + +# Load: build a fragment from a device or threadgroup array tile. +@device_function @inline function simdgroup_load(::Type{MtlSimdgroupMatrix{T,8,8}}, + src::MtlDeviceArray{T}, + matrix_origin::NTuple{2, Int64} = (1, 1)) where {T} + return _unsafe_wrap_simdgroup_matrix(MtlSimdgroupMatrix{T,8,8}, + simdgroup_load(src, matrix_origin)) +end + +# Store: write the fragment back to a device or threadgroup array tile. +@device_function @inline function simdgroup_store(m::MtlSimdgroupMatrix{T,8,8}, + dest::MtlDeviceArray{T}, + matrix_origin::NTuple{2, Int64} = (1, 1)) where {T} + return simdgroup_store(m.data, dest, matrix_origin) +end + +# Multiply: D = A * B. +@inline function Base.:(*)(a::MtlSimdgroupMatrix{T,8,8}, + b::MtlSimdgroupMatrix{T,8,8}) where {T} + return _unsafe_wrap_simdgroup_matrix(MtlSimdgroupMatrix{T,8,8}, + simdgroup_multiply(a.data, b.data)) +end + +# Fused multiply-add: D = A * B + C. +@inline function Base.muladd(a::MtlSimdgroupMatrix{T,8,8}, + b::MtlSimdgroupMatrix{T,8,8}, + c::MtlSimdgroupMatrix{T,8,8}) where {T} + return _unsafe_wrap_simdgroup_matrix(MtlSimdgroupMatrix{T,8,8}, + simdgroup_multiply_accumulate(a.data, b.data, c.data)) +end diff --git a/test/device/intrinsics/simdgroup_matrix.jl b/test/device/intrinsics/simdgroup_matrix.jl new file mode 100644 index 000000000..16922211e --- /dev/null +++ b/test/device/intrinsics/simdgroup_matrix.jl @@ -0,0 +1,135 @@ +@testset "type" begin + @testset for T in (Float16, Float32) + @test eltype(MtlSimdgroupMatrix{T,8,8}) === T + @test size(MtlSimdgroupMatrix{T,8,8}) === (8, 8) + end +end + +@testset "fill($T)" for T in (Float16, Float32) + function kernel(out::MtlDeviceMatrix{T}, val::T) where {T} + m = MtlSimdgroupMatrix{T,8,8}(val) + simdgroup_store(m, out) + return + end + + out = MtlArray(zeros(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(out, T(3.5)) + @test all(Array(out) .== T(3.5)) +end + +@testset "zero($T)" for T in (Float16, Float32) + function kernel(out::MtlDeviceMatrix{T}) where {T} + m = zero(MtlSimdgroupMatrix{T,8,8}) + simdgroup_store(m, out) + return + end + + out = MtlArray(ones(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(out) + @test all(Array(out) .== zero(T)) +end + +@testset "load_store($T)" for T in (Float16, Float32) + function kernel(a::MtlDeviceMatrix{T}, b::MtlDeviceMatrix{T}) where {T} + m = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, a) + simdgroup_store(m, b) + return + end + + a = MtlArray(rand(T, 8, 8)) + b = MtlArray(zeros(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(a, b) + @test Array(a) == Array(b) +end + +@testset "load_store with origin($T)" for T in (Float16, Float32) + function kernel(a::MtlDeviceMatrix{T}, b::MtlDeviceMatrix{T}, + origin_a::NTuple{2,Int64}, origin_b::NTuple{2,Int64}) where {T} + m = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, a, origin_a) + simdgroup_store(m, b, origin_b) + return + end + + a = MtlArray(rand(T, 20, 15)) + b = MtlArray(zeros(T, 15, 20)) + Metal.@sync @metal threads=(8, 8) kernel(a, b, (4, 2), (3, 5)) + @test Array(a)[4:11, 2:9] == Array(b)[3:10, 5:12] +end + +@testset "multiply($T)" for T in (Float16, Float32) + function kernel(a::MtlDeviceMatrix{T}, b::MtlDeviceMatrix{T}, c::MtlDeviceMatrix{T}) where {T} + ma = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, a) + mb = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, b) + simdgroup_store(ma * mb, c) + return + end + + a = MtlArray(rand(T, 8, 8)) + b = MtlArray(rand(T, 8, 8)) + c = MtlArray(zeros(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(a, b, c) + @test Array(a) * Array(b) ≈ Array(c) +end + +@testset "muladd($T)" for T in (Float16, Float32) + function kernel(a::MtlDeviceMatrix{T}, b::MtlDeviceMatrix{T}, + c::MtlDeviceMatrix{T}, d::MtlDeviceMatrix{T}) where {T} + ma = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, a) + mb = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, b) + mc = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, c) + simdgroup_store(muladd(ma, mb, mc), d) + return + end + + a = MtlArray(rand(T, 8, 8)) + b = MtlArray(rand(T, 8, 8)) + c = MtlArray(rand(T, 8, 8)) + d = MtlArray(zeros(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(a, b, c, d) + @test Array(a) * Array(b) + Array(c) ≈ Array(d) +end + +# Composed K-loop GEMM: C(8×8) = A(8×K) * B(K×8) with K=32, accumulating +# four 8×8×8 fragment MMAs. +@testset "K-loop GEMM($T)" for T in (Float16, Float32) + function kernel(A::MtlDeviceMatrix{T}, B::MtlDeviceMatrix{T}, C::MtlDeviceMatrix{T}) where {T} + acc = zero(MtlSimdgroupMatrix{T,8,8}) + for k in 0:3 + ma = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, A, (1, 1 + k*8)) + mb = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, B, (1 + k*8, 1)) + acc = muladd(ma, mb, acc) + end + simdgroup_store(acc, C) + return + end + + A = MtlArray(rand(T, 8, 32)) + B = MtlArray(rand(T, 32, 8)) + C = MtlArray(zeros(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(A, B, C) + @test Array(A) * Array(B) ≈ Array(C) rtol=sqrt(eps(T)) +end + +# Threadgroup-memory variant: stage tiles through threadgroup memory, then +# load fragments from there. Mirrors how Flash Attention stages K/V tiles. +@testset "threadgroup load($T)" for T in (Float16, Float32) + function kernel(a::MtlDeviceMatrix{T}, b::MtlDeviceMatrix{T}) where {T} + pos = thread_position_in_threadgroup() + tg = MtlThreadGroupArray(T, (8, 8)) + tg[pos.x, pos.y] = a[pos.x, pos.y] + threadgroup_barrier(Metal.MemoryFlagThreadGroup) + + m = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, tg) + tg2 = MtlThreadGroupArray(T, (8, 8)) + simdgroup_store(m, tg2) + threadgroup_barrier(Metal.MemoryFlagThreadGroup) + + b[pos.x, pos.y] = tg2[pos.x, pos.y] + return + end + + a = MtlArray(rand(T, 8, 8)) + b = MtlArray(zeros(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(a, b) + @test Array(a) == Array(b) +end From 3273ae4c48f87cf626484d973cced4b6de32104b Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Sun, 24 May 2026 12:15:55 +0200 Subject: [PATCH 02/24] Add Flash Attention examples (MPS + simdgroup_matrix kernel). --- examples/flash_attention/README.md | 75 +++++++++++++++ examples/flash_attention/fa_mps.jl | 44 +++++++++ examples/flash_attention/fa_simdgroup.jl | 115 +++++++++++++++++++++++ 3 files changed, 234 insertions(+) create mode 100644 examples/flash_attention/README.md create mode 100644 examples/flash_attention/fa_mps.jl create mode 100644 examples/flash_attention/fa_simdgroup.jl diff --git a/examples/flash_attention/README.md b/examples/flash_attention/README.md new file mode 100644 index 000000000..4d26513e7 --- /dev/null +++ b/examples/flash_attention/README.md @@ -0,0 +1,75 @@ +# Flash Attention examples + +Two reference implementations of scaled dot-product attention on Apple +Silicon GPUs from Julia, illustrating different programming models that +Metal.jl exposes. + +## `fa_mps.jl` + +The trivial baseline. Uses standard Julia operators (`*`, broadcasting, +`maximum`, `sum`, `exp`) on `MtlArray`. The matrix multiplications are +dispatched to **MPSGraph / MPSMatrixMultiplication** by +[`src/linalg.jl`](../../src/linalg.jl); the rest is GPUArrays. + +Not actually a Flash Attention algorithm — the full N×N scores matrix is +materialized in device memory — but it is the right reference to verify a +custom kernel against, and the fastest path to "attention runs on GPU" +when you don't need a custom kernel. + +Works on macOS 13+ / M1+. + +## `fa_simdgroup.jl` + +A single-block scaled dot-product attention kernel built from +`MtlSimdgroupMatrix{Float16, 8, 8}` (see `src/device/intrinsics/`). One +simdgroup of 32 lanes does the QKᵀ and PV matrix multiplies via two +`simdgroup_matrix` ops; the row-wise softmax is done in scalar code +through threadgroup memory. + +The example is intentionally minimal — Q, K, V are fixed at 8×8 — so the +control flow stays readable. A production implementation would: + + - sweep KV in blocks with online-softmax state (`m`, `l` per query row), + - tile D across multiple simdgroups, + - overlap loads with compute via `simdgroup_async_copy`, + - and split the backward pass into separate dQ / dKV kernels to avoid + FP32 atomics on Apple GPUs. + +See [philipturner/metal-flash-attention](https://github.com/philipturner/metal-flash-attention) +for a tuned reference (Swift + MSL, ~83 % ALU on M1 Max). + +K is host-transposed to `K_t` before launch — Metal.jl's `simdgroup_load` +issues a transposed-from-MSL load to compensate for Julia's column-major +storage, so `Q · K_t` in the kernel equals mathematical `Q · K^T`. +Exposing a `transpose=false` variant of `simdgroup_load` would let the +host transpose drop; that's a small follow-up to +`src/device/intrinsics/simd.jl`. + +Works on macOS 13+ / M1+. + +## Not included: `fa_metal4.jl` + +A third path would use the Metal 4 `cooperative_tensor` / +`tensor_ops::matmul2d` primitives with postfix-fusion of the softmax +epilogue. Apple positions this as the "preferred programming model for +ML applications" — on M5 hardware it can issue Neural-Accelerator MMAs +and skip threadgroup memory entirely. + +That path is not yet wired up in Metal.jl. The ABI is documented in +[`docs/src/devel/air_intrinsics.md`](../../docs/src/devel/air_intrinsics.md); +the externally-defined `__tensorops_impl_matmul2d_op_run_*` symbols are +fully captured there, and the matmul2d descriptor layout is known. What +remains is a Julia-side `MtlCooperativeTensor` wrapper plus a host-side +`MTLTensor` / `MTL4ComputeCommandEncoder` binding (the Objective-C +classes are already generated in `lib/mtl/libmtl.jl`, gated on +`macos(v"26.0.0")`). Both validation steps require macOS 26 + Xcode 26; +M5 hardware is required to see the Neural-Accelerator speedup. + +## References + + - Apple — *Discover Metal 4* (WWDC25 Session 205) and + *Combine Metal 4 machine learning and graphics* (WWDC25 Session 262). + - Apple — *Metal Performance Primitives Programming Guide* (PDF, 2025). + - philipturner — [metal-flash-attention](https://github.com/philipturner/metal-flash-attention). + - llama.cpp — [Metal 4 cooperative-tensor FA backend (PR #16634)](https://github.com/ggml-org/llama.cpp/pull/16634). + - liuliu — [example_matmul_metal4](https://github.com/liuliu/example_matmul_metal4) (minimal MSL probe for the Metal 4 host API). diff --git a/examples/flash_attention/fa_mps.jl b/examples/flash_attention/fa_mps.jl new file mode 100644 index 000000000..860e7c094 --- /dev/null +++ b/examples/flash_attention/fa_mps.jl @@ -0,0 +1,44 @@ +# Scaled-dot-product attention via the existing MPS / MPSGraph dispatch. +# +# This is the trivial baseline: rely on Metal.jl's automatic dispatch of `*` +# to MPSMatrixMultiplication / MPSGraph and on the GPUArrays broadcast for +# the softmax. It is NOT a Flash Attention algorithm — the full N×N scores +# matrix is materialized in device memory — but it is the right reference +# implementation to verify a custom kernel against. + +using Metal +using Test + +function attention_reference(Q, K, V) + d = size(Q, 2) + scale = inv(sqrt(eltype(Q)(d))) + + S = (Q * K') .* scale # N_q × N_kv, dispatched to MPS + S = S .- maximum(S; dims = 2) # row-wise max for numerical stability + P = exp.(S) + P = P ./ sum(P; dims = 2) # row-wise softmax + + return P * V # N_q × D, dispatched to MPS +end + +let + N_q, N_kv, D = 64, 64, 32 + T = Float32 + + Q = MtlArray(randn(T, N_q, D)) + K = MtlArray(randn(T, N_kv, D)) + V = MtlArray(randn(T, N_kv, D)) + + O_gpu = attention_reference(Q, K, V) + + # CPU reference + Qh, Kh, Vh = Array(Q), Array(K), Array(V) + scale = inv(sqrt(T(D))) + S = (Qh * Kh') .* scale + S .-= maximum(S; dims = 2) + P = exp.(S) + P ./= sum(P; dims = 2) + O_cpu = P * Vh + + @test Array(O_gpu) ≈ O_cpu rtol = sqrt(eps(T)) +end diff --git a/examples/flash_attention/fa_simdgroup.jl b/examples/flash_attention/fa_simdgroup.jl new file mode 100644 index 000000000..ab3958e42 --- /dev/null +++ b/examples/flash_attention/fa_simdgroup.jl @@ -0,0 +1,115 @@ +# Single-block scaled dot-product attention kernel built from +# `MtlSimdgroupMatrix`. This is the smallest readable example of composing +# the SIMD-group matrix primitives into an attention-style kernel. +# +# Scope: +# - one simdgroup (32 lanes) per threadgroup, one threadgroup total +# - Q, K_t, V are fixed at 8×8 (Br = Bc = D = 8) +# - K is host-transposed to K_t so the in-kernel matmul Q · K_t equals +# mathematical Q · K^T (Metal.jl's `simdgroup_load` always issues a +# transposed-from-MSL load to read Julia's column-major data, so +# getting K^T from a col-major K without an extra binding is awkward) +# - softmax is done in scalar code through threadgroup memory +# +# A "real" FA kernel adds a KV-block loop with online-softmax state +# (`m`, `l` per row) and tiles D across multiple simdgroups; see +# philipturner/metal-flash-attention for a production reference. + +using Metal +using Test + +const Br = 8 +const Bc = 8 +const D = 8 + +function fa_kernel!(O::AbstractMatrix{Float16}, + Q::AbstractMatrix{Float16}, + K_t::AbstractMatrix{Float16}, + V::AbstractMatrix{Float16}, + scale::Float32) + # Stage scratch. + Ss = MtlThreadGroupArray(Float32, (Br, Bc)) # scores, then P + Sh = MtlThreadGroupArray(Float16, (Br, Bc)) # P cast back to fp16 + + # 1. S = Q · K_t (single 8x8 simdgroup_matrix multiply) + Qm = simdgroup_load(MtlSimdgroupMatrix{Float16, 8, 8}, Q) + Km = simdgroup_load(MtlSimdgroupMatrix{Float16, 8, 8}, K_t) + Sm = Qm * Km + + # 2. Spill to threadgroup memory for the row-wise softmax in scalar code. + Sh_tmp = MtlThreadGroupArray(Float16, (Br, Bc)) + simdgroup_store(Sm, Sh_tmp) + threadgroup_barrier(Metal.MemoryFlagThreadGroup) + + # Cast to Float32 and scale; 32 threads cover the 64 elements at 2 per lane. + tid = Int(thread_index_in_threadgroup()) - 1 # 0..31 + @inbounds for k in 0:1 + idx = tid * 2 + k + r = idx ÷ Bc + 1 + c = idx % Bc + 1 + Ss[r, c] = Float32(Sh_tmp[r, c]) * scale + end + threadgroup_barrier(Metal.MemoryFlagThreadGroup) + + # 3. Row-wise softmax. 8 of 32 lanes do real work; the rest idle. + if tid < Br + m = -Inf32 + @inbounds for j in 1:Bc + v = Ss[tid + 1, j] + m = v > m ? v : m + end + s = 0.0f0 + @inbounds for j in 1:Bc + p = exp(Ss[tid + 1, j] - m) + Ss[tid + 1, j] = p + s += p + end + inv_s = 1.0f0 / s + @inbounds for j in 1:Bc + Sh[tid + 1, j] = Float16(Ss[tid + 1, j] * inv_s) + end + end + threadgroup_barrier(Metal.MemoryFlagThreadGroup) + + # 4. O = P · V (second 8x8 simdgroup_matrix multiply) + Pm = simdgroup_load(MtlSimdgroupMatrix{Float16, 8, 8}, Sh) + Vm = simdgroup_load(MtlSimdgroupMatrix{Float16, 8, 8}, V) + Om = Pm * Vm + + simdgroup_store(Om, O) + return +end + +function flash_attention(Q::MtlMatrix{Float16}, K::MtlMatrix{Float16}, + V::MtlMatrix{Float16}) + @assert size(Q) == (Br, D) "Q must be ($Br, $D)" + @assert size(K) == (Bc, D) "K must be ($Bc, $D)" + @assert size(V) == (Bc, D) "V must be ($Bc, $D)" + + K_t = MtlMatrix(collect(transpose(Array(K)))) + O = similar(Q) + scale = inv(sqrt(Float32(D))) + + Metal.@sync @metal threads = 32 fa_kernel!(O, Q, K_t, V, scale) + return O +end + +let + T = Float16 + Q = MtlArray(rand(T, Br, D)) + K = MtlArray(rand(T, Bc, D)) + V = MtlArray(rand(T, Bc, D)) + + O = flash_attention(Q, K, V) + + # Reference attention computed in Float32 on the CPU. + Qh, Kh, Vh = Float32.(Array(Q)), Float32.(Array(K)), Float32.(Array(V)) + scale = inv(sqrt(Float32(D))) + S = (Qh * Kh') .* scale + S .-= maximum(S; dims = 2) + P = exp.(S) + P ./= sum(P; dims = 2) + O_ref = P * Vh + + @test Float32.(Array(O)) ≈ O_ref rtol = 1e-2 +end From 5cb34979170b688c3c4d431d0d13f2f29e79a08a Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Sun, 24 May 2026 12:50:00 +0200 Subject: [PATCH 03/24] Note Metal 4 tensor ops use externally-defined symbols. --- examples/flash_attention/README.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/flash_attention/README.md b/examples/flash_attention/README.md index 4d26513e7..43a293b72 100644 --- a/examples/flash_attention/README.md +++ b/examples/flash_attention/README.md @@ -55,15 +55,15 @@ epilogue. Apple positions this as the "preferred programming model for ML applications" — on M5 hardware it can issue Neural-Accelerator MMAs and skip threadgroup memory entirely. -That path is not yet wired up in Metal.jl. The ABI is documented in -[`docs/src/devel/air_intrinsics.md`](../../docs/src/devel/air_intrinsics.md); -the externally-defined `__tensorops_impl_matmul2d_op_run_*` symbols are -fully captured there, and the matmul2d descriptor layout is known. What -remains is a Julia-side `MtlCooperativeTensor` wrapper plus a host-side -`MTLTensor` / `MTL4ComputeCommandEncoder` binding (the Objective-C -classes are already generated in `lib/mtl/libmtl.jl`, gated on -`macos(v"26.0.0")`). Both validation steps require macOS 26 + Xcode 26; -M5 hardware is required to see the Neural-Accelerator speedup. +That path is not yet wired up in Metal.jl. The Objective-C classes are +already generated in `lib/mtl/libmtl.jl` (gated on `macos(v"26.0.0")`); +what remains is a Julia-side `MtlCooperativeTensor` wrapper plus a +host-side `MTLTensor` / `MTL4ComputeCommandEncoder` binding. Note that +the device-side ops lower to externally-defined +`__tensorops_impl_matmul2d_op_*` symbols rather than `air.*` intrinsics, +so the binding pattern differs from the SIMD-group case. Validation +needs macOS 26 + Xcode 26; M5 hardware is required to see the +Neural-Accelerator speedup. ## References From 9b559d7a3c60ddceb562ffa945ef17aab8913cb8 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Sun, 24 May 2026 13:18:42 +0200 Subject: [PATCH 04/24] Fold MtlSimdgroupMatrix wrapper into simd.jl. --- src/Metal.jl | 1 - src/device/intrinsics/simd.jl | 71 +++++++++++ src/device/intrinsics/simdgroup_matrix.jl | 68 ----------- test/device/intrinsics/simd.jl | 136 +++++++++++++++++++++ test/device/intrinsics/simdgroup_matrix.jl | 135 -------------------- 5 files changed, 207 insertions(+), 204 deletions(-) delete mode 100644 src/device/intrinsics/simdgroup_matrix.jl delete mode 100644 test/device/intrinsics/simdgroup_matrix.jl diff --git a/src/Metal.jl b/src/Metal.jl index 7ed8567cb..ecc403a73 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -40,7 +40,6 @@ include("device/intrinsics/math.jl") include("device/intrinsics/synchronization.jl") include("device/intrinsics/memory.jl") include("device/intrinsics/simd.jl") -include("device/intrinsics/simdgroup_matrix.jl") include("device/intrinsics/atomics.jl") include("device/malloc.jl") include("device/random.jl") diff --git a/src/device/intrinsics/simd.jl b/src/device/intrinsics/simd.jl index 368227ac2..30f3ce6ff 100644 --- a/src/device/intrinsics/simd.jl +++ b/src/device/intrinsics/simd.jl @@ -1,4 +1,5 @@ export simdgroup_load, simdgroup_store, simdgroup_multiply, simdgroup_multiply_accumulate, + MtlSimdgroupMatrix, simd_shuffle_down, simd_shuffle_up, simd_shuffle_and_fill_down, simd_shuffle_and_fill_up, simd_shuffle, simd_shuffle_xor, simd_ballot, simd_vote_all, simd_vote_any @@ -85,6 +86,76 @@ Returns `a * b + c`. """ simdgroup_multiply_accumulate +## Typed wrapper + +""" + MtlSimdgroupMatrix{T,R,C} + +Typed wrapper around a SIMD-group matrix fragment. `T` is the element type +(`Float16` or `Float32`); `R` and `C` are the matrix dimensions. Only the +8×8 shape is supported by current Apple GPUs. + +The fragment data is distributed across the 32 lanes of a SIMD-group; the +per-lane element layout is implementation-defined and elements cannot be +accessed directly. To inspect or modify individual entries, store the +matrix to device or threadgroup memory first. + +Construct via [`simdgroup_load`](@ref), [`zero`](@ref) or the explicit +fill constructor `MtlSimdgroupMatrix{T,8,8}(val::T)`. +""" +struct MtlSimdgroupMatrix{T,R,C} + data::NTuple{64, VecElement{T}} + + global _unsafe_wrap_simdgroup_matrix(::Type{MtlSimdgroupMatrix{T,R,C}}, + data::NTuple{64, VecElement{T}}) where {T,R,C} = + new{T,R,C}(data) +end + +Base.size(::Type{<:MtlSimdgroupMatrix{<:Any,R,C}}) where {R,C} = (R, C) +Base.size(m::MtlSimdgroupMatrix) = size(typeof(m)) +Base.eltype(::Type{<:MtlSimdgroupMatrix{T}}) where {T} = T +Base.eltype(m::MtlSimdgroupMatrix) = eltype(typeof(m)) + +# Fill constructor: materialize a fragment whose elements are all `val`. +@inline function MtlSimdgroupMatrix{T,8,8}(val::T) where {T} + return _unsafe_wrap_simdgroup_matrix(MtlSimdgroupMatrix{T,8,8}, + ntuple(_ -> VecElement{T}(val), Val(64))) +end + +@inline Base.zero(::Type{MtlSimdgroupMatrix{T,8,8}}) where {T} = + MtlSimdgroupMatrix{T,8,8}(zero(T)) + +# Load: build a fragment from a device or threadgroup array tile. +@device_function @inline function simdgroup_load(::Type{MtlSimdgroupMatrix{T,8,8}}, + src::MtlDeviceArray{T}, + matrix_origin::NTuple{2, Int64} = (1, 1)) where {T} + return _unsafe_wrap_simdgroup_matrix(MtlSimdgroupMatrix{T,8,8}, + simdgroup_load(src, matrix_origin)) +end + +# Store: write the fragment back to a device or threadgroup array tile. +@device_function @inline function simdgroup_store(m::MtlSimdgroupMatrix{T,8,8}, + dest::MtlDeviceArray{T}, + matrix_origin::NTuple{2, Int64} = (1, 1)) where {T} + return simdgroup_store(m.data, dest, matrix_origin) +end + +# Multiply: D = A * B. +@inline function Base.:(*)(a::MtlSimdgroupMatrix{T,8,8}, + b::MtlSimdgroupMatrix{T,8,8}) where {T} + return _unsafe_wrap_simdgroup_matrix(MtlSimdgroupMatrix{T,8,8}, + simdgroup_multiply(a.data, b.data)) +end + +# Fused multiply-add: D = A * B + C. +@inline function Base.muladd(a::MtlSimdgroupMatrix{T,8,8}, + b::MtlSimdgroupMatrix{T,8,8}, + c::MtlSimdgroupMatrix{T,8,8}) where {T} + return _unsafe_wrap_simdgroup_matrix(MtlSimdgroupMatrix{T,8,8}, + simdgroup_multiply_accumulate(a.data, b.data, c.data)) +end + + ## SIMD Shuffle Up/Down simd_shuffle_map = ((Float32, "f32"), diff --git a/src/device/intrinsics/simdgroup_matrix.jl b/src/device/intrinsics/simdgroup_matrix.jl deleted file mode 100644 index 2762797a5..000000000 --- a/src/device/intrinsics/simdgroup_matrix.jl +++ /dev/null @@ -1,68 +0,0 @@ -export MtlSimdgroupMatrix - -""" - MtlSimdgroupMatrix{T,R,C} - -Typed wrapper around a SIMD-group matrix fragment. `T` is the element type -(`Float16` or `Float32`); `R` and `C` are the matrix dimensions. Only the -8×8 shape is supported by current Apple GPUs. - -The fragment data is distributed across the 32 lanes of a SIMD-group; the -per-lane element layout is implementation-defined and elements cannot be -accessed directly. To inspect or modify individual entries, store the -matrix to device or threadgroup memory first. - -Construct via [`simdgroup_load`](@ref), [`zero`](@ref) or the explicit -fill constructor `MtlSimdgroupMatrix{T,8,8}(val::T)`. -""" -struct MtlSimdgroupMatrix{T,R,C} - data::NTuple{64, VecElement{T}} - - global _unsafe_wrap_simdgroup_matrix(::Type{MtlSimdgroupMatrix{T,R,C}}, - data::NTuple{64, VecElement{T}}) where {T,R,C} = - new{T,R,C}(data) -end - -Base.size(::Type{<:MtlSimdgroupMatrix{<:Any,R,C}}) where {R,C} = (R, C) -Base.size(m::MtlSimdgroupMatrix) = size(typeof(m)) -Base.eltype(::Type{<:MtlSimdgroupMatrix{T}}) where {T} = T -Base.eltype(m::MtlSimdgroupMatrix) = eltype(typeof(m)) - -# Fill constructor: materialize a fragment whose elements are all `val`. -@inline function MtlSimdgroupMatrix{T,8,8}(val::T) where {T} - return _unsafe_wrap_simdgroup_matrix(MtlSimdgroupMatrix{T,8,8}, - ntuple(_ -> VecElement{T}(val), Val(64))) -end - -@inline Base.zero(::Type{MtlSimdgroupMatrix{T,8,8}}) where {T} = - MtlSimdgroupMatrix{T,8,8}(zero(T)) - -# Load: build a fragment from a device or threadgroup array tile. -@device_function @inline function simdgroup_load(::Type{MtlSimdgroupMatrix{T,8,8}}, - src::MtlDeviceArray{T}, - matrix_origin::NTuple{2, Int64} = (1, 1)) where {T} - return _unsafe_wrap_simdgroup_matrix(MtlSimdgroupMatrix{T,8,8}, - simdgroup_load(src, matrix_origin)) -end - -# Store: write the fragment back to a device or threadgroup array tile. -@device_function @inline function simdgroup_store(m::MtlSimdgroupMatrix{T,8,8}, - dest::MtlDeviceArray{T}, - matrix_origin::NTuple{2, Int64} = (1, 1)) where {T} - return simdgroup_store(m.data, dest, matrix_origin) -end - -# Multiply: D = A * B. -@inline function Base.:(*)(a::MtlSimdgroupMatrix{T,8,8}, - b::MtlSimdgroupMatrix{T,8,8}) where {T} - return _unsafe_wrap_simdgroup_matrix(MtlSimdgroupMatrix{T,8,8}, - simdgroup_multiply(a.data, b.data)) -end - -# Fused multiply-add: D = A * B + C. -@inline function Base.muladd(a::MtlSimdgroupMatrix{T,8,8}, - b::MtlSimdgroupMatrix{T,8,8}, - c::MtlSimdgroupMatrix{T,8,8}) where {T} - return _unsafe_wrap_simdgroup_matrix(MtlSimdgroupMatrix{T,8,8}, - simdgroup_multiply_accumulate(a.data, b.data, c.data)) -end diff --git a/test/device/intrinsics/simd.jl b/test/device/intrinsics/simd.jl index 684bb04bf..eb8d55fb0 100644 --- a/test/device/intrinsics/simd.jl +++ b/test/device/intrinsics/simd.jl @@ -288,4 +288,140 @@ end @metal threads=(8, 8) kernel(a, b, c, d) @test Array(a) * Array(b) + Array(c) ≈ Array(d) end + + @testset "MtlSimdgroupMatrix type" begin + @testset for T in (Float16, Float32) + @test eltype(MtlSimdgroupMatrix{T,8,8}) === T + @test size(MtlSimdgroupMatrix{T,8,8}) === (8, 8) + end + end + + @testset "MtlSimdgroupMatrix fill($T)" for T in (Float16, Float32) + function kernel(out::MtlDeviceMatrix{T}, val::T) where {T} + m = MtlSimdgroupMatrix{T,8,8}(val) + simdgroup_store(m, out) + return + end + + out = MtlArray(zeros(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(out, T(3.5)) + @test all(Array(out) .== T(3.5)) + end + + @testset "MtlSimdgroupMatrix zero($T)" for T in (Float16, Float32) + function kernel(out::MtlDeviceMatrix{T}) where {T} + m = zero(MtlSimdgroupMatrix{T,8,8}) + simdgroup_store(m, out) + return + end + + out = MtlArray(ones(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(out) + @test all(Array(out) .== zero(T)) + end + + @testset "MtlSimdgroupMatrix load_store($T)" for T in (Float16, Float32) + function kernel(a::MtlDeviceMatrix{T}, b::MtlDeviceMatrix{T}) where {T} + m = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, a) + simdgroup_store(m, b) + return + end + + a = MtlArray(rand(T, 8, 8)) + b = MtlArray(zeros(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(a, b) + @test Array(a) == Array(b) + end + + @testset "MtlSimdgroupMatrix load_store with origin($T)" for T in (Float16, Float32) + function kernel(a::MtlDeviceMatrix{T}, b::MtlDeviceMatrix{T}, + origin_a::NTuple{2,Int64}, origin_b::NTuple{2,Int64}) where {T} + m = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, a, origin_a) + simdgroup_store(m, b, origin_b) + return + end + + a = MtlArray(rand(T, 20, 15)) + b = MtlArray(zeros(T, 15, 20)) + Metal.@sync @metal threads=(8, 8) kernel(a, b, (4, 2), (3, 5)) + @test Array(a)[4:11, 2:9] == Array(b)[3:10, 5:12] + end + + @testset "MtlSimdgroupMatrix multiply($T)" for T in (Float16, Float32) + function kernel(a::MtlDeviceMatrix{T}, b::MtlDeviceMatrix{T}, c::MtlDeviceMatrix{T}) where {T} + ma = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, a) + mb = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, b) + simdgroup_store(ma * mb, c) + return + end + + a = MtlArray(rand(T, 8, 8)) + b = MtlArray(rand(T, 8, 8)) + c = MtlArray(zeros(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(a, b, c) + @test Array(a) * Array(b) ≈ Array(c) + end + + @testset "MtlSimdgroupMatrix muladd($T)" for T in (Float16, Float32) + function kernel(a::MtlDeviceMatrix{T}, b::MtlDeviceMatrix{T}, + c::MtlDeviceMatrix{T}, d::MtlDeviceMatrix{T}) where {T} + ma = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, a) + mb = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, b) + mc = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, c) + simdgroup_store(muladd(ma, mb, mc), d) + return + end + + a = MtlArray(rand(T, 8, 8)) + b = MtlArray(rand(T, 8, 8)) + c = MtlArray(rand(T, 8, 8)) + d = MtlArray(zeros(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(a, b, c, d) + @test Array(a) * Array(b) + Array(c) ≈ Array(d) + end + + # Composed K-loop GEMM: C(8×8) = A(8×K) * B(K×8) with K=32, accumulating + # four 8×8×8 fragment MMAs. + @testset "MtlSimdgroupMatrix K-loop GEMM($T)" for T in (Float16, Float32) + function kernel(A::MtlDeviceMatrix{T}, B::MtlDeviceMatrix{T}, C::MtlDeviceMatrix{T}) where {T} + acc = zero(MtlSimdgroupMatrix{T,8,8}) + for k in 0:3 + ma = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, A, (1, 1 + k*8)) + mb = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, B, (1 + k*8, 1)) + acc = muladd(ma, mb, acc) + end + simdgroup_store(acc, C) + return + end + + A = MtlArray(rand(T, 8, 32)) + B = MtlArray(rand(T, 32, 8)) + C = MtlArray(zeros(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(A, B, C) + @test Array(A) * Array(B) ≈ Array(C) rtol=sqrt(eps(T)) + end + + # Threadgroup-memory variant: stage tiles through threadgroup memory, then + # load fragments from there. Mirrors how Flash Attention stages K/V tiles. + @testset "MtlSimdgroupMatrix threadgroup load($T)" for T in (Float16, Float32) + function kernel(a::MtlDeviceMatrix{T}, b::MtlDeviceMatrix{T}) where {T} + pos = thread_position_in_threadgroup() + tg = MtlThreadGroupArray(T, (8, 8)) + tg[pos.x, pos.y] = a[pos.x, pos.y] + threadgroup_barrier(Metal.MemoryFlagThreadGroup) + + m = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, tg) + tg2 = MtlThreadGroupArray(T, (8, 8)) + simdgroup_store(m, tg2) + threadgroup_barrier(Metal.MemoryFlagThreadGroup) + + b[pos.x, pos.y] = tg2[pos.x, pos.y] + return + end + + a = MtlArray(rand(T, 8, 8)) + b = MtlArray(zeros(T, 8, 8)) + Metal.@sync @metal threads=(8, 8) kernel(a, b) + @test Array(a) == Array(b) + end end # End Matrix Functions diff --git a/test/device/intrinsics/simdgroup_matrix.jl b/test/device/intrinsics/simdgroup_matrix.jl deleted file mode 100644 index 16922211e..000000000 --- a/test/device/intrinsics/simdgroup_matrix.jl +++ /dev/null @@ -1,135 +0,0 @@ -@testset "type" begin - @testset for T in (Float16, Float32) - @test eltype(MtlSimdgroupMatrix{T,8,8}) === T - @test size(MtlSimdgroupMatrix{T,8,8}) === (8, 8) - end -end - -@testset "fill($T)" for T in (Float16, Float32) - function kernel(out::MtlDeviceMatrix{T}, val::T) where {T} - m = MtlSimdgroupMatrix{T,8,8}(val) - simdgroup_store(m, out) - return - end - - out = MtlArray(zeros(T, 8, 8)) - Metal.@sync @metal threads=(8, 8) kernel(out, T(3.5)) - @test all(Array(out) .== T(3.5)) -end - -@testset "zero($T)" for T in (Float16, Float32) - function kernel(out::MtlDeviceMatrix{T}) where {T} - m = zero(MtlSimdgroupMatrix{T,8,8}) - simdgroup_store(m, out) - return - end - - out = MtlArray(ones(T, 8, 8)) - Metal.@sync @metal threads=(8, 8) kernel(out) - @test all(Array(out) .== zero(T)) -end - -@testset "load_store($T)" for T in (Float16, Float32) - function kernel(a::MtlDeviceMatrix{T}, b::MtlDeviceMatrix{T}) where {T} - m = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, a) - simdgroup_store(m, b) - return - end - - a = MtlArray(rand(T, 8, 8)) - b = MtlArray(zeros(T, 8, 8)) - Metal.@sync @metal threads=(8, 8) kernel(a, b) - @test Array(a) == Array(b) -end - -@testset "load_store with origin($T)" for T in (Float16, Float32) - function kernel(a::MtlDeviceMatrix{T}, b::MtlDeviceMatrix{T}, - origin_a::NTuple{2,Int64}, origin_b::NTuple{2,Int64}) where {T} - m = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, a, origin_a) - simdgroup_store(m, b, origin_b) - return - end - - a = MtlArray(rand(T, 20, 15)) - b = MtlArray(zeros(T, 15, 20)) - Metal.@sync @metal threads=(8, 8) kernel(a, b, (4, 2), (3, 5)) - @test Array(a)[4:11, 2:9] == Array(b)[3:10, 5:12] -end - -@testset "multiply($T)" for T in (Float16, Float32) - function kernel(a::MtlDeviceMatrix{T}, b::MtlDeviceMatrix{T}, c::MtlDeviceMatrix{T}) where {T} - ma = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, a) - mb = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, b) - simdgroup_store(ma * mb, c) - return - end - - a = MtlArray(rand(T, 8, 8)) - b = MtlArray(rand(T, 8, 8)) - c = MtlArray(zeros(T, 8, 8)) - Metal.@sync @metal threads=(8, 8) kernel(a, b, c) - @test Array(a) * Array(b) ≈ Array(c) -end - -@testset "muladd($T)" for T in (Float16, Float32) - function kernel(a::MtlDeviceMatrix{T}, b::MtlDeviceMatrix{T}, - c::MtlDeviceMatrix{T}, d::MtlDeviceMatrix{T}) where {T} - ma = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, a) - mb = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, b) - mc = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, c) - simdgroup_store(muladd(ma, mb, mc), d) - return - end - - a = MtlArray(rand(T, 8, 8)) - b = MtlArray(rand(T, 8, 8)) - c = MtlArray(rand(T, 8, 8)) - d = MtlArray(zeros(T, 8, 8)) - Metal.@sync @metal threads=(8, 8) kernel(a, b, c, d) - @test Array(a) * Array(b) + Array(c) ≈ Array(d) -end - -# Composed K-loop GEMM: C(8×8) = A(8×K) * B(K×8) with K=32, accumulating -# four 8×8×8 fragment MMAs. -@testset "K-loop GEMM($T)" for T in (Float16, Float32) - function kernel(A::MtlDeviceMatrix{T}, B::MtlDeviceMatrix{T}, C::MtlDeviceMatrix{T}) where {T} - acc = zero(MtlSimdgroupMatrix{T,8,8}) - for k in 0:3 - ma = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, A, (1, 1 + k*8)) - mb = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, B, (1 + k*8, 1)) - acc = muladd(ma, mb, acc) - end - simdgroup_store(acc, C) - return - end - - A = MtlArray(rand(T, 8, 32)) - B = MtlArray(rand(T, 32, 8)) - C = MtlArray(zeros(T, 8, 8)) - Metal.@sync @metal threads=(8, 8) kernel(A, B, C) - @test Array(A) * Array(B) ≈ Array(C) rtol=sqrt(eps(T)) -end - -# Threadgroup-memory variant: stage tiles through threadgroup memory, then -# load fragments from there. Mirrors how Flash Attention stages K/V tiles. -@testset "threadgroup load($T)" for T in (Float16, Float32) - function kernel(a::MtlDeviceMatrix{T}, b::MtlDeviceMatrix{T}) where {T} - pos = thread_position_in_threadgroup() - tg = MtlThreadGroupArray(T, (8, 8)) - tg[pos.x, pos.y] = a[pos.x, pos.y] - threadgroup_barrier(Metal.MemoryFlagThreadGroup) - - m = simdgroup_load(MtlSimdgroupMatrix{T,8,8}, tg) - tg2 = MtlThreadGroupArray(T, (8, 8)) - simdgroup_store(m, tg2) - threadgroup_barrier(Metal.MemoryFlagThreadGroup) - - b[pos.x, pos.y] = tg2[pos.x, pos.y] - return - end - - a = MtlArray(rand(T, 8, 8)) - b = MtlArray(zeros(T, 8, 8)) - Metal.@sync @metal threads=(8, 8) kernel(a, b) - @test Array(a) == Array(b) -end From 4bec8c548ef5e04cd9ea6bb2cffb74fd95785eeb Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Sun, 24 May 2026 13:31:05 +0200 Subject: [PATCH 05/24] Wrap MPSGraph scaledDotProductAttention. --- examples/flash_attention/README.md | 15 +++++ examples/flash_attention/fa_mpsgraph.jl | 83 ++++++++++++++++++++++++ lib/mpsgraphs/operations.jl | 25 ++++++++ test/mpsgraphs/sdpa.jl | 85 +++++++++++++++++++++++++ 4 files changed, 208 insertions(+) create mode 100644 examples/flash_attention/fa_mpsgraph.jl create mode 100644 test/mpsgraphs/sdpa.jl diff --git a/examples/flash_attention/README.md b/examples/flash_attention/README.md index 43a293b72..0e0582030 100644 --- a/examples/flash_attention/README.md +++ b/examples/flash_attention/README.md @@ -18,6 +18,21 @@ when you don't need a custom kernel. Works on macOS 13+ / M1+. +## `fa_mpsgraph.jl` + +The high-level MPS path. Builds a one-node MPSGraph using +`scaledDotProductAttentionWithQueryTensor` (macOS 14+), which fuses +Q·Kᵀ → scale → softmax → ·V into a single op. Apple uses the same op as +the backbone of their own SDPA paths (MLX falls back to it; Core ML +lowers attention to it), so it's the closest thing to "ask Apple for +attention" that Metal.jl can give you. + +Inputs are 4-D `(head_dim, seq, num_heads, batch)` in Julia — MPSGraph +sees these reversed as `(batch, num_heads, seq, head_dim)`, the layout +Apple's SDPA expects. + +Works on macOS 14+ / M1+. + ## `fa_simdgroup.jl` A single-block scaled dot-product attention kernel built from diff --git a/examples/flash_attention/fa_mpsgraph.jl b/examples/flash_attention/fa_mpsgraph.jl new file mode 100644 index 000000000..d0f904aab --- /dev/null +++ b/examples/flash_attention/fa_mpsgraph.jl @@ -0,0 +1,83 @@ +# Fused scaled-dot-product attention via MPSGraph's +# `scaledDotProductAttentionWithQueryTensor` op (macOS 14+). One graph node +# does Q·Kᵀ → scale → softmax → ·V end-to-end; MPSGraph picks the kernel. +# +# Higher-level than `fa_mps.jl`: there's no host-side composition of +# matmul + broadcast + softmax + matmul. Whether MPSGraph internally +# materializes the N×N scores matrix is a black box, but on macOS 14+ +# Apple uses this same op as their backbone SDPA implementation +# (MLX falls back to it; Core ML lowers attention to it). +# +# This example uses the 4-D `(head_dim, seq, num_heads, batch)` Julia +# layout, which MPSGraph sees reversed as `(batch, num_heads, seq, head_dim)` +# — Apple's expected SDPA layout. + +using Metal +using Metal.MPS: MPSCommandBuffer, commit!, wait_completed +using Metal.MPSGraphs: MPSGraph, MPSGraphTensor, MPSGraphTensorData, + placeholderTensor, scaledDotProductAttentionWithQueryTensor, + encode!, default_exec_desc +using Metal.Foundation: NSDictionary, nil +using Test + +function mpsgraph_attention(Q::MtlArray{T,4}, K::MtlArray{T,4}, V::MtlArray{T,4}; + scale = inv(sqrt(T(size(Q, 1))))) where {T} + @assert size(Q, 1) == size(K, 1) == size(V, 1) "head dim mismatch" + @assert size(K, 2) == size(V, 2) "K/V seq length mismatch" + @assert size(Q)[3:4] == size(K)[3:4] == size(V)[3:4] "(heads, batch) mismatch" + + O = similar(Q) + + graph = MPSGraph() + qph = placeholderTensor(graph, size(Q), T) + kph = placeholderTensor(graph, size(K), T) + vph = placeholderTensor(graph, size(V), T) + out = scaledDotProductAttentionWithQueryTensor(graph, qph, kph, vph, + Float32(scale)) + + feeds = Dict{MPSGraphTensor, MPSGraphTensorData}( + qph => MPSGraphTensorData(Q), + kph => MPSGraphTensorData(K), + vph => MPSGraphTensorData(V), + ) + results = Dict{MPSGraphTensor, MPSGraphTensorData}( + out => MPSGraphTensorData(O), + ) + + cmdbuf = MPSCommandBuffer(Metal.global_queue(device())) + encode!(cmdbuf, graph, NSDictionary(feeds), NSDictionary(results), + nil, default_exec_desc()) + commit!(cmdbuf) + wait_completed(cmdbuf) + return O +end + +# CPU reference attention, 4-D (head_dim, seq, num_heads, batch) Julia layout. +function cpu_attention(Q, K, V; scale = inv(sqrt(eltype(Q)(size(Q, 1))))) + D, N_q, H, B = size(Q) + N_kv = size(K, 2) + O = similar(Q) + for b in 1:B, h in 1:H + Qm, Km, Vm = Q[:, :, h, b], K[:, :, h, b], V[:, :, h, b] # each is (D, N) + S = (Qm' * Km) .* scale # (N_q, N_kv) + S .-= maximum(S; dims = 2) + P = exp.(S) + P ./= sum(P; dims = 2) + O[:, :, h, b] = Vm * P' # (D, N_q) + end + return O +end + +let + T = Float32 + D, N_q, N_kv, H, B = 16, 24, 32, 2, 1 + + Q = MtlArray(randn(T, D, N_q, H, B)) + K = MtlArray(randn(T, D, N_kv, H, B)) + V = MtlArray(randn(T, D, N_kv, H, B)) + + O_gpu = mpsgraph_attention(Q, K, V) + O_cpu = cpu_attention(Array(Q), Array(K), Array(V)) + + @test Array(O_gpu) ≈ O_cpu rtol = sqrt(eps(T)) +end diff --git a/lib/mpsgraphs/operations.jl b/lib/mpsgraphs/operations.jl index 107c9ae31..96ffe8c52 100644 --- a/lib/mpsgraphs/operations.jl +++ b/lib/mpsgraphs/operations.jl @@ -74,3 +74,28 @@ Dumps the `graph`. This function is undocumented from Apple so it may stop working at any time. """ dump_graph(graph::MPSGraph) = @objc [graph::id{MPSGraph} dump]::Nothing ## COV_EXCL_LINE + +# Scaled dot-product attention: softmax((Q · K^T) * scale [+ mask]) · V, fused. +# Available macOS 14.0+. +function scaledDotProductAttentionWithQueryTensor(graph::MPSGraph, Q::MPSGraphTensor, + K::MPSGraphTensor, V::MPSGraphTensor, + scale::Real, name = "sdpa") + obj = @objc [graph::id{MPSGraph} scaledDotProductAttentionWithQueryTensor:Q::id{MPSGraphTensor} + keyTensor:K::id{MPSGraphTensor} + valueTensor:V::id{MPSGraphTensor} + scale:scale::Cfloat + name:name::id{NSString}]::id{MPSGraphTensor} + MPSGraphTensor(obj) +end +function scaledDotProductAttentionWithQueryTensor(graph::MPSGraph, Q::MPSGraphTensor, + K::MPSGraphTensor, V::MPSGraphTensor, + mask::MPSGraphTensor, scale::Real, + name = "sdpa") + obj = @objc [graph::id{MPSGraph} scaledDotProductAttentionWithQueryTensor:Q::id{MPSGraphTensor} + keyTensor:K::id{MPSGraphTensor} + valueTensor:V::id{MPSGraphTensor} + maskTensor:mask::id{MPSGraphTensor} + scale:scale::Cfloat + name:name::id{NSString}]::id{MPSGraphTensor} + MPSGraphTensor(obj) +end diff --git a/test/mpsgraphs/sdpa.jl b/test/mpsgraphs/sdpa.jl new file mode 100644 index 000000000..ab0231d39 --- /dev/null +++ b/test/mpsgraphs/sdpa.jl @@ -0,0 +1,85 @@ +if MPS.is_supported(device()) && Metal.macos_version() >= v"14" + +using .MPS: MPSCommandBuffer, commit!, wait_completed +using .MPSGraphs: MPSGraph, MPSGraphTensor, MPSGraphTensorData, + placeholderTensor, scaledDotProductAttentionWithQueryTensor, + encode!, default_exec_desc +using ObjectiveC.Foundation: NSDictionary, nil + +# Reference attention in 4-D `(head_dim, seq, num_heads, batch)` Julia layout. +# Mask (if provided) is `(N_kv, N_q, num_heads, batch)` — that's MPS's natural +# `(B, H, N_q, N_kv)` layout reversed for Julia col-major. +function _ref_attention(Q, K, V, scale, mask = nothing) + D, N_q, H, B = size(Q) + N_kv = size(K, 2) + O = similar(Q) + for b in 1:B, h in 1:H + Qm, Km, Vm = Q[:, :, h, b], K[:, :, h, b], V[:, :, h, b] + S = (Qm' * Km) .* scale + if mask !== nothing + S .+= transpose(mask[:, :, h, b]) + end + S .-= maximum(S; dims = 2) + P = exp.(S) + P ./= sum(P; dims = 2) + O[:, :, h, b] = Vm * P' + end + return O +end + +function _run_sdpa(Q, K, V, scale; mask = nothing) + O = similar(Q) + graph = MPSGraph() + qph = placeholderTensor(graph, size(Q), eltype(Q)) + kph = placeholderTensor(graph, size(K), eltype(K)) + vph = placeholderTensor(graph, size(V), eltype(V)) + out = if mask === nothing + scaledDotProductAttentionWithQueryTensor(graph, qph, kph, vph, Float32(scale)) + else + mph = placeholderTensor(graph, size(mask), eltype(mask)) + scaledDotProductAttentionWithQueryTensor(graph, qph, kph, vph, mph, + Float32(scale)) + end + + feeds = Dict{MPSGraphTensor, MPSGraphTensorData}( + qph => MPSGraphTensorData(Q), + kph => MPSGraphTensorData(K), + vph => MPSGraphTensorData(V), + ) + if mask !== nothing + feeds[graph.placeholderTensors[end]] = MPSGraphTensorData(mask) + end + results = Dict{MPSGraphTensor, MPSGraphTensorData}( + out => MPSGraphTensorData(O), + ) + + cmdbuf = MPSCommandBuffer(Metal.global_queue(device())) + encode!(cmdbuf, graph, NSDictionary(feeds), NSDictionary(results), nil, + default_exec_desc()) + commit!(cmdbuf) + wait_completed(cmdbuf) + return O +end + +@testset "scaled dot-product attention ($T)" for T in (Float16, Float32) + D, N_q, N_kv, H, B = 8, 12, 16, 2, 1 + Q = MtlArray(randn(T, D, N_q, H, B)) + K = MtlArray(randn(T, D, N_kv, H, B)) + V = MtlArray(randn(T, D, N_kv, H, B)) + scale = inv(sqrt(T(D))) + + @testset "no mask" begin + O = _run_sdpa(Q, K, V, scale) + O_ref = _ref_attention(Array(Q), Array(K), Array(V), scale) + @test Array(O) ≈ O_ref rtol = (T === Float16 ? 1e-2 : sqrt(eps(T))) + end + + @testset "with mask" begin + mask = MtlArray(randn(T, N_kv, N_q, H, B)) + O = _run_sdpa(Q, K, V, scale; mask) + O_ref = _ref_attention(Array(Q), Array(K), Array(V), scale, Array(mask)) + @test Array(O) ≈ O_ref rtol = (T === Float16 ? 1e-2 : sqrt(eps(T))) + end +end + +end # MPS.is_supported(device()) && macOS 14+ From 132c91c61ca5ba33e74aad341854431a52e4aae9 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Sun, 24 May 2026 13:40:35 +0200 Subject: [PATCH 06/24] Merge Flash Attention examples into a single file with benchmarks. --- examples/flash_attention/README.md | 90 --------- examples/flash_attention/fa_mps.jl | 44 ----- examples/flash_attention/fa_mpsgraph.jl | 83 -------- examples/flash_attention/fa_simdgroup.jl | 115 ------------ examples/flashattention.jl | 230 +++++++++++++++++++++++ 5 files changed, 230 insertions(+), 332 deletions(-) delete mode 100644 examples/flash_attention/README.md delete mode 100644 examples/flash_attention/fa_mps.jl delete mode 100644 examples/flash_attention/fa_mpsgraph.jl delete mode 100644 examples/flash_attention/fa_simdgroup.jl create mode 100644 examples/flashattention.jl diff --git a/examples/flash_attention/README.md b/examples/flash_attention/README.md deleted file mode 100644 index 0e0582030..000000000 --- a/examples/flash_attention/README.md +++ /dev/null @@ -1,90 +0,0 @@ -# Flash Attention examples - -Two reference implementations of scaled dot-product attention on Apple -Silicon GPUs from Julia, illustrating different programming models that -Metal.jl exposes. - -## `fa_mps.jl` - -The trivial baseline. Uses standard Julia operators (`*`, broadcasting, -`maximum`, `sum`, `exp`) on `MtlArray`. The matrix multiplications are -dispatched to **MPSGraph / MPSMatrixMultiplication** by -[`src/linalg.jl`](../../src/linalg.jl); the rest is GPUArrays. - -Not actually a Flash Attention algorithm — the full N×N scores matrix is -materialized in device memory — but it is the right reference to verify a -custom kernel against, and the fastest path to "attention runs on GPU" -when you don't need a custom kernel. - -Works on macOS 13+ / M1+. - -## `fa_mpsgraph.jl` - -The high-level MPS path. Builds a one-node MPSGraph using -`scaledDotProductAttentionWithQueryTensor` (macOS 14+), which fuses -Q·Kᵀ → scale → softmax → ·V into a single op. Apple uses the same op as -the backbone of their own SDPA paths (MLX falls back to it; Core ML -lowers attention to it), so it's the closest thing to "ask Apple for -attention" that Metal.jl can give you. - -Inputs are 4-D `(head_dim, seq, num_heads, batch)` in Julia — MPSGraph -sees these reversed as `(batch, num_heads, seq, head_dim)`, the layout -Apple's SDPA expects. - -Works on macOS 14+ / M1+. - -## `fa_simdgroup.jl` - -A single-block scaled dot-product attention kernel built from -`MtlSimdgroupMatrix{Float16, 8, 8}` (see `src/device/intrinsics/`). One -simdgroup of 32 lanes does the QKᵀ and PV matrix multiplies via two -`simdgroup_matrix` ops; the row-wise softmax is done in scalar code -through threadgroup memory. - -The example is intentionally minimal — Q, K, V are fixed at 8×8 — so the -control flow stays readable. A production implementation would: - - - sweep KV in blocks with online-softmax state (`m`, `l` per query row), - - tile D across multiple simdgroups, - - overlap loads with compute via `simdgroup_async_copy`, - - and split the backward pass into separate dQ / dKV kernels to avoid - FP32 atomics on Apple GPUs. - -See [philipturner/metal-flash-attention](https://github.com/philipturner/metal-flash-attention) -for a tuned reference (Swift + MSL, ~83 % ALU on M1 Max). - -K is host-transposed to `K_t` before launch — Metal.jl's `simdgroup_load` -issues a transposed-from-MSL load to compensate for Julia's column-major -storage, so `Q · K_t` in the kernel equals mathematical `Q · K^T`. -Exposing a `transpose=false` variant of `simdgroup_load` would let the -host transpose drop; that's a small follow-up to -`src/device/intrinsics/simd.jl`. - -Works on macOS 13+ / M1+. - -## Not included: `fa_metal4.jl` - -A third path would use the Metal 4 `cooperative_tensor` / -`tensor_ops::matmul2d` primitives with postfix-fusion of the softmax -epilogue. Apple positions this as the "preferred programming model for -ML applications" — on M5 hardware it can issue Neural-Accelerator MMAs -and skip threadgroup memory entirely. - -That path is not yet wired up in Metal.jl. The Objective-C classes are -already generated in `lib/mtl/libmtl.jl` (gated on `macos(v"26.0.0")`); -what remains is a Julia-side `MtlCooperativeTensor` wrapper plus a -host-side `MTLTensor` / `MTL4ComputeCommandEncoder` binding. Note that -the device-side ops lower to externally-defined -`__tensorops_impl_matmul2d_op_*` symbols rather than `air.*` intrinsics, -so the binding pattern differs from the SIMD-group case. Validation -needs macOS 26 + Xcode 26; M5 hardware is required to see the -Neural-Accelerator speedup. - -## References - - - Apple — *Discover Metal 4* (WWDC25 Session 205) and - *Combine Metal 4 machine learning and graphics* (WWDC25 Session 262). - - Apple — *Metal Performance Primitives Programming Guide* (PDF, 2025). - - philipturner — [metal-flash-attention](https://github.com/philipturner/metal-flash-attention). - - llama.cpp — [Metal 4 cooperative-tensor FA backend (PR #16634)](https://github.com/ggml-org/llama.cpp/pull/16634). - - liuliu — [example_matmul_metal4](https://github.com/liuliu/example_matmul_metal4) (minimal MSL probe for the Metal 4 host API). diff --git a/examples/flash_attention/fa_mps.jl b/examples/flash_attention/fa_mps.jl deleted file mode 100644 index 860e7c094..000000000 --- a/examples/flash_attention/fa_mps.jl +++ /dev/null @@ -1,44 +0,0 @@ -# Scaled-dot-product attention via the existing MPS / MPSGraph dispatch. -# -# This is the trivial baseline: rely on Metal.jl's automatic dispatch of `*` -# to MPSMatrixMultiplication / MPSGraph and on the GPUArrays broadcast for -# the softmax. It is NOT a Flash Attention algorithm — the full N×N scores -# matrix is materialized in device memory — but it is the right reference -# implementation to verify a custom kernel against. - -using Metal -using Test - -function attention_reference(Q, K, V) - d = size(Q, 2) - scale = inv(sqrt(eltype(Q)(d))) - - S = (Q * K') .* scale # N_q × N_kv, dispatched to MPS - S = S .- maximum(S; dims = 2) # row-wise max for numerical stability - P = exp.(S) - P = P ./ sum(P; dims = 2) # row-wise softmax - - return P * V # N_q × D, dispatched to MPS -end - -let - N_q, N_kv, D = 64, 64, 32 - T = Float32 - - Q = MtlArray(randn(T, N_q, D)) - K = MtlArray(randn(T, N_kv, D)) - V = MtlArray(randn(T, N_kv, D)) - - O_gpu = attention_reference(Q, K, V) - - # CPU reference - Qh, Kh, Vh = Array(Q), Array(K), Array(V) - scale = inv(sqrt(T(D))) - S = (Qh * Kh') .* scale - S .-= maximum(S; dims = 2) - P = exp.(S) - P ./= sum(P; dims = 2) - O_cpu = P * Vh - - @test Array(O_gpu) ≈ O_cpu rtol = sqrt(eps(T)) -end diff --git a/examples/flash_attention/fa_mpsgraph.jl b/examples/flash_attention/fa_mpsgraph.jl deleted file mode 100644 index d0f904aab..000000000 --- a/examples/flash_attention/fa_mpsgraph.jl +++ /dev/null @@ -1,83 +0,0 @@ -# Fused scaled-dot-product attention via MPSGraph's -# `scaledDotProductAttentionWithQueryTensor` op (macOS 14+). One graph node -# does Q·Kᵀ → scale → softmax → ·V end-to-end; MPSGraph picks the kernel. -# -# Higher-level than `fa_mps.jl`: there's no host-side composition of -# matmul + broadcast + softmax + matmul. Whether MPSGraph internally -# materializes the N×N scores matrix is a black box, but on macOS 14+ -# Apple uses this same op as their backbone SDPA implementation -# (MLX falls back to it; Core ML lowers attention to it). -# -# This example uses the 4-D `(head_dim, seq, num_heads, batch)` Julia -# layout, which MPSGraph sees reversed as `(batch, num_heads, seq, head_dim)` -# — Apple's expected SDPA layout. - -using Metal -using Metal.MPS: MPSCommandBuffer, commit!, wait_completed -using Metal.MPSGraphs: MPSGraph, MPSGraphTensor, MPSGraphTensorData, - placeholderTensor, scaledDotProductAttentionWithQueryTensor, - encode!, default_exec_desc -using Metal.Foundation: NSDictionary, nil -using Test - -function mpsgraph_attention(Q::MtlArray{T,4}, K::MtlArray{T,4}, V::MtlArray{T,4}; - scale = inv(sqrt(T(size(Q, 1))))) where {T} - @assert size(Q, 1) == size(K, 1) == size(V, 1) "head dim mismatch" - @assert size(K, 2) == size(V, 2) "K/V seq length mismatch" - @assert size(Q)[3:4] == size(K)[3:4] == size(V)[3:4] "(heads, batch) mismatch" - - O = similar(Q) - - graph = MPSGraph() - qph = placeholderTensor(graph, size(Q), T) - kph = placeholderTensor(graph, size(K), T) - vph = placeholderTensor(graph, size(V), T) - out = scaledDotProductAttentionWithQueryTensor(graph, qph, kph, vph, - Float32(scale)) - - feeds = Dict{MPSGraphTensor, MPSGraphTensorData}( - qph => MPSGraphTensorData(Q), - kph => MPSGraphTensorData(K), - vph => MPSGraphTensorData(V), - ) - results = Dict{MPSGraphTensor, MPSGraphTensorData}( - out => MPSGraphTensorData(O), - ) - - cmdbuf = MPSCommandBuffer(Metal.global_queue(device())) - encode!(cmdbuf, graph, NSDictionary(feeds), NSDictionary(results), - nil, default_exec_desc()) - commit!(cmdbuf) - wait_completed(cmdbuf) - return O -end - -# CPU reference attention, 4-D (head_dim, seq, num_heads, batch) Julia layout. -function cpu_attention(Q, K, V; scale = inv(sqrt(eltype(Q)(size(Q, 1))))) - D, N_q, H, B = size(Q) - N_kv = size(K, 2) - O = similar(Q) - for b in 1:B, h in 1:H - Qm, Km, Vm = Q[:, :, h, b], K[:, :, h, b], V[:, :, h, b] # each is (D, N) - S = (Qm' * Km) .* scale # (N_q, N_kv) - S .-= maximum(S; dims = 2) - P = exp.(S) - P ./= sum(P; dims = 2) - O[:, :, h, b] = Vm * P' # (D, N_q) - end - return O -end - -let - T = Float32 - D, N_q, N_kv, H, B = 16, 24, 32, 2, 1 - - Q = MtlArray(randn(T, D, N_q, H, B)) - K = MtlArray(randn(T, D, N_kv, H, B)) - V = MtlArray(randn(T, D, N_kv, H, B)) - - O_gpu = mpsgraph_attention(Q, K, V) - O_cpu = cpu_attention(Array(Q), Array(K), Array(V)) - - @test Array(O_gpu) ≈ O_cpu rtol = sqrt(eps(T)) -end diff --git a/examples/flash_attention/fa_simdgroup.jl b/examples/flash_attention/fa_simdgroup.jl deleted file mode 100644 index ab3958e42..000000000 --- a/examples/flash_attention/fa_simdgroup.jl +++ /dev/null @@ -1,115 +0,0 @@ -# Single-block scaled dot-product attention kernel built from -# `MtlSimdgroupMatrix`. This is the smallest readable example of composing -# the SIMD-group matrix primitives into an attention-style kernel. -# -# Scope: -# - one simdgroup (32 lanes) per threadgroup, one threadgroup total -# - Q, K_t, V are fixed at 8×8 (Br = Bc = D = 8) -# - K is host-transposed to K_t so the in-kernel matmul Q · K_t equals -# mathematical Q · K^T (Metal.jl's `simdgroup_load` always issues a -# transposed-from-MSL load to read Julia's column-major data, so -# getting K^T from a col-major K without an extra binding is awkward) -# - softmax is done in scalar code through threadgroup memory -# -# A "real" FA kernel adds a KV-block loop with online-softmax state -# (`m`, `l` per row) and tiles D across multiple simdgroups; see -# philipturner/metal-flash-attention for a production reference. - -using Metal -using Test - -const Br = 8 -const Bc = 8 -const D = 8 - -function fa_kernel!(O::AbstractMatrix{Float16}, - Q::AbstractMatrix{Float16}, - K_t::AbstractMatrix{Float16}, - V::AbstractMatrix{Float16}, - scale::Float32) - # Stage scratch. - Ss = MtlThreadGroupArray(Float32, (Br, Bc)) # scores, then P - Sh = MtlThreadGroupArray(Float16, (Br, Bc)) # P cast back to fp16 - - # 1. S = Q · K_t (single 8x8 simdgroup_matrix multiply) - Qm = simdgroup_load(MtlSimdgroupMatrix{Float16, 8, 8}, Q) - Km = simdgroup_load(MtlSimdgroupMatrix{Float16, 8, 8}, K_t) - Sm = Qm * Km - - # 2. Spill to threadgroup memory for the row-wise softmax in scalar code. - Sh_tmp = MtlThreadGroupArray(Float16, (Br, Bc)) - simdgroup_store(Sm, Sh_tmp) - threadgroup_barrier(Metal.MemoryFlagThreadGroup) - - # Cast to Float32 and scale; 32 threads cover the 64 elements at 2 per lane. - tid = Int(thread_index_in_threadgroup()) - 1 # 0..31 - @inbounds for k in 0:1 - idx = tid * 2 + k - r = idx ÷ Bc + 1 - c = idx % Bc + 1 - Ss[r, c] = Float32(Sh_tmp[r, c]) * scale - end - threadgroup_barrier(Metal.MemoryFlagThreadGroup) - - # 3. Row-wise softmax. 8 of 32 lanes do real work; the rest idle. - if tid < Br - m = -Inf32 - @inbounds for j in 1:Bc - v = Ss[tid + 1, j] - m = v > m ? v : m - end - s = 0.0f0 - @inbounds for j in 1:Bc - p = exp(Ss[tid + 1, j] - m) - Ss[tid + 1, j] = p - s += p - end - inv_s = 1.0f0 / s - @inbounds for j in 1:Bc - Sh[tid + 1, j] = Float16(Ss[tid + 1, j] * inv_s) - end - end - threadgroup_barrier(Metal.MemoryFlagThreadGroup) - - # 4. O = P · V (second 8x8 simdgroup_matrix multiply) - Pm = simdgroup_load(MtlSimdgroupMatrix{Float16, 8, 8}, Sh) - Vm = simdgroup_load(MtlSimdgroupMatrix{Float16, 8, 8}, V) - Om = Pm * Vm - - simdgroup_store(Om, O) - return -end - -function flash_attention(Q::MtlMatrix{Float16}, K::MtlMatrix{Float16}, - V::MtlMatrix{Float16}) - @assert size(Q) == (Br, D) "Q must be ($Br, $D)" - @assert size(K) == (Bc, D) "K must be ($Bc, $D)" - @assert size(V) == (Bc, D) "V must be ($Bc, $D)" - - K_t = MtlMatrix(collect(transpose(Array(K)))) - O = similar(Q) - scale = inv(sqrt(Float32(D))) - - Metal.@sync @metal threads = 32 fa_kernel!(O, Q, K_t, V, scale) - return O -end - -let - T = Float16 - Q = MtlArray(rand(T, Br, D)) - K = MtlArray(rand(T, Bc, D)) - V = MtlArray(rand(T, Bc, D)) - - O = flash_attention(Q, K, V) - - # Reference attention computed in Float32 on the CPU. - Qh, Kh, Vh = Float32.(Array(Q)), Float32.(Array(K)), Float32.(Array(V)) - scale = inv(sqrt(Float32(D))) - S = (Qh * Kh') .* scale - S .-= maximum(S; dims = 2) - P = exp.(S) - P ./= sum(P; dims = 2) - O_ref = P * Vh - - @test Float32.(Array(O)) ≈ O_ref rtol = 1e-2 -end diff --git a/examples/flashattention.jl b/examples/flashattention.jl new file mode 100644 index 000000000..bd98f5520 --- /dev/null +++ b/examples/flashattention.jl @@ -0,0 +1,230 @@ +# Flash Attention reference implementations on Apple Silicon. +# +# Three ways to spell scaled dot-product attention on Metal, illustrating +# the programming models Metal.jl exposes: +# +# attention_mps(Q, K, V) +# The trivial baseline. Uses standard Julia operators (`*`, +# broadcasting, `maximum`, `sum`, `exp`) on `MtlArray`. The matrix +# multiplies are dispatched to MPSGraph / MPSMatrixMultiplication by +# `src/linalg.jl`; the rest is GPUArrays. Not actually a Flash +# Attention algorithm — the full N×N scores matrix is materialized +# in device memory — but it's the right reference and the fastest +# path to "attention runs on GPU" when you don't need a custom +# kernel. Works on macOS 13+ / M1+. +# +# attention_mpsgraph(Q, K, V) +# The high-level MPS path. Builds a one-node MPSGraph using +# `scaledDotProductAttentionWithQueryTensor` (macOS 14+), which +# fuses Q·Kᵀ → scale → softmax → ·V into a single op. Apple uses +# the same op as the backbone of their own SDPA paths (MLX falls +# back to it; Core ML lowers attention to it), so it's the closest +# thing to "ask Apple for attention" that Metal.jl can give you. +# +# attention_simdgroup(Q, K, V) +# A single-block scaled dot-product attention kernel built from +# `MtlSimdgroupMatrix{Float16, 8, 8}`. One simdgroup of 32 lanes +# does the QKᵀ and PV matrix multiplies via two `simdgroup_matrix` +# ops; the row-wise softmax is done in scalar code through +# threadgroup memory. Limited to N = D = 8, single head, single +# batch — illustrative, not production. See +# https://github.com/philipturner/metal-flash-attention for a +# tuned reference. Works on macOS 13+ / M1+. +# +# A fourth path would use the Metal 4 `cooperative_tensor` / +# `tensor_ops::matmul2d` primitives with postfix-fusion of the softmax +# epilogue. Apple positions this as the preferred programming model for +# ML on M5; on M3/M4 it lowers to the same simdgroup MMA hardware the +# `attention_simdgroup` path already drives. That path isn't yet wired up +# in Metal.jl — the ObjC classes are generated in `lib/mtl/libmtl.jl` +# (gated on `macos(v"26.0.0")`), but the host-side `MTLTensor` / +# `MTL4ComputeCommandEncoder` wrappers and the device-side +# `MtlCooperativeTensor` are not. Note that the device-side ops lower to +# externally-defined `__tensorops_impl_matmul2d_op_*` symbols rather than +# `air.*` intrinsics, so the binding pattern differs from the simdgroup +# case. +# +# All three implementations take Julia 4-D `(head_dim, seq, num_heads, +# batch)` inputs — MPSGraph sees these reversed as `(batch, num_heads, +# seq, head_dim)`, the layout Apple's SDPA expects. + +using Metal +using Test +using BenchmarkTools + +using Metal.MPS: MPSCommandBuffer, commit!, wait_completed +using Metal.MPSGraphs: MPSGraph, MPSGraphTensor, MPSGraphTensorData, + placeholderTensor, scaledDotProductAttentionWithQueryTensor, + encode!, default_exec_desc +using Metal.Foundation: NSDictionary, nil + + +## MPS / GPUArrays path + +function attention_mps(Q::MtlArray{T,4}, K::MtlArray{T,4}, V::MtlArray{T,4}; + scale = inv(sqrt(T(size(Q, 1))))) where {T} + _, _, H, B = size(Q) + O = similar(Q) + for b in 1:B, h in 1:H + Qm, Km, Vm = Q[:, :, h, b], K[:, :, h, b], V[:, :, h, b] + S = (transpose(Qm) * Km) .* scale + S = S .- maximum(S; dims = 2) + P = exp.(S) + P = P ./ sum(P; dims = 2) + O[:, :, h, b] = Vm * transpose(P) + end + return O +end + + +## MPSGraph SDPA path + +function attention_mpsgraph(Q::MtlArray{T,4}, K::MtlArray{T,4}, V::MtlArray{T,4}; + scale = inv(sqrt(T(size(Q, 1))))) where {T} + O = similar(Q) + + graph = MPSGraph() + qph = placeholderTensor(graph, size(Q), T) + kph = placeholderTensor(graph, size(K), T) + vph = placeholderTensor(graph, size(V), T) + out = scaledDotProductAttentionWithQueryTensor(graph, qph, kph, vph, + Float32(scale)) + + feeds = Dict{MPSGraphTensor, MPSGraphTensorData}( + qph => MPSGraphTensorData(Q), + kph => MPSGraphTensorData(K), + vph => MPSGraphTensorData(V), + ) + results = Dict{MPSGraphTensor, MPSGraphTensorData}( + out => MPSGraphTensorData(O), + ) + + cmdbuf = MPSCommandBuffer(Metal.global_queue(device())) + encode!(cmdbuf, graph, NSDictionary(feeds), NSDictionary(results), nil, + default_exec_desc()) + commit!(cmdbuf) + wait_completed(cmdbuf) + return O +end + + +## Custom kernel with MtlSimdgroupMatrix + +function _fa_kernel!(O::AbstractMatrix{Float16}, + Q::AbstractMatrix{Float16}, + K_t::AbstractMatrix{Float16}, + V::AbstractMatrix{Float16}, + scale::Float32) + Ss = MtlThreadGroupArray(Float32, (8, 8)) # scores, then P + Sh = MtlThreadGroupArray(Float16, (8, 8)) # P cast back to fp16 + + # 1. S = Q · K_t (single 8x8 simdgroup_matrix multiply) + Qm = simdgroup_load(MtlSimdgroupMatrix{Float16, 8, 8}, Q) + Km = simdgroup_load(MtlSimdgroupMatrix{Float16, 8, 8}, K_t) + Sm = Qm * Km + + # 2. Spill to threadgroup memory for the row-wise softmax in scalar code. + Sh_tmp = MtlThreadGroupArray(Float16, (8, 8)) + simdgroup_store(Sm, Sh_tmp) + threadgroup_barrier(Metal.MemoryFlagThreadGroup) + + tid = Int(thread_index_in_threadgroup()) - 1 # 0..31 + @inbounds for k in 0:1 + idx = tid * 2 + k + r = idx ÷ 8 + 1 + c = idx % 8 + 1 + Ss[r, c] = Float32(Sh_tmp[r, c]) * scale + end + threadgroup_barrier(Metal.MemoryFlagThreadGroup) + + # 3. Row-wise softmax. 8 of 32 lanes do real work. + if tid < 8 + m = -Inf32 + @inbounds for j in 1:8 + v = Ss[tid + 1, j] + m = v > m ? v : m + end + s = 0.0f0 + @inbounds for j in 1:8 + p = exp(Ss[tid + 1, j] - m) + Ss[tid + 1, j] = p + s += p + end + inv_s = 1.0f0 / s + @inbounds for j in 1:8 + Sh[tid + 1, j] = Float16(Ss[tid + 1, j] * inv_s) + end + end + threadgroup_barrier(Metal.MemoryFlagThreadGroup) + + # 4. O = P · V (second 8x8 simdgroup_matrix multiply) + Pm = simdgroup_load(MtlSimdgroupMatrix{Float16, 8, 8}, Sh) + Vm = simdgroup_load(MtlSimdgroupMatrix{Float16, 8, 8}, V) + Om = Pm * Vm + + simdgroup_store(Om, O) + return +end + +function attention_simdgroup(Q::MtlArray{Float16,4}, K::MtlArray{Float16,4}, + V::MtlArray{Float16,4}; + scale = inv(sqrt(Float32(size(Q, 1))))) + @assert size(Q) == size(K) == size(V) == (8, 8, 1, 1) "simdgroup kernel only handles (D=8, N=8, H=1, B=1)" + + # Inputs are (D, N, 1, 1). Kernel works with Q and V in (N, D), and K_t in + # (D, N) — the latter is the user-facing K already transposed, so reshape + # of the (D, N) slice gives us K_t for free. + Q2 = permutedims(reshape(Q, 8, 8), (2, 1)) + V2 = permutedims(reshape(V, 8, 8), (2, 1)) + K_t = reshape(K, 8, 8) + O2 = similar(Q2) + + Metal.@sync @metal threads = 32 _fa_kernel!(O2, Q2, K_t, V2, Float32(scale)) + return reshape(permutedims(O2, (2, 1)), 8, 8, 1, 1) +end + + +## CPU reference + driver + +function attention_cpu(Q, K, V; scale = inv(sqrt(eltype(Q)(size(Q, 1))))) + D, N_q, H, B = size(Q) + O = similar(Q) + for b in 1:B, h in 1:H + Qm, Km, Vm = Q[:, :, h, b], K[:, :, h, b], V[:, :, h, b] + S = (Qm' * Km) .* scale + S .-= maximum(S; dims = 2) + P = exp.(S) + P ./= sum(P; dims = 2) + O[:, :, h, b] = Vm * P' + end + return O +end + +function main() + T = Float16 # simdgroup path requires fp16 + D = N = 8 # constrained by the simdgroup kernel + + Q = MtlArray(randn(T, D, N, 1, 1)) + K = MtlArray(randn(T, D, N, 1, 1)) + V = MtlArray(randn(T, D, N, 1, 1)) + + O_cpu = attention_cpu(Array(Q), Array(K), Array(V)) + O_mps = attention_mps(Q, K, V) + O_mpsgraph = attention_mpsgraph(Q, K, V) + O_simdgroup = attention_simdgroup(Q, K, V) + + @test Array(O_mps) ≈ O_cpu rtol = 1e-2 + @test Array(O_mpsgraph) ≈ O_cpu rtol = 1e-2 + @test Array(O_simdgroup) ≈ O_cpu rtol = 1e-2 + + if get(ENV, "TESTING", "false") != "true" + println("\nattention_mps:") + @btime Metal.@sync attention_mps($Q, $K, $V) + println("attention_mpsgraph:") + @btime Metal.@sync attention_mpsgraph($Q, $K, $V) + println("attention_simdgroup:") + @btime Metal.@sync attention_simdgroup($Q, $K, $V) + end +end + +isinteractive() || main() From e80db1d82fb88e47ab9b7c50c0c2fdd9eb8cd310 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Sun, 24 May 2026 13:45:06 +0200 Subject: [PATCH 07/24] Drop benchmark section from flashattention example. --- examples/flashattention.jl | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/examples/flashattention.jl b/examples/flashattention.jl index bd98f5520..4fe28df02 100644 --- a/examples/flashattention.jl +++ b/examples/flashattention.jl @@ -50,7 +50,6 @@ using Metal using Test -using BenchmarkTools using Metal.MPS: MPSCommandBuffer, commit!, wait_completed using Metal.MPSGraphs: MPSGraph, MPSGraphTensor, MPSGraphTensorData, @@ -216,15 +215,6 @@ function main() @test Array(O_mps) ≈ O_cpu rtol = 1e-2 @test Array(O_mpsgraph) ≈ O_cpu rtol = 1e-2 @test Array(O_simdgroup) ≈ O_cpu rtol = 1e-2 - - if get(ENV, "TESTING", "false") != "true" - println("\nattention_mps:") - @btime Metal.@sync attention_mps($Q, $K, $V) - println("attention_mpsgraph:") - @btime Metal.@sync attention_mpsgraph($Q, $K, $V) - println("attention_simdgroup:") - @btime Metal.@sync attention_simdgroup($Q, $K, $V) - end end isinteractive() || main() From b4da06e3b14172f922e65ba88e8f28869f7620d9 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 26 May 2026 15:57:29 +0200 Subject: [PATCH 08/24] Add device-side Metal 4 tensor_ops::matmul2d wrappers. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wraps the tensor_inline form of Metal 4 tensor_ops. Kernel args stay buffer-shaped MtlDeviceArrays — no host-side MTLTensor or MTL4 command encoder wrapping is needed. The descriptors and matmul2d run helper lower to the externally-defined __tensorops_impl_* family in the MetalPerformancePrimitives runtime; AIR construction uses the air.*_private_tensor intrinsics. Per-thread descriptor storage is a Ref{NTuple{N, UInt8}} that allocopt promotes to a stack alloca (every constructor is @inline'd into the kernel, and the gc-managed object only escapes via pointer_from_objref). Reference IR for the three kernel shapes (handle / cooperative / inline) is in bin/{simple,coop,inline}_matmul.*. --- ISSUE-tensor-ops.md | 133 +++++++++++++++ bin/coop_matmul.ll | 223 ++++++++++++++++++++++++ bin/coop_matmul.metal | 32 ++++ bin/inline_matmul.ll | 290 ++++++++++++++++++++++++++++++++ bin/inline_matmul.metal | 32 ++++ bin/simple_matmul.ll | 128 ++++++++++++++ bin/simple_matmul.metal | 22 +++ src/Metal.jl | 1 + src/device/intrinsics/tensor.jl | 219 ++++++++++++++++++++++++ 9 files changed, 1080 insertions(+) create mode 100644 ISSUE-tensor-ops.md create mode 100644 bin/coop_matmul.ll create mode 100644 bin/coop_matmul.metal create mode 100644 bin/inline_matmul.ll create mode 100644 bin/inline_matmul.metal create mode 100644 bin/simple_matmul.ll create mode 100644 bin/simple_matmul.metal create mode 100644 src/device/intrinsics/tensor.jl diff --git a/ISSUE-tensor-ops.md b/ISSUE-tensor-ops.md new file mode 100644 index 000000000..3153cab47 --- /dev/null +++ b/ISSUE-tensor-ops.md @@ -0,0 +1,133 @@ +# Metal 4 tensor ops (matmul2d / cooperative_tensor) — status + +## What's working + +`examples/flashattention.jl` now has an `attention_tensor(Q, K, V)` path that +dispatches the two attention matmuls via the Metal 4 `tensor_ops::matmul2d` +primitives. It matches the CPU reference at `D = N = 64`, single head, single +batch. Requires macOS 26+. + +The device-side wrappers live in `src/device/intrinsics/tensor.jl`: + +- `MtlInlineTensor{T, R}` — kernel-stack tensor view (`tensor_inline` form) + over an `MtlDeviceArray`. Built via `air.init_strided_private_tensor`. The + per-thread tensor descriptor is held by a `Ref{NTuple{64, UInt8}}` — + Julia's `llvm-alloc-opt` pass promotes it to a stack alloca because every + use is `@inline`d into the kernel and the gc-managed object only escapes + via `pointer_from_objref` (which `allocopt` treats as `addrescaped`, not + `escaped`). `GC.@preserve` around the ccalls keeps the buffer alive + across the runtime calls. +- `matmul2d_descriptor(m, n, k=-1; transpose_left, transpose_right, + relaxed_precision, mode)` — 20-byte POD matching + `mpp::tensor_ops::matmul2d_descriptor`. +- `tensor_ops_matmul2d!(desc, left, right, dest, threads)` — dispatches one + of `__tensorops_impl_matmul2d_op_run_dv_{tl}_dv_{tr}_dv_{td}` based on the + element types of the operand tensors. `threads` must equal + `simdgroup_size * num_simdgroups` for the descriptor's scope. + +The inline-tensor route lets us reuse the existing Metal.jl kernel ABI: +kernel args are still `MtlDeviceArray`s, so no host-side `MTLTensor` / +`MTL4ComputeCommandEncoder` wrapping is needed. + +The GPUCompiler bits: + +- `GPUCompiler/src/metal.jl` `isintrinsic` whitelists `__tensorops_impl_` + symbols (alongside `air.`). +- `annotate_air_intrinsics!` attaches `section "air.externally_defined"` and + `(convergent, nounwind)` attributes to `__tensorops_impl_*` declarations. + Without the section attribute, the metallib back-end won't resolve the + symbol from the MetalPerformancePrimitives runtime. + +## What's not working / known limitations + +- **Two `__tensorops_impl_matmul2d_op_run_*` calls in one kernel crash the + Metal back-end** (`XPC_ERROR_CONNECTION_INTERRUPTED` from + `AGXMetalG15X_M1`). The attention example sidesteps this by splitting QK + and PV into two dispatches. This is likely an Apple compiler bug — the IR + we emit looks structurally identical to single-matmul kernels that compile + fine. Worth filing upstream. +- **No `cooperative_tensor` yet.** That means the softmax epilogue can't be + done in registers — the scores tile is materialized in device memory. A + proper Flash Attention would fuse the softmax into the cooperative tensor + between the two matmuls. +- **No `tensor_handle` kernel args.** Apple's matmul samples (and the bulk of + the MPP docs) describe tensors as host-bound `MTLTensor` parameters that + arrive in the kernel as opaque `%struct._tensor_t addrspace(1)*`. That + requires both a host-side `MTL4ArgumentTable` / `MTLTensor` wrapping and a + Metal.jl kernel-ABI rewrite. Inline tensors give us most of the + expressiveness without any of that. +- **No threadgroup-memory matmul.** Only `dv_*` (device-memory) variants of + the run helpers are wrapped. `tg_*` variants would let us stage tiles into + threadgroup memory. +- **`D == N` only.** The attention example uses one matmul descriptor sized + to a single 64×64 tile; supporting arbitrary `D, N` means dispatching + multiple threadgroups and tiling on the host. + +## Reverse-engineering reference + +Annotated AIR for the kernels we generate Apple-style equivalents for: + +- `bin/simple_matmul.metal` / `bin/simple_matmul.ll` — minimal NN matmul, + device-memory destination, `tensor_handle` parameters. +- `bin/coop_matmul.metal` / `bin/coop_matmul.ll` — cooperative-tensor + destination with a trivial scale-by-2 postfix epilogue. Closest template + for the proper Flash Attention path. +- `bin/inline_matmul.metal` / `bin/inline_matmul.ll` — the `tensor_inline` + form that Metal.jl actually uses. Matches the IR shape our wrappers emit. + +Apple's headers: + +- `/usr/metal//lib/clang//include/metal/{metal_tensor,metal_cooperative_tensor}` +- `/System/Library/Frameworks/MetalPerformancePrimitives.framework/Versions/A/Headers/{MPPTensorOpsMatMul2d.h,__impl/MPPTensorOpsMatMul2dImpl.h}` + +### AIR shapes used by our wrappers + +Inline tensor construction (`air.*` intrinsics, in `i32`-indexed flavor): + +```llvm +i16 @air.get_descriptor_size_tensor(i16 rank, i16 index_size) +void @air.init_strided_private_tensor.i32.global(i8* %handle, i16 rank, + i8 addrspace(1)* %data, + i8* %extents, i8* %strides, + i8 %contiguous) +i32 @air.get_extent_private_tensor.i32(i8* %handle, i16 rank, i16 dim) +void @air.slice_private_tensor_private_tensor.s.i32(i8* %dst, i8* %src, + i16 rank, i8* %origin, + i8* %extents) +``` + +Matmul run (externally-defined, `section "air.externally_defined"`): + +```llvm +void @__tensorops_impl_matmul2d_op_run_dv_{tl}_dv_{tr}_dv_{td}( + %"struct.matmul2d_descriptor"* %desc, + i8* %left, i32 %left_desc_type, + i8* %right, i32 %right_desc_type, + i8* %destination, i32 %dest_desc_type, + i32 %threads) +``` + +`{tl}, {tr}, {td}` are element-type suffixes (`f16`, `f32`, `bf16`, `i8`, …) +and the descriptor types are `1` for `tensor_handle`, `2` for +`tensor_inline`. + +## What's still TODO + +In rough order of value: + +1. **`MtlCooperativeTensor`** — would enable the proper Flash Attention + postfix-fusion path. Needs dynamic stack allocation (the Apple compiler + emits `alloca i8, i64 %sz` where `%sz` comes from + `__tensorops_impl_matmul2d_op_cooperative_tensor_data_size` and is marked + `"deferred-static-alloca-size"`). Workaround: reserve a conservative + upper bound at compile time. +2. **Threadgroup-memory matmul variants.** Wrap `_tg_*` flavors of the run + helpers and let `MtlInlineTensor` accept a `MtlThreadGroupArray`. +3. **Tile decomposition.** Drop the `D == N == tile` constraint by + dispatching multiple threadgroups per matmul and slicing on `tgid`. +4. **`tensor_handle` kernel args + host-side `MTLTensor` / `MTL4` wrappers.** + The biggest piece, and the closest path to what Apple's samples + demonstrate. Inline tensors get us most of the way without it, so this + is now only worth doing if we want first-class interop with Apple's + tensor APIs (e.g., to consume an `MTLTensor` produced by some other + framework). diff --git a/bin/coop_matmul.ll b/bin/coop_matmul.ll new file mode 100644 index 000000000..80227be6b --- /dev/null +++ b/bin/coop_matmul.ll @@ -0,0 +1,223 @@ +; ModuleID = 'coop_matmul.metal' +source_filename = "coop_matmul.metal" +target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024-n8:16:32" +target triple = "air64_v28-apple-macosx26.0.0" + +%"struct.mpp::tensor_ops::matmul2d_descriptor" = type { i32, i32, i32, i8, i8, i8, i32 } +%struct._tensor_t = type opaque +%"struct.metal::tensor.6" = type { %"struct.metal::__tensor_base.7", %struct._tensor_t addrspace(1)* } +%"struct.metal::__tensor_base.7" = type { %"struct.metal::__tensor_offsets.8" } +%"struct.metal::__tensor_offsets.8" = type { %"struct.metal::array" } +%"struct.metal::array" = type { [2 x i32] } +%"struct.metal::tensor.3" = type { %"struct.metal::__tensor_base.4", %struct._tensor_t addrspace(1)* } +%"struct.metal::__tensor_base.4" = type { %"struct.metal::__tensor_offsets.5" } +%"struct.metal::__tensor_offsets.5" = type { %"struct.metal::array" } + +@_ZTAXtlN3mpp10tensor_ops19matmul2d_descriptorELi64ELi32ELin1EEE = linkonce_odr local_unnamed_addr constant %"struct.mpp::tensor_ops::matmul2d_descriptor" { i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0 } + +; Function Attrs: convergent nounwind +define void @coop_matmul(%struct._tensor_t addrspace(1)* %0, %struct._tensor_t addrspace(1)* %1, %struct._tensor_t addrspace(1)* %2, <2 x i32> noundef %3) local_unnamed_addr #0 { + %5 = alloca %"struct.mpp::tensor_ops::matmul2d_descriptor", align 4 + %6 = alloca %"struct.metal::tensor.6", align 8 + %7 = alloca %"struct.metal::tensor.3", align 8 + %8 = alloca %"struct.metal::tensor.3", align 8 + %9 = tail call i64 @_ZN5metal18cooperative_tensorIfNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEEN3mpp10tensor_ops17__mutmul2d_detail16__operand_layoutIXtlNS4_19matmul2d_descriptorELi64ELi32ELin1EEELNS5_36__matmul2d_cooperative_operand_indexE2ENS_20execution_simdgroupsILm4EEEDhDhfiJEEEEE.MTL_SIZEAS() #7 + %10 = alloca i8, i64 %9, align 4 + %11 = bitcast %"struct.metal::tensor.3"* %7 to i8* + call void @llvm.lifetime.start.p0i8(i64 16, i8* nonnull %11) #7 + %12 = extractelement <2 x i32> %3, i64 1 + %13 = shl i32 %12, 6 + %14 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %7, i64 0, i32 0, i32 0, i32 0, i32 0, i64 0 + store i32 0, i32* %14, align 8 + %15 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %7, i64 0, i32 0, i32 0, i32 0, i32 0, i64 1 + store i32 %13, i32* %15, align 4 + %16 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %7, i64 0, i32 1 + store %struct._tensor_t addrspace(1)* %0, %struct._tensor_t addrspace(1)** %16, align 8 + %17 = bitcast %"struct.metal::tensor.3"* %8 to i8* + call void @llvm.lifetime.start.p0i8(i64 16, i8* nonnull %17) #7 + %18 = extractelement <2 x i32> %3, i64 0 + %19 = shl i32 %18, 5 + %20 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %8, i64 0, i32 0, i32 0, i32 0, i32 0, i64 0 + store i32 %19, i32* %20, align 8 + %21 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %8, i64 0, i32 0, i32 0, i32 0, i32 0, i64 1 + store i32 0, i32* %21, align 4 + %22 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %8, i64 0, i32 1 + store %struct._tensor_t addrspace(1)* %1, %struct._tensor_t addrspace(1)** %22, align 8 + call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %10) + %23 = tail call i32 @air.get_simdgroup_size.i32() #8 + %24 = shl i32 %23, 2 + call void @__tensorops_impl_matmul2d_op_cooperative_tensor_init(i32 noundef 2, i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0, i8* noundef nonnull %10, i32 noundef 268435472, i32 noundef 268435472, i32 noundef 268435488, i32 noundef %24) #9 + br label %25 + +25: ; preds = %37, %4 + %26 = phi i16 [ 0, %4 ], [ %38, %37 ] + %27 = call zeroext i16 @__tensorops_impl_matmul2d_op_cooperative_tensor_num_elements(i32 noundef 2, i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0, i8* noundef nonnull %10, i32 noundef 268435472, i32 noundef 268435472, i32 noundef %24) #9 + %28 = icmp ult i16 %26, %27 + br i1 %28, label %32, label %29 + +29: ; preds = %25 + %30 = bitcast %"struct.mpp::tensor_ops::matmul2d_descriptor"* %5 to i8* + call void @llvm.lifetime.start.p0i8(i64 20, i8* nonnull %30) #7 + call void @llvm.memcpy.p0i8.p0i8.i64(i8* noundef nonnull align 4 dereferenceable(20) %30, i8* noundef nonnull align 4 dereferenceable(20) bitcast (%"struct.mpp::tensor_ops::matmul2d_descriptor"* @_ZTAXtlN3mpp10tensor_ops19matmul2d_descriptorELi64ELi32ELin1EEE to i8*), i64 20, i1 false) #7, !tbaa.struct !23 + %31 = call i8* @__tensorops_impl_matmul2d_op_cooperative_tensor_get_element_pointer(i32 noundef 2, i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0, i8* noundef nonnull %10, i16 noundef zeroext -1, i32 noundef 268435472, i32 noundef 268435472, i32 noundef 268435488) #9 + call void @__tensorops_impl_matmul2d_op_run_cooperative_dv_f16_dv_f16_f32(%"struct.mpp::tensor_ops::matmul2d_descriptor"* noundef nonnull align 4 dereferenceable(20) %5, i8* noundef nonnull %11, i32 noundef 1, i8* noundef nonnull %17, i32 noundef 1, i8* noundef %31, i32 noundef %24) #9 + call void @llvm.lifetime.end.p0i8(i64 20, i8* nonnull %30) #7 + br label %39 + +32: ; preds = %25 + %33 = call zeroext i1 @__tensorops_impl_matmul2d_op_cooperative_tensor_is_valid_element(i32 noundef 2, i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0, i8* noundef nonnull %10, i16 noundef zeroext %26, i32 noundef 268435472, i32 noundef 268435472, i32 noundef 268435488, i32 noundef %24) #9 + br i1 %33, label %34, label %37 + +34: ; preds = %32 + %35 = call i8* @__tensorops_impl_matmul2d_op_cooperative_tensor_get_element_pointer(i32 noundef 2, i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0, i8* noundef nonnull %10, i16 noundef zeroext %26, i32 noundef 268435472, i32 noundef 268435472, i32 noundef 268435488) #9 + %36 = bitcast i8* %35 to float* + store float 0.000000e+00, float* %36, align 4, !tbaa !32 + br label %37 + +37: ; preds = %32, %34 + %38 = add nuw i16 %26, 1 + br label %25, !llvm.loop !34 + +39: ; preds = %55, %29 + %40 = phi i16 [ 0, %29 ], [ %56, %55 ] + %41 = call zeroext i16 @__tensorops_impl_matmul2d_op_cooperative_tensor_num_elements(i32 noundef 2, i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0, i8* noundef nonnull %10, i32 noundef 268435472, i32 noundef 268435472, i32 noundef %24) #9 + %42 = icmp ult i16 %40, %41 + br i1 %42, label %48, label %43 + +43: ; preds = %39 + %44 = bitcast %"struct.metal::tensor.6"* %6 to i8* + call void @llvm.lifetime.start.p0i8(i64 16, i8* nonnull %44) + %45 = getelementptr inbounds %"struct.metal::tensor.6", %"struct.metal::tensor.6"* %6, i64 0, i32 0, i32 0, i32 0, i32 0, i64 0 + store i32 %19, i32* %45, align 8 + %46 = getelementptr inbounds %"struct.metal::tensor.6", %"struct.metal::tensor.6"* %6, i64 0, i32 0, i32 0, i32 0, i32 0, i64 1 + store i32 %13, i32* %46, align 4 + %47 = getelementptr inbounds %"struct.metal::tensor.6", %"struct.metal::tensor.6"* %6, i64 0, i32 1 + store %struct._tensor_t addrspace(1)* %2, %struct._tensor_t addrspace(1)** %47, align 8 + call void @__tensorops_impl_matmul2d_op_cooperative_tensor_store_dv_f32(i32 noundef 2, i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0, i8* noundef nonnull %10, i8* noundef nonnull %44, i32 noundef 1, i32 noundef 268435472, i32 noundef 268435472, i32 noundef 268435488, i32 noundef %24) #9 + call void @llvm.lifetime.end.p0i8(i64 16, i8* nonnull %44) + call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %10) #7 + call void @llvm.lifetime.end.p0i8(i64 16, i8* nonnull %17) #7 + call void @llvm.lifetime.end.p0i8(i64 16, i8* nonnull %11) #7 + ret void + +48: ; preds = %39 + %49 = call zeroext i1 @__tensorops_impl_matmul2d_op_cooperative_tensor_is_valid_element(i32 noundef 2, i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0, i8* noundef nonnull %10, i16 noundef zeroext %40, i32 noundef 268435472, i32 noundef 268435472, i32 noundef 268435488, i32 noundef %24) #9 + br i1 %49, label %50, label %55 + +50: ; preds = %48 + %51 = call i8* @__tensorops_impl_matmul2d_op_cooperative_tensor_get_element_pointer(i32 noundef 2, i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0, i8* noundef nonnull %10, i16 noundef zeroext %40, i32 noundef 268435472, i32 noundef 268435472, i32 noundef 268435488) #9 + %52 = bitcast i8* %51 to float* + %53 = load float, float* %52, align 4, !tbaa !32 + %54 = fmul fast float %53, 2.000000e+00 + store float %54, float* %52, align 4, !tbaa !32 + br label %55 + +55: ; preds = %48, %50 + %56 = add nuw i16 %40, 1 + br label %39, !llvm.loop !36 +} + +; Function Attrs: argmemonly mustprogress nocallback nofree nosync nounwind willreturn +declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #1 + +; Function Attrs: mustprogress nofree nosync readnone speculatable willreturn +define linkonce_odr hidden i64 @_ZN5metal18cooperative_tensorIfNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEEN3mpp10tensor_ops17__mutmul2d_detail16__operand_layoutIXtlNS4_19matmul2d_descriptorELi64ELi32ELin1EEELNS5_36__matmul2d_cooperative_operand_indexE2ENS_20execution_simdgroupsILm4EEEDhDhfiJEEEEE.MTL_SIZEAS() local_unnamed_addr #2 { + %1 = tail call i64 @_ZN3mpp10tensor_ops17__mutmul2d_detail16__operand_layoutIXtlNS0_19matmul2d_descriptorELi64ELi32ELin1EEELNS1_36__matmul2d_cooperative_operand_indexE2EN5metal20execution_simdgroupsILm4EEEDhDhfiJEE19thread_storage_sizeEv() #10 + ret i64 %1 +} + +; Function Attrs: convergent nounwind +define linkonce_odr i64 @_ZN3mpp10tensor_ops17__mutmul2d_detail16__operand_layoutIXtlNS0_19matmul2d_descriptorELi64ELi32ELin1EEELNS1_36__matmul2d_cooperative_operand_indexE2EN5metal20execution_simdgroupsILm4EEEDhDhfiJEE19thread_storage_sizeEv() local_unnamed_addr #3 align 2 { + %1 = tail call i32 @air.get_simdgroup_size.i32() #8 + %2 = shl i32 %1, 2 + %3 = tail call i64 @__tensorops_impl_matmul2d_op_cooperative_tensor_data_size(i32 noundef 2, i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0, i32 noundef 268435472, i32 noundef 268435472, i32 noundef 268435488, i32 noundef %2) #9 + ret i64 %3 +} + +; Function Attrs: argmemonly mustprogress nocallback nofree nosync nounwind willreturn +declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #1 + +; Function Attrs: argmemonly mustprogress nofree nounwind willreturn +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i64, i1 immarg) #4 + +; Function Attrs: convergent +declare void @__tensorops_impl_matmul2d_op_cooperative_tensor_init(i32 noundef, i32, i32, i32, i8, i8, i8, i32, i8* noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #5 section "air.externally_defined" + +; Function Attrs: mustprogress nofree nosync nounwind readnone willreturn +declare i32 @air.get_simdgroup_size.i32() local_unnamed_addr #6 + +; Function Attrs: convergent +declare i64 @__tensorops_impl_matmul2d_op_cooperative_tensor_data_size(i32 noundef, i32, i32, i32, i8, i8, i8, i32, i32 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #5 section "air.externally_defined" + +; Function Attrs: convergent +declare zeroext i16 @__tensorops_impl_matmul2d_op_cooperative_tensor_num_elements(i32 noundef, i32, i32, i32, i8, i8, i8, i32, i8* noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #5 section "air.externally_defined" + +; Function Attrs: convergent +declare zeroext i1 @__tensorops_impl_matmul2d_op_cooperative_tensor_is_valid_element(i32 noundef, i32, i32, i32, i8, i8, i8, i32, i8* noundef, i16 noundef zeroext, i32 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #5 section "air.externally_defined" + +; Function Attrs: convergent +declare i8* @__tensorops_impl_matmul2d_op_cooperative_tensor_get_element_pointer(i32 noundef, i32, i32, i32, i8, i8, i8, i32, i8* noundef, i16 noundef zeroext, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #5 section "air.externally_defined" + +; Function Attrs: convergent +declare void @__tensorops_impl_matmul2d_op_run_cooperative_dv_f16_dv_f16_f32(%"struct.mpp::tensor_ops::matmul2d_descriptor"* noundef nonnull align 4 dereferenceable(20), i8* noundef, i32 noundef, i8* noundef, i32 noundef, i8* noundef, i32 noundef) local_unnamed_addr #5 section "air.externally_defined" + +; Function Attrs: convergent +declare void @__tensorops_impl_matmul2d_op_cooperative_tensor_store_dv_f32(i32 noundef, i32, i32, i32, i8, i8, i8, i32, i8* noundef, i8* noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #5 section "air.externally_defined" + +attributes #0 = { convergent nounwind "approx-func-fp-math"="true" "frame-pointer"="all" "min-legal-vector-width"="64" "no-builtins" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="true" } +attributes #1 = { argmemonly mustprogress nocallback nofree nosync nounwind willreturn } +attributes #2 = { mustprogress nofree nosync readnone speculatable willreturn "deferred-static-alloca-size" } +attributes #3 = { convergent nounwind "approx-func-fp-math"="true" "frame-pointer"="all" "min-legal-vector-width"="0" "no-builtins" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="true" } +attributes #4 = { argmemonly mustprogress nofree nounwind willreturn } +attributes #5 = { convergent "approx-func-fp-math"="true" "frame-pointer"="all" "no-builtins" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="true" } +attributes #6 = { mustprogress nofree nosync nounwind readnone willreturn } +attributes #7 = { nounwind } +attributes #8 = { nounwind readnone willreturn } +attributes #9 = { convergent nobuiltin nounwind "no-builtins" } +attributes #10 = { convergent nobuiltin "no-builtins" } + +!llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7, !8} +!air.kernel = !{!9} +!air.compile_options = !{!16, !17, !18} +!llvm.ident = !{!19} +!air.version = !{!20} +!air.language_version = !{!21} +!air.source_file_name = !{!22} + +!0 = !{i32 2, !"SDK Version", [2 x i32] [i32 26, i32 2]} +!1 = !{i32 1, !"wchar_size", i32 4} +!2 = !{i32 7, !"frame-pointer", i32 2} +!3 = !{i32 7, !"air.max_device_buffers", i32 31} +!4 = !{i32 7, !"air.max_constant_buffers", i32 31} +!5 = !{i32 7, !"air.max_threadgroup_buffers", i32 31} +!6 = !{i32 7, !"air.max_textures", i32 128} +!7 = !{i32 7, !"air.max_read_write_textures", i32 8} +!8 = !{i32 7, !"air.max_samplers", i32 16} +!9 = !{void (%struct._tensor_t addrspace(1)*, %struct._tensor_t addrspace(1)*, %struct._tensor_t addrspace(1)*, <2 x i32>)* @coop_matmul, !10, !11} +!10 = !{} +!11 = !{!12, !13, !14, !15} +!12 = !{i32 0, !"air.tensor", !"air.location_index", i32 0, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.arg_type_name", !"tensor>", !"air.arg_name", !"A"} +!13 = !{i32 1, !"air.tensor", !"air.location_index", i32 1, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.arg_type_name", !"tensor>", !"air.arg_name", !"B"} +!14 = !{i32 2, !"air.tensor", !"air.location_index", i32 2, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.arg_type_name", !"tensor>", !"air.arg_name", !"C"} +!15 = !{i32 3, !"air.threadgroup_position_in_grid", !"air.arg_type_name", !"uint2", !"air.arg_name", !"tgid"} +!16 = !{!"air.compile.denorms_disable"} +!17 = !{!"air.compile.fast_math_enable"} +!18 = !{!"air.compile.framebuffer_fetch_enable"} +!19 = !{!"Apple metal version 32023.864 (metalfe-32023.864)"} +!20 = !{i32 2, i32 8, i32 0} +!21 = !{!"Metal", i32 4, i32 0, i32 0} +!22 = !{!"/private/tmp/metaltest/coop_matmul.metal"} +!23 = !{i64 0, i64 4, !24, i64 4, i64 4, !24, i64 8, i64 4, !24, i64 12, i64 1, !28, i64 13, i64 1, !28, i64 14, i64 1, !28, i64 16, i64 4, !30} +!24 = !{!25, !25, i64 0} +!25 = !{!"int", !26, i64 0} +!26 = !{!"omnipotent char", !27, i64 0} +!27 = !{!"Simple C++ TBAA"} +!28 = !{!29, !29, i64 0} +!29 = !{!"bool", !26, i64 0} +!30 = !{!31, !31, i64 0} +!31 = !{!"_ZTSN3mpp10tensor_ops19matmul2d_descriptor4modeE", !26, i64 0} +!32 = !{!33, !33, i64 0} +!33 = !{!"float", !26, i64 0} +!34 = distinct !{!34, !35} +!35 = !{!"llvm.loop.mustprogress"} +!36 = distinct !{!36, !35} diff --git a/bin/coop_matmul.metal b/bin/coop_matmul.metal new file mode 100644 index 000000000..26c51faa9 --- /dev/null +++ b/bin/coop_matmul.metal @@ -0,0 +1,32 @@ +#include +#include +#include +#include + +using namespace metal; +using namespace mpp::tensor_ops; + +kernel void coop_matmul(tensor> A, + tensor> B, + tensor> C, + uint2 tgid [[threadgroup_position_in_grid]]) +{ + constexpr auto desc = matmul2d_descriptor(64, 32, static_cast(dynamic_extent)); + matmul2d> op; + + auto mA = A.slice(0, tgid.y * 64); + auto mB = B.slice(tgid.x * 32, 0); + auto mC = C.slice(tgid.x * 32, tgid.y * 64); + + auto cT = op.get_destination_cooperative_tensor(); + for (uint16_t i = 0; i < cT.get_capacity(); ++i) { + if (cT.is_valid_element(i)) cT[i] = 0; + } + op.run(mA, mB, cT); + + // postfix-fuse: just scale + cast as a stand-in for softmax epilogue + for (uint16_t i = 0; i < cT.get_capacity(); ++i) { + if (cT.is_valid_element(i)) cT[i] *= 2.0f; + } + cT.store(mC); +} diff --git a/bin/inline_matmul.ll b/bin/inline_matmul.ll new file mode 100644 index 000000000..a9ed2aef6 --- /dev/null +++ b/bin/inline_matmul.ll @@ -0,0 +1,290 @@ +; ModuleID = 'inline_matmul.metal' +source_filename = "inline_matmul.metal" +target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024-n8:16:32" +target triple = "air64_v28-apple-macosx26.0.0" + +%"struct.mpp::tensor_ops::matmul2d_descriptor" = type { i32, i32, i32, i8, i8, i8, i32 } +%"struct.metal::array" = type { [2 x i32] } +%struct._tensor_t = type opaque + +@_ZTAXtlN3mpp10tensor_ops19matmul2d_descriptorELi64ELi32ELin1EEE = linkonce_odr local_unnamed_addr constant %"struct.mpp::tensor_ops::matmul2d_descriptor" { i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0 } + +; Function Attrs: convergent nounwind +define void @inline_matmul(half addrspace(1)* noundef "air-buffer-no-alias" %0, half addrspace(1)* noundef "air-buffer-no-alias" %1, float addrspace(1)* noundef "air-buffer-no-alias" %2, i32 addrspace(2)* nocapture noundef readonly align 4 dereferenceable(4) "air-buffer-no-alias" %3, i32 addrspace(2)* nocapture noundef readonly align 4 dereferenceable(4) "air-buffer-no-alias" %4, i32 addrspace(2)* nocapture noundef readonly align 4 dereferenceable(4) "air-buffer-no-alias" %5, <2 x i32> noundef %6) local_unnamed_addr #0 { + %8 = alloca %"struct.mpp::tensor_ops::matmul2d_descriptor", align 4 + %9 = alloca %"struct.metal::array", align 4 + %10 = alloca %"struct.metal::array", align 4 + %11 = alloca %"struct.metal::array", align 4 + %12 = alloca %"struct.metal::array", align 4 + %13 = alloca %"struct.metal::array", align 4 + %14 = alloca %"struct.metal::array", align 4 + %15 = alloca %"struct.metal::array", align 4 + %16 = alloca %"struct.metal::array", align 4 + %17 = alloca %"struct.metal::array", align 4 + %18 = alloca %"struct.metal::array", align 4 + %19 = alloca %"struct.metal::array", align 4 + %20 = alloca %"struct.metal::array", align 4 + %21 = tail call i64 @_ZN5metal6tensorIU9MTLdeviceDhNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEENS_13tensor_inlineEJEEE.MTL_SIZEAS() #7 + %22 = alloca i8, i64 %21, align 8 + %23 = alloca i8, i64 %21, align 8 + %24 = tail call i64 @_ZN5metal6tensorIU9MTLdevicefNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEENS_13tensor_inlineEJEEE.MTL_SIZEAS() #7 + %25 = alloca i8, i64 %24, align 8 + %26 = alloca i8, i64 %21, align 8 + %27 = alloca i8, i64 %21, align 8 + %28 = alloca i8, i64 %24, align 8 + call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %22) + %29 = load i32, i32 addrspace(2)* %5, align 4, !tbaa !26, !alias.scope !30, !noalias !33 + %30 = load i32, i32 addrspace(2)* %3, align 4, !tbaa !26, !alias.scope !39, !noalias !40 + %31 = bitcast i8* %22 to %struct._tensor_t* + %32 = bitcast half addrspace(1)* %0 to i8 addrspace(1)* + %33 = bitcast %"struct.metal::array"* %13 to i8* + call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %33) #7 + %34 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %13, i64 0, i32 0, i64 0 + store i32 %29, i32* %34, align 4, !tbaa !26 + %35 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %13, i64 0, i32 0, i64 1 + store i32 %30, i32* %35, align 4, !tbaa !26 + %36 = bitcast %"struct.metal::array"* %14 to i8* + call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %36) #7 + %37 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %14, i64 0, i32 0, i64 0 + store i32 1, i32* %37, align 4, !tbaa !26 + %38 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %14, i64 0, i32 0, i64 1 + store i32 %29, i32* %38, align 4, !tbaa !26 + call void @air.init_strided_private_tensor.i32.global(%struct._tensor_t* nocapture nonnull writeonly %31, i16 2, i8 addrspace(1)* readnone %32, i8* nocapture nonnull readonly %33, i8* nocapture nonnull readonly %36, i8 1) #8 + call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %36) #7 + call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %33) #7 + call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %23) + %39 = load i32, i32 addrspace(2)* %4, align 4, !tbaa !26, !alias.scope !41, !noalias !42 + %40 = bitcast i8* %23 to %struct._tensor_t* + %41 = bitcast half addrspace(1)* %1 to i8 addrspace(1)* + %42 = bitcast %"struct.metal::array"* %11 to i8* + call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %42) #7 + %43 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %11, i64 0, i32 0, i64 0 + store i32 %39, i32* %43, align 4, !tbaa !26 + %44 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %11, i64 0, i32 0, i64 1 + store i32 %29, i32* %44, align 4, !tbaa !26 + %45 = bitcast %"struct.metal::array"* %12 to i8* + call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %45) #7 + %46 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %12, i64 0, i32 0, i64 0 + store i32 1, i32* %46, align 4, !tbaa !26 + %47 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %12, i64 0, i32 0, i64 1 + store i32 %39, i32* %47, align 4, !tbaa !26 + call void @air.init_strided_private_tensor.i32.global(%struct._tensor_t* nocapture nonnull writeonly %40, i16 2, i8 addrspace(1)* readnone %41, i8* nocapture nonnull readonly %42, i8* nocapture nonnull readonly %45, i8 1) #8 + call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %45) #7 + call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %42) #7 + call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %25) + %48 = bitcast i8* %25 to %struct._tensor_t* + %49 = bitcast float addrspace(1)* %2 to i8 addrspace(1)* + %50 = bitcast %"struct.metal::array"* %9 to i8* + call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %50) #7 + %51 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %9, i64 0, i32 0, i64 0 + store i32 %39, i32* %51, align 4, !tbaa !26 + %52 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %9, i64 0, i32 0, i64 1 + store i32 %30, i32* %52, align 4, !tbaa !26 + %53 = bitcast %"struct.metal::array"* %10 to i8* + call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %53) #7 + %54 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %10, i64 0, i32 0, i64 0 + store i32 1, i32* %54, align 4, !tbaa !26 + %55 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %10, i64 0, i32 0, i64 1 + store i32 %39, i32* %55, align 4, !tbaa !26 + call void @air.init_strided_private_tensor.i32.global(%struct._tensor_t* nocapture nonnull writeonly %48, i16 2, i8 addrspace(1)* readnone %49, i8* nocapture nonnull readonly %50, i8* nocapture nonnull readonly %53, i8 0) #8 + call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %53) #7 + call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %50) #7 + call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %26) + %56 = extractelement <2 x i32> %6, i64 1 + %57 = shl i32 %56, 6 + %58 = bitcast %"struct.metal::array"* %20 to i8* + call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %58) #7, !noalias !43 + %59 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %20, i64 0, i32 0, i64 0 + store i32 0, i32* %59, align 4, !tbaa !26, !noalias !43 + %60 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %20, i64 0, i32 0, i64 1 + store i32 %57, i32* %60, align 4, !tbaa !26, !noalias !43 + %61 = bitcast %"struct.metal::array"* %16 to i8* + call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %61) #7 + %62 = call i32 @air.get_extent_private_tensor.i32(%struct._tensor_t* nocapture nonnull readonly %31, i16 2, i16 0) #8 + %63 = call i32 @air.get_extent_private_tensor.i32(%struct._tensor_t* nocapture nonnull readonly %31, i16 2, i16 1) #8 + %64 = sub i32 %63, %57 + %65 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %16, i64 0, i32 0, i64 0 + store i32 %62, i32* %65, align 4 + %66 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %16, i64 0, i32 0, i64 1 + store i32 %64, i32* %66, align 4 + %67 = bitcast i8* %26 to %struct._tensor_t* + call void @air.slice_private_tensor_private_tensor.s.i32(%struct._tensor_t* nocapture nonnull writeonly %67, %struct._tensor_t* nocapture nonnull readonly %31, i16 2, i8* nocapture nonnull readonly %58, i8* nocapture nonnull readonly %61) #8 + call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %61) #7 + call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %58) #7, !noalias !43 + call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %27) + %68 = extractelement <2 x i32> %6, i64 0 + %69 = shl i32 %68, 5 + %70 = bitcast %"struct.metal::array"* %19 to i8* + call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %70) #7, !noalias !46 + %71 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %19, i64 0, i32 0, i64 0 + store i32 %69, i32* %71, align 4, !tbaa !26, !noalias !46 + %72 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %19, i64 0, i32 0, i64 1 + store i32 0, i32* %72, align 4, !tbaa !26, !noalias !46 + %73 = bitcast %"struct.metal::array"* %17 to i8* + call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %73) #7 + %74 = call i32 @air.get_extent_private_tensor.i32(%struct._tensor_t* nocapture nonnull readonly %40, i16 2, i16 0) #8 + %75 = call i32 @air.get_extent_private_tensor.i32(%struct._tensor_t* nocapture nonnull readonly %40, i16 2, i16 1) #8 + %76 = sub i32 %74, %69 + %77 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %17, i64 0, i32 0, i64 0 + store i32 %76, i32* %77, align 4 + %78 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %17, i64 0, i32 0, i64 1 + store i32 %75, i32* %78, align 4 + %79 = bitcast i8* %27 to %struct._tensor_t* + call void @air.slice_private_tensor_private_tensor.s.i32(%struct._tensor_t* nocapture nonnull writeonly %79, %struct._tensor_t* nocapture nonnull readonly %40, i16 2, i8* nocapture nonnull readonly %70, i8* nocapture nonnull readonly %73) #8 + call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %73) #7 + call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %70) #7, !noalias !46 + call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %28) + %80 = bitcast %"struct.metal::array"* %18 to i8* + call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %80) #7, !noalias !49 + %81 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %18, i64 0, i32 0, i64 0 + store i32 %69, i32* %81, align 4, !tbaa !26, !noalias !49 + %82 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %18, i64 0, i32 0, i64 1 + store i32 %57, i32* %82, align 4, !tbaa !26, !noalias !49 + %83 = bitcast %"struct.metal::array"* %15 to i8* + call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %83) #7 + %84 = call i32 @air.get_extent_private_tensor.i32(%struct._tensor_t* nocapture nonnull readonly %48, i16 2, i16 0) #8 + %85 = call i32 @air.get_extent_private_tensor.i32(%struct._tensor_t* nocapture nonnull readonly %48, i16 2, i16 1) #8 + %86 = sub i32 %84, %69 + %87 = sub i32 %85, %57 + %88 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %15, i64 0, i32 0, i64 0 + store i32 %86, i32* %88, align 4 + %89 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %15, i64 0, i32 0, i64 1 + store i32 %87, i32* %89, align 4 + %90 = bitcast i8* %28 to %struct._tensor_t* + call void @air.slice_private_tensor_private_tensor.s.i32(%struct._tensor_t* nocapture nonnull writeonly %90, %struct._tensor_t* nocapture nonnull readonly %48, i16 2, i8* nocapture nonnull readonly %80, i8* nocapture nonnull readonly %83) #8 + call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %83) #7 + call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %80) #7, !noalias !49 + %91 = tail call i32 @air.get_simdgroup_size.i32() #9 + %92 = shl i32 %91, 2 + %93 = bitcast %"struct.mpp::tensor_ops::matmul2d_descriptor"* %8 to i8* + call void @llvm.lifetime.start.p0i8(i64 20, i8* nonnull %93) #7 + call void @llvm.memcpy.p0i8.p0i8.i64(i8* noundef nonnull align 4 dereferenceable(20) %93, i8* noundef nonnull align 4 dereferenceable(20) bitcast (%"struct.mpp::tensor_ops::matmul2d_descriptor"* @_ZTAXtlN3mpp10tensor_ops19matmul2d_descriptorELi64ELi32ELin1EEE to i8*), i64 20, i1 false) #7, !tbaa.struct !52 + call void @__tensorops_impl_matmul2d_op_run_dv_f16_dv_f16_dv_f32(%"struct.mpp::tensor_ops::matmul2d_descriptor"* noundef nonnull align 4 dereferenceable(20) %8, i8* noundef nonnull %26, i32 noundef 2, i8* noundef nonnull %27, i32 noundef 2, i8* noundef nonnull %28, i32 noundef 2, i32 noundef %92) #10 + call void @llvm.lifetime.end.p0i8(i64 20, i8* nonnull %93) #7 + call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %28) #7 + call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %27) #7 + call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %26) #7 + call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %25) #7 + call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %23) #7 + call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %22) #7 + ret void +} + +; Function Attrs: mustprogress nofree nosync readnone speculatable willreturn +define linkonce_odr hidden i64 @_ZN5metal6tensorIU9MTLdeviceDhNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEENS_13tensor_inlineEJEEE.MTL_SIZEAS() local_unnamed_addr #1 { + %1 = tail call i16 @air.get_descriptor_size_tensor(i16 2, i16 4) #9 + %2 = zext i16 %1 to i64 + ret i64 %2 +} + +; Function Attrs: mustprogress nofree nosync nounwind readnone willreturn +declare i16 @air.get_descriptor_size_tensor(i16, i16) local_unnamed_addr #2 + +; Function Attrs: argmemonly mustprogress nocallback nofree nosync nounwind willreturn +declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #3 + +; Function Attrs: argmemonly mustprogress nocallback nofree nosync nounwind willreturn +declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #3 + +; Function Attrs: mustprogress nofree nosync readnone speculatable willreturn +define linkonce_odr hidden i64 @_ZN5metal6tensorIU9MTLdevicefNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEENS_13tensor_inlineEJEEE.MTL_SIZEAS() local_unnamed_addr #1 { + %1 = tail call i16 @air.get_descriptor_size_tensor(i16 2, i16 4) #9 + %2 = zext i16 %1 to i64 + ret i64 %2 +} + +; Function Attrs: argmemonly mustprogress nounwind willreturn +declare void @air.init_strided_private_tensor.i32.global(%struct._tensor_t* nocapture writeonly, i16, i8 addrspace(1)* readnone, i8* nocapture readonly, i8* nocapture readonly, i8) local_unnamed_addr #4 + +; Function Attrs: argmemonly mustprogress nounwind willreturn +declare i32 @air.get_extent_private_tensor.i32(%struct._tensor_t* nocapture readonly, i16, i16) local_unnamed_addr #4 + +; Function Attrs: argmemonly mustprogress nounwind willreturn +declare void @air.slice_private_tensor_private_tensor.s.i32(%struct._tensor_t* nocapture writeonly, %struct._tensor_t* nocapture readonly, i16, i8* nocapture readonly, i8* nocapture readonly) local_unnamed_addr #4 + +; Function Attrs: argmemonly mustprogress nofree nounwind willreturn +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i64, i1 immarg) #5 + +; Function Attrs: convergent +declare void @__tensorops_impl_matmul2d_op_run_dv_f16_dv_f16_dv_f32(%"struct.mpp::tensor_ops::matmul2d_descriptor"* noundef nonnull align 4 dereferenceable(20), i8* noundef, i32 noundef, i8* noundef, i32 noundef, i8* noundef, i32 noundef, i32 noundef) local_unnamed_addr #6 section "air.externally_defined" + +; Function Attrs: mustprogress nofree nosync nounwind readnone willreturn +declare i32 @air.get_simdgroup_size.i32() local_unnamed_addr #2 + +attributes #0 = { convergent nounwind "approx-func-fp-math"="true" "frame-pointer"="all" "min-legal-vector-width"="64" "no-builtins" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="true" } +attributes #1 = { mustprogress nofree nosync readnone speculatable willreturn "deferred-static-alloca-size" } +attributes #2 = { mustprogress nofree nosync nounwind readnone willreturn } +attributes #3 = { argmemonly mustprogress nocallback nofree nosync nounwind willreturn } +attributes #4 = { argmemonly mustprogress nounwind willreturn } +attributes #5 = { argmemonly mustprogress nofree nounwind willreturn } +attributes #6 = { convergent "approx-func-fp-math"="true" "frame-pointer"="all" "no-builtins" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="true" } +attributes #7 = { nounwind } +attributes #8 = { argmemonly nounwind willreturn } +attributes #9 = { nounwind readnone willreturn } +attributes #10 = { convergent nobuiltin nounwind "no-builtins" } + +!llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7, !8} +!air.kernel = !{!9} +!air.compile_options = !{!19, !20, !21} +!llvm.ident = !{!22} +!air.version = !{!23} +!air.language_version = !{!24} +!air.source_file_name = !{!25} + +!0 = !{i32 2, !"SDK Version", [2 x i32] [i32 26, i32 2]} +!1 = !{i32 1, !"wchar_size", i32 4} +!2 = !{i32 7, !"frame-pointer", i32 2} +!3 = !{i32 7, !"air.max_device_buffers", i32 31} +!4 = !{i32 7, !"air.max_constant_buffers", i32 31} +!5 = !{i32 7, !"air.max_threadgroup_buffers", i32 31} +!6 = !{i32 7, !"air.max_textures", i32 128} +!7 = !{i32 7, !"air.max_read_write_textures", i32 8} +!8 = !{i32 7, !"air.max_samplers", i32 16} +!9 = !{void (half addrspace(1)*, half addrspace(1)*, float addrspace(1)*, i32 addrspace(2)*, i32 addrspace(2)*, i32 addrspace(2)*, <2 x i32>)* @inline_matmul, !10, !11} +!10 = !{} +!11 = !{!12, !13, !14, !15, !16, !17, !18} +!12 = !{i32 0, !"air.buffer", !"air.location_index", i32 0, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.arg_type_size", i32 2, !"air.arg_type_align_size", i32 2, !"air.arg_type_name", !"half", !"air.arg_name", !"Abuf"} +!13 = !{i32 1, !"air.buffer", !"air.location_index", i32 1, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.arg_type_size", i32 2, !"air.arg_type_align_size", i32 2, !"air.arg_type_name", !"half", !"air.arg_name", !"Bbuf"} +!14 = !{i32 2, !"air.buffer", !"air.location_index", i32 2, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.arg_type_size", i32 4, !"air.arg_type_align_size", i32 4, !"air.arg_type_name", !"float", !"air.arg_name", !"Cbuf"} +!15 = !{i32 3, !"air.buffer", !"air.buffer_size", i32 4, !"air.location_index", i32 3, i32 1, !"air.read", !"air.address_space", i32 2, !"air.arg_type_size", i32 4, !"air.arg_type_align_size", i32 4, !"air.arg_type_name", !"uint", !"air.arg_name", !"M"} +!16 = !{i32 4, !"air.buffer", !"air.buffer_size", i32 4, !"air.location_index", i32 4, i32 1, !"air.read", !"air.address_space", i32 2, !"air.arg_type_size", i32 4, !"air.arg_type_align_size", i32 4, !"air.arg_type_name", !"uint", !"air.arg_name", !"N"} +!17 = !{i32 5, !"air.buffer", !"air.buffer_size", i32 4, !"air.location_index", i32 5, i32 1, !"air.read", !"air.address_space", i32 2, !"air.arg_type_size", i32 4, !"air.arg_type_align_size", i32 4, !"air.arg_type_name", !"uint", !"air.arg_name", !"K"} +!18 = !{i32 6, !"air.threadgroup_position_in_grid", !"air.arg_type_name", !"uint2", !"air.arg_name", !"tgid"} +!19 = !{!"air.compile.denorms_disable"} +!20 = !{!"air.compile.fast_math_enable"} +!21 = !{!"air.compile.framebuffer_fetch_enable"} +!22 = !{!"Apple metal version 32023.864 (metalfe-32023.864)"} +!23 = !{i32 2, i32 8, i32 0} +!24 = !{!"Metal", i32 4, i32 0, i32 0} +!25 = !{!"/private/tmp/metaltest/inline_matmul.metal"} +!26 = !{!27, !27, i64 0} +!27 = !{!"int", !28, i64 0} +!28 = !{!"omnipotent char", !29, i64 0} +!29 = !{!"Simple C++ TBAA"} +!30 = !{!31} +!31 = distinct !{!31, !32, !"air-alias-scope-arg(5)"} +!32 = distinct !{!32, !"air-alias-scopes(inline_matmul)"} +!33 = !{!34, !35, !36, !37, !38} +!34 = distinct !{!34, !32, !"air-alias-scope-arg(0)"} +!35 = distinct !{!35, !32, !"air-alias-scope-arg(1)"} +!36 = distinct !{!36, !32, !"air-alias-scope-arg(2)"} +!37 = distinct !{!37, !32, !"air-alias-scope-arg(3)"} +!38 = distinct !{!38, !32, !"air-alias-scope-arg(4)"} +!39 = !{!37} +!40 = !{!34, !35, !36, !38, !31} +!41 = !{!38} +!42 = !{!34, !35, !36, !37, !31} +!43 = !{!44} +!44 = distinct !{!44, !45, !"_ZNK5metal6tensorIU9MTLdeviceDhNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEENS_13tensor_inlineEJEE5sliceIJijEEENS_9enable_ifIXaafraa16is_convertible_vIT_iEeqsZT_clL_ZNS5_8get_rankEvEEES5_E4typeEDpS8_: argument 0"} +!45 = distinct !{!45, !"_ZNK5metal6tensorIU9MTLdeviceDhNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEENS_13tensor_inlineEJEE5sliceIJijEEENS_9enable_ifIXaafraa16is_convertible_vIT_iEeqsZT_clL_ZNS5_8get_rankEvEEES5_E4typeEDpS8_"} +!46 = !{!47} +!47 = distinct !{!47, !48, !"_ZNK5metal6tensorIU9MTLdeviceDhNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEENS_13tensor_inlineEJEE5sliceIJjiEEENS_9enable_ifIXaafraa16is_convertible_vIT_iEeqsZT_clL_ZNS5_8get_rankEvEEES5_E4typeEDpS8_: argument 0"} +!48 = distinct !{!48, !"_ZNK5metal6tensorIU9MTLdeviceDhNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEENS_13tensor_inlineEJEE5sliceIJjiEEENS_9enable_ifIXaafraa16is_convertible_vIT_iEeqsZT_clL_ZNS5_8get_rankEvEEES5_E4typeEDpS8_"} +!49 = !{!50} +!50 = distinct !{!50, !51, !"_ZNK5metal6tensorIU9MTLdevicefNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEENS_13tensor_inlineEJEE5sliceIJjjEEENS_9enable_ifIXaafraa16is_convertible_vIT_iEeqsZT_clL_ZNS5_8get_rankEvEEES5_E4typeEDpS8_: argument 0"} +!51 = distinct !{!51, !"_ZNK5metal6tensorIU9MTLdevicefNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEENS_13tensor_inlineEJEE5sliceIJjjEEENS_9enable_ifIXaafraa16is_convertible_vIT_iEeqsZT_clL_ZNS5_8get_rankEvEEES5_E4typeEDpS8_"} +!52 = !{i64 0, i64 4, !26, i64 4, i64 4, !26, i64 8, i64 4, !26, i64 12, i64 1, !53, i64 13, i64 1, !53, i64 14, i64 1, !53, i64 16, i64 4, !55} +!53 = !{!54, !54, i64 0} +!54 = !{!"bool", !28, i64 0} +!55 = !{!56, !56, i64 0} +!56 = !{!"_ZTSN3mpp10tensor_ops19matmul2d_descriptor4modeE", !28, i64 0} diff --git a/bin/inline_matmul.metal b/bin/inline_matmul.metal new file mode 100644 index 000000000..f1dd66565 --- /dev/null +++ b/bin/inline_matmul.metal @@ -0,0 +1,32 @@ +#include +#include +#include + +using namespace metal; +using namespace mpp::tensor_ops; + +kernel void inline_matmul(device half* Abuf, + device half* Bbuf, + device float* Cbuf, + constant uint& M, + constant uint& N, + constant uint& K, + uint2 tgid [[threadgroup_position_in_grid]]) +{ + // Build tensor_inline views over raw buffers. + auto A = tensor, tensor_inline>( + Abuf, dextents{int32_t(K), int32_t(M)}); + auto B = tensor, tensor_inline>( + Bbuf, dextents{int32_t(N), int32_t(K)}); + auto C = tensor, tensor_inline>( + Cbuf, dextents{int32_t(N), int32_t(M)}); + + constexpr auto desc = matmul2d_descriptor(64, 32, static_cast(dynamic_extent)); + matmul2d> op; + + auto mA = A.slice(0, tgid.y * 64); + auto mB = B.slice(tgid.x * 32, 0); + auto mC = C.slice(tgid.x * 32, tgid.y * 64); + + op.run(mA, mB, mC); +} diff --git a/bin/simple_matmul.ll b/bin/simple_matmul.ll new file mode 100644 index 000000000..c3a3bb3ff --- /dev/null +++ b/bin/simple_matmul.ll @@ -0,0 +1,128 @@ +; ModuleID = 'simple_matmul.metal' +source_filename = "simple_matmul.metal" +target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024-n8:16:32" +target triple = "air64_v28-apple-macosx26.0.0" + +%"struct.mpp::tensor_ops::matmul2d_descriptor" = type { i32, i32, i32, i8, i8, i8, i32 } +%struct._tensor_t = type opaque +%"struct.metal::tensor.3" = type { %"struct.metal::__tensor_base.4", %struct._tensor_t addrspace(1)* } +%"struct.metal::__tensor_base.4" = type { %"struct.metal::__tensor_offsets.5" } +%"struct.metal::__tensor_offsets.5" = type { %"struct.metal::array" } +%"struct.metal::array" = type { [2 x i32] } +%"struct.metal::tensor.6" = type { %"struct.metal::__tensor_base.7", %struct._tensor_t addrspace(1)* } +%"struct.metal::__tensor_base.7" = type { %"struct.metal::__tensor_offsets.8" } +%"struct.metal::__tensor_offsets.8" = type { %"struct.metal::array" } + +@_ZTAXtlN3mpp10tensor_ops19matmul2d_descriptorELi64ELi32ELin1EEE = linkonce_odr local_unnamed_addr constant %"struct.mpp::tensor_ops::matmul2d_descriptor" { i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0 } + +; Function Attrs: convergent nounwind +define void @simple_matmul(%struct._tensor_t addrspace(1)* %0, %struct._tensor_t addrspace(1)* %1, %struct._tensor_t addrspace(1)* %2, <2 x i32> noundef %3) local_unnamed_addr #0 { + %5 = alloca %"struct.mpp::tensor_ops::matmul2d_descriptor", align 4 + %6 = alloca %"struct.metal::tensor.3", align 8 + %7 = alloca %"struct.metal::tensor.3", align 8 + %8 = alloca %"struct.metal::tensor.6", align 8 + %9 = bitcast %"struct.metal::tensor.3"* %6 to i8* + call void @llvm.lifetime.start.p0i8(i64 16, i8* nonnull %9) #5 + %10 = extractelement <2 x i32> %3, i64 1 + %11 = shl i32 %10, 6 + %12 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %6, i64 0, i32 0, i32 0, i32 0, i32 0, i64 0 + store i32 0, i32* %12, align 8 + %13 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %6, i64 0, i32 0, i32 0, i32 0, i32 0, i64 1 + store i32 %11, i32* %13, align 4 + %14 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %6, i64 0, i32 1 + store %struct._tensor_t addrspace(1)* %0, %struct._tensor_t addrspace(1)** %14, align 8 + %15 = bitcast %"struct.metal::tensor.3"* %7 to i8* + call void @llvm.lifetime.start.p0i8(i64 16, i8* nonnull %15) #5 + %16 = extractelement <2 x i32> %3, i64 0 + %17 = shl i32 %16, 5 + %18 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %7, i64 0, i32 0, i32 0, i32 0, i32 0, i64 0 + store i32 %17, i32* %18, align 8 + %19 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %7, i64 0, i32 0, i32 0, i32 0, i32 0, i64 1 + store i32 0, i32* %19, align 4 + %20 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %7, i64 0, i32 1 + store %struct._tensor_t addrspace(1)* %1, %struct._tensor_t addrspace(1)** %20, align 8 + %21 = bitcast %"struct.metal::tensor.6"* %8 to i8* + call void @llvm.lifetime.start.p0i8(i64 16, i8* nonnull %21) #5 + %22 = getelementptr inbounds %"struct.metal::tensor.6", %"struct.metal::tensor.6"* %8, i64 0, i32 0, i32 0, i32 0, i32 0, i64 0 + store i32 %17, i32* %22, align 8 + %23 = getelementptr inbounds %"struct.metal::tensor.6", %"struct.metal::tensor.6"* %8, i64 0, i32 0, i32 0, i32 0, i32 0, i64 1 + store i32 %11, i32* %23, align 4 + %24 = getelementptr inbounds %"struct.metal::tensor.6", %"struct.metal::tensor.6"* %8, i64 0, i32 1 + store %struct._tensor_t addrspace(1)* %2, %struct._tensor_t addrspace(1)** %24, align 8 + %25 = tail call i32 @air.get_simdgroup_size.i32() #6 + %26 = shl i32 %25, 2 + %27 = bitcast %"struct.mpp::tensor_ops::matmul2d_descriptor"* %5 to i8* + call void @llvm.lifetime.start.p0i8(i64 20, i8* nonnull %27) #5 + call void @llvm.memcpy.p0i8.p0i8.i64(i8* noundef nonnull align 4 dereferenceable(20) %27, i8* noundef nonnull align 4 dereferenceable(20) bitcast (%"struct.mpp::tensor_ops::matmul2d_descriptor"* @_ZTAXtlN3mpp10tensor_ops19matmul2d_descriptorELi64ELi32ELin1EEE to i8*), i64 20, i1 false) #5, !tbaa.struct !23 + call void @__tensorops_impl_matmul2d_op_run_dv_f16_dv_f16_dv_f32(%"struct.mpp::tensor_ops::matmul2d_descriptor"* noundef nonnull align 4 dereferenceable(20) %5, i8* noundef nonnull %9, i32 noundef 1, i8* noundef nonnull %15, i32 noundef 1, i8* noundef nonnull %21, i32 noundef 1, i32 noundef %26) #7 + call void @llvm.lifetime.end.p0i8(i64 20, i8* nonnull %27) #5 + call void @llvm.lifetime.end.p0i8(i64 16, i8* nonnull %21) #5 + call void @llvm.lifetime.end.p0i8(i64 16, i8* nonnull %15) #5 + call void @llvm.lifetime.end.p0i8(i64 16, i8* nonnull %9) #5 + ret void +} + +; Function Attrs: argmemonly mustprogress nocallback nofree nosync nounwind willreturn +declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #1 + +; Function Attrs: argmemonly mustprogress nocallback nofree nosync nounwind willreturn +declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #1 + +; Function Attrs: argmemonly mustprogress nofree nounwind willreturn +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i64, i1 immarg) #2 + +; Function Attrs: convergent +declare void @__tensorops_impl_matmul2d_op_run_dv_f16_dv_f16_dv_f32(%"struct.mpp::tensor_ops::matmul2d_descriptor"* noundef nonnull align 4 dereferenceable(20), i8* noundef, i32 noundef, i8* noundef, i32 noundef, i8* noundef, i32 noundef, i32 noundef) local_unnamed_addr #3 section "air.externally_defined" + +; Function Attrs: mustprogress nofree nosync nounwind readnone willreturn +declare i32 @air.get_simdgroup_size.i32() local_unnamed_addr #4 + +attributes #0 = { convergent nounwind "approx-func-fp-math"="true" "frame-pointer"="all" "min-legal-vector-width"="64" "no-builtins" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="true" } +attributes #1 = { argmemonly mustprogress nocallback nofree nosync nounwind willreturn } +attributes #2 = { argmemonly mustprogress nofree nounwind willreturn } +attributes #3 = { convergent "approx-func-fp-math"="true" "frame-pointer"="all" "no-builtins" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="true" } +attributes #4 = { mustprogress nofree nosync nounwind readnone willreturn } +attributes #5 = { nounwind } +attributes #6 = { nounwind readnone willreturn } +attributes #7 = { convergent nobuiltin nounwind "no-builtins" } + +!llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7, !8} +!air.kernel = !{!9} +!air.compile_options = !{!16, !17, !18} +!llvm.ident = !{!19} +!air.version = !{!20} +!air.language_version = !{!21} +!air.source_file_name = !{!22} + +!0 = !{i32 2, !"SDK Version", [2 x i32] [i32 26, i32 2]} +!1 = !{i32 1, !"wchar_size", i32 4} +!2 = !{i32 7, !"frame-pointer", i32 2} +!3 = !{i32 7, !"air.max_device_buffers", i32 31} +!4 = !{i32 7, !"air.max_constant_buffers", i32 31} +!5 = !{i32 7, !"air.max_threadgroup_buffers", i32 31} +!6 = !{i32 7, !"air.max_textures", i32 128} +!7 = !{i32 7, !"air.max_read_write_textures", i32 8} +!8 = !{i32 7, !"air.max_samplers", i32 16} +!9 = !{void (%struct._tensor_t addrspace(1)*, %struct._tensor_t addrspace(1)*, %struct._tensor_t addrspace(1)*, <2 x i32>)* @simple_matmul, !10, !11} +!10 = !{} +!11 = !{!12, !13, !14, !15} +!12 = !{i32 0, !"air.tensor", !"air.location_index", i32 0, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.arg_type_name", !"tensor>", !"air.arg_name", !"A"} +!13 = !{i32 1, !"air.tensor", !"air.location_index", i32 1, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.arg_type_name", !"tensor>", !"air.arg_name", !"B"} +!14 = !{i32 2, !"air.tensor", !"air.location_index", i32 2, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.arg_type_name", !"tensor>", !"air.arg_name", !"C"} +!15 = !{i32 3, !"air.threadgroup_position_in_grid", !"air.arg_type_name", !"uint2", !"air.arg_name", !"tgid"} +!16 = !{!"air.compile.denorms_disable"} +!17 = !{!"air.compile.fast_math_enable"} +!18 = !{!"air.compile.framebuffer_fetch_enable"} +!19 = !{!"Apple metal version 32023.864 (metalfe-32023.864)"} +!20 = !{i32 2, i32 8, i32 0} +!21 = !{!"Metal", i32 4, i32 0, i32 0} +!22 = !{!"/private/tmp/metaltest/simple_matmul.metal"} +!23 = !{i64 0, i64 4, !24, i64 4, i64 4, !24, i64 8, i64 4, !24, i64 12, i64 1, !28, i64 13, i64 1, !28, i64 14, i64 1, !28, i64 16, i64 4, !30} +!24 = !{!25, !25, i64 0} +!25 = !{!"int", !26, i64 0} +!26 = !{!"omnipotent char", !27, i64 0} +!27 = !{!"Simple C++ TBAA"} +!28 = !{!29, !29, i64 0} +!29 = !{!"bool", !26, i64 0} +!30 = !{!31, !31, i64 0} +!31 = !{!"_ZTSN3mpp10tensor_ops19matmul2d_descriptor4modeE", !26, i64 0} diff --git a/bin/simple_matmul.metal b/bin/simple_matmul.metal new file mode 100644 index 000000000..5b9f94f33 --- /dev/null +++ b/bin/simple_matmul.metal @@ -0,0 +1,22 @@ +#include +#include +#include + +using namespace metal; +using namespace mpp::tensor_ops; + +kernel void simple_matmul(tensor> A, + tensor> B, + tensor> C, + uint2 tgid [[threadgroup_position_in_grid]]) +{ + constexpr auto desc = matmul2d_descriptor(64, 32, static_cast(dynamic_extent), + false, false, false); + matmul2d> op; + + auto mA = A.slice(0, tgid.y * 64); + auto mB = B.slice(tgid.x * 32, 0); + auto mC = C.slice(tgid.x * 32, tgid.y * 64); + + op.run(mA, mB, mC); +} diff --git a/src/Metal.jl b/src/Metal.jl index ecc403a73..72def02f5 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -40,6 +40,7 @@ include("device/intrinsics/math.jl") include("device/intrinsics/synchronization.jl") include("device/intrinsics/memory.jl") include("device/intrinsics/simd.jl") +include("device/intrinsics/tensor.jl") include("device/intrinsics/atomics.jl") include("device/malloc.jl") include("device/random.jl") diff --git a/src/device/intrinsics/tensor.jl b/src/device/intrinsics/tensor.jl new file mode 100644 index 000000000..903f9134e --- /dev/null +++ b/src/device/intrinsics/tensor.jl @@ -0,0 +1,219 @@ +export MtlInlineTensor, matmul2d_descriptor, tensor_ops_matmul2d! + +using Core: LLVMPtr + +# Wrappers for Metal 4 tensor-ops / `mpp::tensor_ops` device-side APIs. +# +# The host-bound `tensor_handle` form needs opaque kernel arguments and a +# different ABI (see `ISSUE-tensor-ops.md`). The `tensor_inline` form, which +# is what this file targets, constructs the tensor on the kernel stack from +# a buffer pointer and a set of extents/strides — kernel signature stays the +# same as a plain `MtlDeviceArray` kernel. +# +# Each tensor descriptor lives in a per-thread byte buffer held by a +# `Ref{NTuple{N, UInt8}}`. Julia's `llvm-alloc-opt` pass promotes the Ref to +# a stack alloca once everything is inlined into the kernel. + +const _TENSOR_DESCRIPTOR_SIZE = 64 + +const _TensorDescriptorStorage = Base.RefValue{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}} + + +## Tensor descriptor primitives (`air.*` intrinsics). + +# Returns the per-thread tensor descriptor size for the given rank/index-size. +@device_function get_descriptor_size_tensor(rank::Int16, index_size::Int16) = + ccall("extern air.get_descriptor_size_tensor", llvmcall, + Int16, (Int16, Int16), rank, index_size) + +# Build an `i32`-indexed strided tensor view over a device-memory buffer. +# Ccall arg types use `Ref{T}` so that NTuple values passed in get auto-boxed +# into temporary Refs (via `cconvert(::Type{Ref{T}}, ::NTuple{N,T})` from +# `base/refpointer.jl`). `llvm-alloc-opt` promotes those temporaries to +# stack allocas. The element type must match what we pass: a mismatched +# `Ref{T}` would force ccall to emit a `jl_f_throw_methoderror` path, and +# the dead heap alloc on that path defeats the promotion. +@device_function init_strided_tensor_device!( + handle::_TensorDescriptorStorage, + rank::Int16, + data::LLVMPtr{UInt8, AS.Device}, + extents::NTuple{2, Int32}, + strides::NTuple{2, Int32}, + contiguous::Int8, +) = ccall("extern air.init_strided_private_tensor.i32.global", llvmcall, + Cvoid, + (Ref{UInt8}, Int16, LLVMPtr{UInt8, AS.Device}, + Ref{Int32}, Ref{Int32}, Int8), + handle, rank, data, extents, strides, contiguous) + +@device_function get_extent_private_tensor(handle::_TensorDescriptorStorage, + rank::Int16, dim::Int16) = + ccall("extern air.get_extent_private_tensor.i32", llvmcall, + Int32, (Ref{UInt8}, Int16, Int16), + handle, rank, dim) + +@device_function slice_private_tensor!( + dst::_TensorDescriptorStorage, + src::_TensorDescriptorStorage, + rank::Int16, + origin::NTuple{2, Int32}, + extents::NTuple{2, Int32}, +) = ccall("extern air.slice_private_tensor_private_tensor.s.i32", llvmcall, + Cvoid, + (Ref{UInt8}, Ref{UInt8}, Int16, Ref{Int32}, Ref{Int32}), + dst, src, rank, origin, extents) + + +## High-level inline-tensor wrapper. + +""" + MtlInlineTensor{T, R} + +Kernel-stack tensor view over an `MtlDeviceArray`, suitable for use as an +operand of [`tensor_ops_matmul2d!`](@ref). `T` is the element type; `R` is +the rank (only rank 2 is supported today). Backed by a thread-private byte +buffer (an inline `Ref`) that the runtime initializes at construction. + +Note: extents follow the MSL `dextents{e1, e2, ...}` convention +(innermost dimension first), which is the row-major view the matmul kernel +expects. For a Julia column-major `MtlMatrix(M, N)`, pass extents `(M, N)` +if you want to treat columns as the inner dimension. +""" +struct MtlInlineTensor{T, R} + storage::_TensorDescriptorStorage +end + +# In-kernel constructor: build a packed-stride rank-2 tensor over `data`. +@device_function @inline function MtlInlineTensor{T, 2}( + data::MtlDeviceArray{T, <:Any, AS.Device}, + e1::Integer, e2::Integer) where {T} + storage = Ref{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}}() + init_strided_tensor_device!(storage, Int16(2), + reinterpret(LLVMPtr{UInt8, AS.Device}, pointer(data)), + (Int32(e1), Int32(e2)), + (Int32(1), Int32(e1)), + Int8(1)) + return MtlInlineTensor{T, 2}(storage) +end + +@inline MtlInlineTensor(data::MtlDeviceArray{T, <:Any, AS.Device}, + extents::NTuple{2, <:Integer}) where {T} = + MtlInlineTensor{T, 2}(data, extents[1], extents[2]) + +# In-kernel constructor with explicit strides. +@device_function @inline function MtlInlineTensor{T, 2}( + data::MtlDeviceArray{T, <:Any, AS.Device}, + e1::Integer, e2::Integer, + s1::Integer, s2::Integer) where {T} + storage = Ref{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}}() + init_strided_tensor_device!(storage, Int16(2), + reinterpret(LLVMPtr{UInt8, AS.Device}, pointer(data)), + (Int32(e1), Int32(e2)), + (Int32(s1), Int32(s2)), + Int8(0)) + return MtlInlineTensor{T, 2}(storage) +end + +@inline MtlInlineTensor(data::MtlDeviceArray{T, <:Any, AS.Device}, + extents::NTuple{2, <:Integer}, + strides::NTuple{2, <:Integer}) where {T} = + MtlInlineTensor{T, 2}(data, extents[1], extents[2], strides[1], strides[2]) + +Base.eltype(::Type{<:MtlInlineTensor{T}}) where {T} = T +Base.eltype(::MtlInlineTensor{T}) where {T} = T + +# Slice. Origins are zero-based to match MSL semantics. +@device_function @inline function _slice_inline_tensor( + t::MtlInlineTensor{T, 2}, + o1::Integer, o2::Integer, + e1::Integer, e2::Integer) where {T} + storage = Ref{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}}() + slice_private_tensor!(storage, t.storage, Int16(2), + (Int32(o1), Int32(o2)), + (Int32(e1), Int32(e2))) + return MtlInlineTensor{T, 2}(storage) +end + +@inline Base.view(t::MtlInlineTensor{T, 2}, origin::NTuple{2, <:Integer}, + extents::NTuple{2, <:Integer}) where {T} = + _slice_inline_tensor(t, origin[1], origin[2], extents[1], extents[2]) + + +## matmul2d descriptor (mirrors `mpp::tensor_ops::matmul2d_descriptor`). + +@enum Matmul2DMode::Int32 begin + matmul2d_multiply = 0 + matmul2d_multiply_accumulate = 1 +end + +""" + matmul2d_descriptor(m, n, [k]; transpose_left=false, transpose_right=false, + relaxed_precision=false, mode=matmul2d_multiply) + +Configuration descriptor for a `tensor_ops::matmul2d` operation. `k` +defaults to `-1` (dynamic — inferred from the input tensors at runtime). +Layout matches the 20-byte `mpp::tensor_ops::matmul2d_descriptor` POD. +""" +struct matmul2d_descriptor + m::Int32 + n::Int32 + k::Int32 + transpose_left::Int8 + transpose_right::Int8 + relaxed_precision::Int8 + matmul_mode::Matmul2DMode +end + +matmul2d_descriptor(m::Integer, n::Integer, k::Integer = -1; + transpose_left::Bool = false, + transpose_right::Bool = false, + relaxed_precision::Bool = false, + mode::Matmul2DMode = matmul2d_multiply) = + matmul2d_descriptor(Int32(m), Int32(n), Int32(k), + Int8(transpose_left), Int8(transpose_right), + Int8(relaxed_precision), mode) + + +## matmul2d run (inline-tensor → inline-tensor variant). + +const _TENSOR_DESC_INLINE = Int32(2) # `__tensor_ops_tensor_descriptor_type::tensor_inline` + +# Element-type suffix for `__tensorops_impl_matmul2d_op_run_*` symbols. +_tensorops_suffix(::Type{Float16}) = "f16" +_tensorops_suffix(::Type{Float32}) = "f32" + +""" + tensor_ops_matmul2d!(desc, left, right, dest, threads) + +`dest = left * right (+ dest if mode=multiply_accumulate)` executed +cooperatively by `threads` participating threads (i.e. +`simdgroup_size * num_simdgroups`). Each operand is an +[`MtlInlineTensor`](@ref). + +Supported element-type combinations follow `MPPTensorOpsMatMul2d.h`; only +`(f16, f16, f16)`, `(f16, f16, f32)`, `(f32, f32, f32)` are wired up here. +""" +@generated function tensor_ops_matmul2d!( + desc::matmul2d_descriptor, + left::MtlInlineTensor{TL, 2}, + right::MtlInlineTensor{TR, 2}, + dest::MtlInlineTensor{TD, 2}, + threads::Int32) where {TL, TR, TD} + sym = "__tensorops_impl_matmul2d_op_run_dv_$(_tensorops_suffix(TL))" * + "_dv_$(_tensorops_suffix(TR))" * + "_dv_$(_tensorops_suffix(TD))" + quote + ccall($"extern $sym", llvmcall, Cvoid, + (Ref{matmul2d_descriptor}, + Ref{UInt8}, Int32, + Ref{UInt8}, Int32, + Ref{UInt8}, Int32, + Int32), + desc, + left.storage, $_TENSOR_DESC_INLINE, + right.storage, $_TENSOR_DESC_INLINE, + dest.storage, $_TENSOR_DESC_INLINE, + threads) + return nothing + end +end From 2ab5b347270fe6b9c9c30af14120fc61fa1e7275 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 26 May 2026 15:57:37 +0200 Subject: [PATCH 09/24] Add Flash Attention example using Metal 4 tensor ops. Fourth path alongside the MPS / MPSGraph / simdgroup_matrix implementations. Builds tensor_inline views over the MtlDeviceArray inputs and dispatches the two matmuls through tensor_ops::matmul2d. The kernel stays buffer-shaped so the existing kernel ABI is unchanged. The forward pass is split across two dispatches (QK+softmax, then PV) to work around a Metal back-end crash on two __tensorops_impl_matmul2d_op_run_* calls in a single kernel. The scores tile is therefore materialized in device memory rather than fused into a cooperative_tensor. Limited to D == N == 64 and a single (head, batch) block; on macOS 26+. --- examples/flashattention.jl | 173 +++++++++++++++++++++++++++++++------ 1 file changed, 146 insertions(+), 27 deletions(-) diff --git a/examples/flashattention.jl b/examples/flashattention.jl index 4fe28df02..a14e4240c 100644 --- a/examples/flashattention.jl +++ b/examples/flashattention.jl @@ -1,6 +1,6 @@ # Flash Attention reference implementations on Apple Silicon. # -# Three ways to spell scaled dot-product attention on Metal, illustrating +# Four ways to spell scaled dot-product attention on Metal, illustrating # the programming models Metal.jl exposes: # # attention_mps(Q, K, V) @@ -31,22 +31,28 @@ # https://github.com/philipturner/metal-flash-attention for a # tuned reference. Works on macOS 13+ / M1+. # -# A fourth path would use the Metal 4 `cooperative_tensor` / -# `tensor_ops::matmul2d` primitives with postfix-fusion of the softmax -# epilogue. Apple positions this as the preferred programming model for -# ML on M5; on M3/M4 it lowers to the same simdgroup MMA hardware the -# `attention_simdgroup` path already drives. That path isn't yet wired up -# in Metal.jl — the ObjC classes are generated in `lib/mtl/libmtl.jl` -# (gated on `macos(v"26.0.0")`), but the host-side `MTLTensor` / -# `MTL4ComputeCommandEncoder` wrappers and the device-side -# `MtlCooperativeTensor` are not. Note that the device-side ops lower to -# externally-defined `__tensorops_impl_matmul2d_op_*` symbols rather than -# `air.*` intrinsics, so the binding pattern differs from the simdgroup -# case. +# attention_tensor(Q, K, V) +# Two kernels (QKᵀ + softmax, then PV) using the Metal 4 +# `tensor_ops::matmul2d` primitives. Each kernel builds +# `tensor_inline` views over the `MtlDeviceArray` inputs, so the +# kernel signature stays buffer-shaped — no host-side `MTLTensor` +# / `MTL4ComputeCommandEncoder` wrapping is needed. The matmuls +# lower to externally-defined `__tensorops_impl_matmul2d_op_*` +# symbols (linked from the MetalPerformancePrimitives runtime), +# not `air.*` intrinsics. Requires macOS 26+; on M3/M4 the runtime +# still lowers to the same simdgroup MMA hardware. Limited to N = +# D = 64 because the matmul descriptor is specialized to that +# single 64x64 tile. Splitting QK and PV across two dispatches — +# rather than one fused kernel — works around an Apple back-end +# crash on two `__tensorops_impl_matmul2d_op_run_*` calls in a +# single kernel; it also means the scores tile is materialized in +# device memory (`cooperative_tensor` would keep it in registers +# for true postfix-fusion, but the device-side dynamic-alloca +# support that requires isn't wired up yet). # -# All three implementations take Julia 4-D `(head_dim, seq, num_heads, -# batch)` inputs — MPSGraph sees these reversed as `(batch, num_heads, -# seq, head_dim)`, the layout Apple's SDPA expects. +# All implementations take Julia 4-D `(head_dim, seq, num_heads, batch)` +# inputs — MPSGraph sees these reversed as `(batch, num_heads, seq, +# head_dim)`, the layout Apple's SDPA expects. using Metal using Test @@ -183,6 +189,101 @@ function attention_simdgroup(Q::MtlArray{Float16,4}, K::MtlArray{Float16,4}, end +## Custom kernel with Metal 4 tensor ops (matmul2d, inline tensors) + +# Step 1: compute Q^T K into a Float32 scores buffer, then a row-wise softmax +# (cast to Float16) into a P buffer. The matmul writes its (M, N) output in a +# layout that Julia reads as K^T Q (the transpose of Q^T K), so we apply a +# *column*-wise softmax — that's what corresponds to row-wise softmax of the +# implicit Q^T K, and it's the right direction for column-major contiguous +# memory access. +function _fa_tensor_qk_softmax!(Q::AbstractMatrix{Float16}, + K::AbstractMatrix{Float16}, + S::AbstractMatrix{Float32}, + P::AbstractMatrix{Float16}, + D::UInt32, N::UInt32, scale::Float32) + threads = Int32(threads_per_threadgroup_3d().x) + tid = Int32(thread_position_in_threadgroup_3d().x) - Int32(1) + + A = MtlInlineTensor(Q, (D, N)) + B = MtlInlineTensor(K, (D, N)) + C = MtlInlineTensor(S, (N, N)) + desc = matmul2d_descriptor(N, N, D; transpose_right = true) + tensor_ops_matmul2d!(desc, A, B, C, threads) + threadgroup_barrier(Metal.MemoryFlagDevice) + + # Column-wise softmax. 64 of 128 threads do real work; the rest wait. + @inbounds if tid < Int32(N) + col = tid + Int32(1) + m = -Inf32 + for i in Int32(1):Int32(N) + v = S[i, col] * scale + m = v > m ? v : m + end + s = 0.0f0 + for i in Int32(1):Int32(N) + p = exp(S[i, col] * scale - m) + S[i, col] = p + s += p + end + inv_s = 1.0f0 / s + for i in Int32(1):Int32(N) + P[i, col] = Float16(S[i, col] * inv_s) + end + end + return +end + +# Step 2: O = V · P (in Julia view; equivalent to V · P_attn^T because the +# softmax output is stored in the transposed layout). +function _fa_tensor_pv!(O::AbstractMatrix{Float16}, + V::AbstractMatrix{Float16}, + P::AbstractMatrix{Float16}, + D::UInt32, N::UInt32) + threads = Int32(threads_per_threadgroup_3d().x) + A = MtlInlineTensor(P, (N, N)) + B = MtlInlineTensor(V, (D, N)) + C = MtlInlineTensor(O, (D, N)) + desc = matmul2d_descriptor(N, D, N) + tensor_ops_matmul2d!(desc, A, B, C, threads) + return +end + +function attention_tensor(Q::MtlArray{Float16,4}, K::MtlArray{Float16,4}, + V::MtlArray{Float16,4}; + scale = inv(sqrt(Float32(size(Q, 1))))) + @assert size(Q) == size(K) == size(V) + D, N, H, B = size(Q) + # MPP requires a real tile, and the (m, n, k) descriptor below is + # specialized to (N, N, D); allowing other shapes would mean dispatching + # multiple threadgroups. + @assert D == N "tensor-ops kernel currently expects D == N" + O = similar(Q) + + # Allocate persistent scratch for the scores / softmax outputs. One per + # (head, batch) pair would let us overlap; for clarity we reuse a single + # pair across all dispatches. + S = MtlArray{Float32}(undef, N, N) + P = MtlArray{Float16}(undef, N, N) + + simdgroup_size = 32 + threads = 4 * simdgroup_size # matmul descriptor wants execution_simdgroups<4> + + for b in 1:B, h in 1:H + Qm = view(Q, :, :, h, b) + Km = view(K, :, :, h, b) + Vm = view(V, :, :, h, b) + Om = view(O, :, :, h, b) + @metal threads = threads _fa_tensor_qk_softmax!(Qm, Km, S, P, + UInt32(D), UInt32(N), + Float32(scale)) + @metal threads = threads _fa_tensor_pv!(Om, Vm, P, UInt32(D), UInt32(N)) + end + Metal.synchronize() + return O +end + + ## CPU reference + driver function attention_cpu(Q, K, V; scale = inv(sqrt(eltype(Q)(size(Q, 1))))) @@ -201,20 +302,38 @@ end function main() T = Float16 # simdgroup path requires fp16 - D = N = 8 # constrained by the simdgroup kernel - Q = MtlArray(randn(T, D, N, 1, 1)) - K = MtlArray(randn(T, D, N, 1, 1)) - V = MtlArray(randn(T, D, N, 1, 1)) + # The simdgroup kernel is locked to 8x8 tiles, and the tensor-ops kernel + # uses a 64x64 matmul descriptor. Run each at its natural shape. + let D = N = 8 + Q = MtlArray(randn(T, D, N, 1, 1)) + K = MtlArray(randn(T, D, N, 1, 1)) + V = MtlArray(randn(T, D, N, 1, 1)) + + O_cpu = attention_cpu(Array(Q), Array(K), Array(V)) + O_mps = attention_mps(Q, K, V) + O_mpsgraph = attention_mpsgraph(Q, K, V) + O_simdgroup = attention_simdgroup(Q, K, V) + + @test Array(O_mps) ≈ O_cpu rtol = 1e-2 + @test Array(O_mpsgraph) ≈ O_cpu rtol = 1e-2 + @test Array(O_simdgroup) ≈ O_cpu rtol = 1e-2 + end + + if Metal.macos_version() >= v"26.0.0" + let D = N = 64 + Q = MtlArray(randn(T, D, N, 1, 1)) + K = MtlArray(randn(T, D, N, 1, 1)) + V = MtlArray(randn(T, D, N, 1, 1)) - O_cpu = attention_cpu(Array(Q), Array(K), Array(V)) - O_mps = attention_mps(Q, K, V) - O_mpsgraph = attention_mpsgraph(Q, K, V) - O_simdgroup = attention_simdgroup(Q, K, V) + O_cpu = attention_cpu(Array(Q), Array(K), Array(V)) + O_mps = attention_mps(Q, K, V) + O_tensor = attention_tensor(Q, K, V) - @test Array(O_mps) ≈ O_cpu rtol = 1e-2 - @test Array(O_mpsgraph) ≈ O_cpu rtol = 1e-2 - @test Array(O_simdgroup) ≈ O_cpu rtol = 1e-2 + @test Array(O_mps) ≈ O_cpu rtol = 1e-2 + @test Array(O_tensor) ≈ O_cpu rtol = 1e-2 + end + end end isinteractive() || main() From cc6c1b0e6f08376085bec707f5b4e1bc7f74b90b Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 26 May 2026 16:06:45 +0200 Subject: [PATCH 10/24] Expose more element types for tensor_ops_matmul2d!. Adds bf16 (Core.BFloat16), i8 (Int8), ui8 (UInt8), and i32 (Int32) to the suffix dispatch table, covering the common dense-precision and quantized matmul combinations from MPPTensorOpsMatMul2d.h. Verified against a CPU reference for {f16, f16, f16}, {f16, f16, f32}, {f32, f32, f32}, {bf16, bf16, bf16}, {bf16, bf16, f32}, {f16, i8, f16}, and {i8, i8, i32}. The 4-bit formats (i4/ui4) need a custom packed type and are skipped. --- src/device/intrinsics/tensor.jl | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/device/intrinsics/tensor.jl b/src/device/intrinsics/tensor.jl index 903f9134e..d2a272d5d 100644 --- a/src/device/intrinsics/tensor.jl +++ b/src/device/intrinsics/tensor.jl @@ -179,8 +179,15 @@ matmul2d_descriptor(m::Integer, n::Integer, k::Integer = -1; const _TENSOR_DESC_INLINE = Int32(2) # `__tensor_ops_tensor_descriptor_type::tensor_inline` # Element-type suffix for `__tensorops_impl_matmul2d_op_run_*` symbols. -_tensorops_suffix(::Type{Float16}) = "f16" -_tensorops_suffix(::Type{Float32}) = "f32" +# The 4-bit integer formats (`i4`, `ui4`) aren't exposed yet — Julia has no +# native 4-bit integer type. `int32` is only valid as the destination of +# an `i8`/`ui8` × `i4`/`ui4` matmul. +_tensorops_suffix(::Type{Float16}) = "f16" +_tensorops_suffix(::Type{Float32}) = "f32" +_tensorops_suffix(::Type{Core.BFloat16}) = "b16" +_tensorops_suffix(::Type{Int8}) = "i8" +_tensorops_suffix(::Type{UInt8}) = "ui8" +_tensorops_suffix(::Type{Int32}) = "i32" """ tensor_ops_matmul2d!(desc, left, right, dest, threads) From a2d6878d3308deab8caa27f685aba972653c7e1c Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 26 May 2026 16:09:26 +0200 Subject: [PATCH 11/24] Address-space-aware inline tensors for tensor_ops_matmul2d!. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit MtlInlineTensor gains an ASpace type parameter; constructors dispatch between air.init_strided_private_tensor.i32.global (device data) and the .local flavor (threadgroup data). tensor_ops_matmul2d! picks the dv/tg prefix per operand to name the __tensorops_impl_matmul2d_op_run_{aspace}_{type}_..._* symbol. That lets us stage tiles in threadgroup memory between matmuls (e.g. between QK and PV in attention) — verified against a CPU reference for half × half → float with the left operand staged to threadgroup. --- src/device/intrinsics/tensor.jl | 123 ++++++++++++++++++++++---------- 1 file changed, 85 insertions(+), 38 deletions(-) diff --git a/src/device/intrinsics/tensor.jl b/src/device/intrinsics/tensor.jl index d2a272d5d..6c9d8711b 100644 --- a/src/device/intrinsics/tensor.jl +++ b/src/device/intrinsics/tensor.jl @@ -46,6 +46,19 @@ const _TensorDescriptorStorage = Base.RefValue{NTuple{_TENSOR_DESCRIPTOR_SIZE, U Ref{Int32}, Ref{Int32}, Int8), handle, rank, data, extents, strides, contiguous) +@device_function init_strided_tensor_threadgroup!( + handle::_TensorDescriptorStorage, + rank::Int16, + data::LLVMPtr{UInt8, AS.ThreadGroup}, + extents::NTuple{2, Int32}, + strides::NTuple{2, Int32}, + contiguous::Int8, +) = ccall("extern air.init_strided_private_tensor.i32.local", llvmcall, + Cvoid, + (Ref{UInt8}, Int16, LLVMPtr{UInt8, AS.ThreadGroup}, + Ref{Int32}, Ref{Int32}, Int8), + handle, rank, data, extents, strides, contiguous) + @device_function get_extent_private_tensor(handle::_TensorDescriptorStorage, rank::Int16, dim::Int16) = ccall("extern air.get_extent_private_tensor.i32", llvmcall, @@ -67,24 +80,28 @@ const _TensorDescriptorStorage = Base.RefValue{NTuple{_TENSOR_DESCRIPTOR_SIZE, U ## High-level inline-tensor wrapper. """ - MtlInlineTensor{T, R} + MtlInlineTensor{T, R, ASpace} -Kernel-stack tensor view over an `MtlDeviceArray`, suitable for use as an -operand of [`tensor_ops_matmul2d!`](@ref). `T` is the element type; `R` is -the rank (only rank 2 is supported today). Backed by a thread-private byte -buffer (an inline `Ref`) that the runtime initializes at construction. +Kernel-stack tensor view over an `MtlDeviceArray` or `MtlThreadGroupArray`, +suitable for use as an operand of [`tensor_ops_matmul2d!`](@ref). `T` is the +element type; `R` is the rank (only rank 2 is supported today); `ASpace` is +the address space of the underlying data (`AS.Device` or `AS.ThreadGroup`). +Backed by a thread-private byte buffer (an inline `Ref`) that the runtime +initializes at construction. Note: extents follow the MSL `dextents{e1, e2, ...}` convention (innermost dimension first), which is the row-major view the matmul kernel expects. For a Julia column-major `MtlMatrix(M, N)`, pass extents `(M, N)` if you want to treat columns as the inner dimension. """ -struct MtlInlineTensor{T, R} +struct MtlInlineTensor{T, R, ASpace} storage::_TensorDescriptorStorage end -# In-kernel constructor: build a packed-stride rank-2 tensor over `data`. -@device_function @inline function MtlInlineTensor{T, 2}( +# In-kernel constructors: packed-stride rank-2 tensor over device or +# threadgroup memory. `contiguous` is `1` (packed) by default; the +# explicit-stride methods below pass `0`. +@device_function @inline function MtlInlineTensor{T, 2, AS.Device}( data::MtlDeviceArray{T, <:Any, AS.Device}, e1::Integer, e2::Integer) where {T} storage = Ref{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}}() @@ -93,15 +110,23 @@ end (Int32(e1), Int32(e2)), (Int32(1), Int32(e1)), Int8(1)) - return MtlInlineTensor{T, 2}(storage) + return MtlInlineTensor{T, 2, AS.Device}(storage) end -@inline MtlInlineTensor(data::MtlDeviceArray{T, <:Any, AS.Device}, - extents::NTuple{2, <:Integer}) where {T} = - MtlInlineTensor{T, 2}(data, extents[1], extents[2]) +@device_function @inline function MtlInlineTensor{T, 2, AS.ThreadGroup}( + data::MtlDeviceArray{T, <:Any, AS.ThreadGroup}, + e1::Integer, e2::Integer) where {T} + storage = Ref{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}}() + init_strided_tensor_threadgroup!(storage, Int16(2), + reinterpret(LLVMPtr{UInt8, AS.ThreadGroup}, pointer(data)), + (Int32(e1), Int32(e2)), + (Int32(1), Int32(e1)), + Int8(1)) + return MtlInlineTensor{T, 2, AS.ThreadGroup}(storage) +end -# In-kernel constructor with explicit strides. -@device_function @inline function MtlInlineTensor{T, 2}( +# Explicit-stride variants (mark the tensor as non-packed for the runtime). +@device_function @inline function MtlInlineTensor{T, 2, AS.Device}( data::MtlDeviceArray{T, <:Any, AS.Device}, e1::Integer, e2::Integer, s1::Integer, s2::Integer) where {T} @@ -111,27 +136,45 @@ end (Int32(e1), Int32(e2)), (Int32(s1), Int32(s2)), Int8(0)) - return MtlInlineTensor{T, 2}(storage) + return MtlInlineTensor{T, 2, AS.Device}(storage) +end + +@device_function @inline function MtlInlineTensor{T, 2, AS.ThreadGroup}( + data::MtlDeviceArray{T, <:Any, AS.ThreadGroup}, + e1::Integer, e2::Integer, + s1::Integer, s2::Integer) where {T} + storage = Ref{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}}() + init_strided_tensor_threadgroup!(storage, Int16(2), + reinterpret(LLVMPtr{UInt8, AS.ThreadGroup}, pointer(data)), + (Int32(e1), Int32(e2)), + (Int32(s1), Int32(s2)), + Int8(0)) + return MtlInlineTensor{T, 2, AS.ThreadGroup}(storage) end -@inline MtlInlineTensor(data::MtlDeviceArray{T, <:Any, AS.Device}, - extents::NTuple{2, <:Integer}, - strides::NTuple{2, <:Integer}) where {T} = - MtlInlineTensor{T, 2}(data, extents[1], extents[2], strides[1], strides[2]) +# Convenience: infer address space from the array. +@inline MtlInlineTensor(data::MtlDeviceArray{T, <:Any, A}, + extents::NTuple{2, <:Integer}) where {T, A} = + MtlInlineTensor{T, 2, A}(data, extents[1], extents[2]) + +@inline MtlInlineTensor(data::MtlDeviceArray{T, <:Any, A}, + extents::NTuple{2, <:Integer}, + strides::NTuple{2, <:Integer}) where {T, A} = + MtlInlineTensor{T, 2, A}(data, extents[1], extents[2], strides[1], strides[2]) Base.eltype(::Type{<:MtlInlineTensor{T}}) where {T} = T Base.eltype(::MtlInlineTensor{T}) where {T} = T # Slice. Origins are zero-based to match MSL semantics. @device_function @inline function _slice_inline_tensor( - t::MtlInlineTensor{T, 2}, + t::MtlInlineTensor{T, 2, A}, o1::Integer, o2::Integer, - e1::Integer, e2::Integer) where {T} + e1::Integer, e2::Integer) where {T, A} storage = Ref{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}}() slice_private_tensor!(storage, t.storage, Int16(2), (Int32(o1), Int32(o2)), (Int32(e1), Int32(e2))) - return MtlInlineTensor{T, 2}(storage) + return MtlInlineTensor{T, 2, A}(storage) end @inline Base.view(t::MtlInlineTensor{T, 2}, origin::NTuple{2, <:Integer}, @@ -182,12 +225,16 @@ const _TENSOR_DESC_INLINE = Int32(2) # `__tensor_ops_tensor_descriptor_type::t # The 4-bit integer formats (`i4`, `ui4`) aren't exposed yet — Julia has no # native 4-bit integer type. `int32` is only valid as the destination of # an `i8`/`ui8` × `i4`/`ui4` matmul. -_tensorops_suffix(::Type{Float16}) = "f16" -_tensorops_suffix(::Type{Float32}) = "f32" +_tensorops_suffix(::Type{Float16}) = "f16" +_tensorops_suffix(::Type{Float32}) = "f32" _tensorops_suffix(::Type{Core.BFloat16}) = "b16" -_tensorops_suffix(::Type{Int8}) = "i8" -_tensorops_suffix(::Type{UInt8}) = "ui8" -_tensorops_suffix(::Type{Int32}) = "i32" +_tensorops_suffix(::Type{Int8}) = "i8" +_tensorops_suffix(::Type{UInt8}) = "ui8" +_tensorops_suffix(::Type{Int32}) = "i32" + +# Address-space prefix for the run helpers (`dv` for device, `tg` for threadgroup). +_tensorops_aspace_prefix(::Val{AS.Device}) = "dv" +_tensorops_aspace_prefix(::Val{AS.ThreadGroup}) = "tg" """ tensor_ops_matmul2d!(desc, left, right, dest, threads) @@ -195,20 +242,20 @@ _tensorops_suffix(::Type{Int32}) = "i32" `dest = left * right (+ dest if mode=multiply_accumulate)` executed cooperatively by `threads` participating threads (i.e. `simdgroup_size * num_simdgroups`). Each operand is an -[`MtlInlineTensor`](@ref). - -Supported element-type combinations follow `MPPTensorOpsMatMul2d.h`; only -`(f16, f16, f16)`, `(f16, f16, f32)`, `(f32, f32, f32)` are wired up here. +[`MtlInlineTensor`](@ref) over either device or threadgroup memory; the +right `__tensorops_impl_matmul2d_op_run_{aspace}_{type}_..._*` symbol is +picked based on the operand types and address spaces. """ @generated function tensor_ops_matmul2d!( desc::matmul2d_descriptor, - left::MtlInlineTensor{TL, 2}, - right::MtlInlineTensor{TR, 2}, - dest::MtlInlineTensor{TD, 2}, - threads::Int32) where {TL, TR, TD} - sym = "__tensorops_impl_matmul2d_op_run_dv_$(_tensorops_suffix(TL))" * - "_dv_$(_tensorops_suffix(TR))" * - "_dv_$(_tensorops_suffix(TD))" + left::MtlInlineTensor{TL, 2, AL}, + right::MtlInlineTensor{TR, 2, AR}, + dest::MtlInlineTensor{TD, 2, AD}, + threads::Int32) where {TL, TR, TD, AL, AR, AD} + sym = "__tensorops_impl_matmul2d_op_run" * + "_$(_tensorops_aspace_prefix(Val(AL)))_$(_tensorops_suffix(TL))" * + "_$(_tensorops_aspace_prefix(Val(AR)))_$(_tensorops_suffix(TR))" * + "_$(_tensorops_aspace_prefix(Val(AD)))_$(_tensorops_suffix(TD))" quote ccall($"extern $sym", llvmcall, Cvoid, (Ref{matmul2d_descriptor}, From 09a85ce94ddb830737f0f24efd62878a4e8ba53d Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 26 May 2026 16:12:10 +0200 Subject: [PATCH 12/24] Document why static-extent inline tensors aren't exposed. Apple's static_slice<> only works on tensor_handle, not tensor_inline, and building an inline tensor with static extents emits the same AIR as one with dynamic extents (same air.init_strided_private_tensor + runtime extents tuple). Encoding static extents in MtlInlineTensor's type would only shave a few bytes off the extents alloca without enabling any optimization, so we keep extents dynamic. --- ISSUE-tensor-ops.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/ISSUE-tensor-ops.md b/ISSUE-tensor-ops.md index 3153cab47..1429f8106 100644 --- a/ISSUE-tensor-ops.md +++ b/ISSUE-tensor-ops.md @@ -38,6 +38,18 @@ The GPUCompiler bits: Without the section attribute, the metallib back-end won't resolve the symbol from the MetalPerformancePrimitives runtime. +## What's intentionally not exposed + +- **`static_slice<>` / compile-time extents.** Apple's tensor API only + exposes `static_slice` on `tensor_handle` operands, not `tensor_inline`. + An inline tensor built with static extents (e.g. + `tensor, tensor_inline>`) emits + identical AIR to one built with dynamic extents — same + `air.init_strided_private_tensor` + runtime extents arrays. So encoding + static extents in the `MtlInlineTensor` type would only buy us a slightly + smaller alloca for the extents tuple; it would not enable bounds-check + elision in the matmul or in the slice path. We leave it dynamic. + ## What's not working / known limitations - **Two `__tensorops_impl_matmul2d_op_run_*` calls in one kernel crash the From 38507ab2508d37f64f348e93e90c82d7435e3488 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 26 May 2026 16:14:45 +0200 Subject: [PATCH 13/24] Document and export multiply_accumulate K-loop pattern. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Exports the Matmul2DMode constants and adds a docstring example showing the K-loop pattern: zero C, then loop with mode = matmul2d_multiply_accumulate, slicing the K dimension. Keeps the loop trip count dynamic to avoid full unrolling into multiple tensor_ops_matmul2d! call sites — that hits Apple's back-end crash (see ISSUE-tensor-ops.md). --- src/device/intrinsics/tensor.jl | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/device/intrinsics/tensor.jl b/src/device/intrinsics/tensor.jl index 6c9d8711b..cadb60d4e 100644 --- a/src/device/intrinsics/tensor.jl +++ b/src/device/intrinsics/tensor.jl @@ -1,4 +1,5 @@ -export MtlInlineTensor, matmul2d_descriptor, tensor_ops_matmul2d! +export MtlInlineTensor, matmul2d_descriptor, tensor_ops_matmul2d!, + matmul2d_multiply, matmul2d_multiply_accumulate using Core: LLVMPtr @@ -196,6 +197,23 @@ end Configuration descriptor for a `tensor_ops::matmul2d` operation. `k` defaults to `-1` (dynamic — inferred from the input tensors at runtime). Layout matches the 20-byte `mpp::tensor_ops::matmul2d_descriptor` POD. + +For an outer `K`-loop where each iteration accumulates a partial product +into the destination, set `mode = matmul2d_multiply_accumulate` and zero +the destination before the loop. A typical pattern: + +```julia +desc = matmul2d_descriptor(M, N, TileK; mode = matmul2d_multiply_accumulate) +for s in 0:(nslices - 1) + sA = view(tA, (Int32(s) * Int32(TileK), Int32(0)), (Int32(TileK), Int32(M))) + sB = view(tB, (Int32(0), Int32(s) * Int32(TileK)), (Int32(N), Int32(TileK))) + tensor_ops_matmul2d!(desc, sA, sB, tC, threads) +end +``` + +Keep the loop's trip count dynamic — a compile-time-known trip count +that fully unrolls into multiple `tensor_ops_matmul2d!` call sites +currently crashes Apple's back-end (see `ISSUE-tensor-ops.md`). """ struct matmul2d_descriptor m::Int32 From a0fbbc32a14c1ca8a2a5a7768b53ae557bca987b Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 26 May 2026 16:20:53 +0200 Subject: [PATCH 14/24] Expand the two-matmul crash note with the AIR-level diagnosis. Apple-compiled MSL with two matmul2d calls builds a working pipeline state, so the crash is triggered specifically by our IR pattern: the matmul2d_descriptor ends up as a series of per-field stores rather than the memcpy-from-constant-global pattern Apple's compiler emits. Local reproducer + diff is in bugs/two_matmul_crash/ (gitignored). --- ISSUE-tensor-ops.md | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/ISSUE-tensor-ops.md b/ISSUE-tensor-ops.md index 1429f8106..3477a11ec 100644 --- a/ISSUE-tensor-ops.md +++ b/ISSUE-tensor-ops.md @@ -53,11 +53,18 @@ The GPUCompiler bits: ## What's not working / known limitations - **Two `__tensorops_impl_matmul2d_op_run_*` calls in one kernel crash the - Metal back-end** (`XPC_ERROR_CONNECTION_INTERRUPTED` from - `AGXMetalG15X_M1`). The attention example sidesteps this by splitting QK - and PV into two dispatches. This is likely an Apple compiler bug — the IR - we emit looks structurally identical to single-matmul kernels that compile - fine. Worth filing upstream. + Metal back-end** at pipeline-state creation + (`XPC_ERROR_CONNECTION_INTERRUPTED` from `AGXMetalG15X_M1`). MSL-compiled + metallibs of the same kernel shape build pipeline states fine, so the + crash is triggered by our specific AIR pattern: the + `matmul2d_descriptor` ends up populated as a sequence of per-field + stores (via Julia's lowering of `Ref(::matmul2d_descriptor)` and SROA), + rather than Apple's pattern of `memcpy` from a `linkonce_odr` constant + global. The likely fix is to emit the constant-global + memcpy pattern + for descriptors whose fields are compile-time constants. Local + reproducer in `bugs/two_matmul_crash/` (gitignored — see the README + there for the AIR diff and what's been tried). The attention example + sidesteps this by splitting QK and PV into two dispatches. - **No `cooperative_tensor` yet.** That means the softmax epilogue can't be done in registers — the scores tile is materialized in device memory. A proper Flash Attention would fuse the softmax into the cooperative tensor From 4cfcb94acde63b83158c4593d4f44cca73f73114 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 26 May 2026 16:28:06 +0200 Subject: [PATCH 15/24] Add tensor_matmul!: tile-decomposed matmul over tensor_ops::matmul2d. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit C = A * B with natural Julia matrix-product semantics, dispatched as (M/tile_m, N/tile_n) threadgroups with a K-loop inside via multiply_accumulate. Each tile is one matmul2d call site, so the two-matmul-per-kernel back-end crash never fires. Tested against a CPU reference at (64,64,64), (64,128,64), (128,64,64), (128,128,64), (64,64,128) and (256,192,128) shapes — max relative error ~3e-7 (Float16 inputs, Float32 accumulator). The wrapper hides the matmul2d operand-swap from the user: matmul2d's output buffer is laid out as Julia's transpose of (M, N), so we put Julia's B in matmul2d's left slot and Julia's A in the right slot. Two swaps cancel and every operand uses packed strides. --- src/Metal.jl | 1 + src/tensor.jl | 85 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+) create mode 100644 src/tensor.jl diff --git a/src/Metal.jl b/src/Metal.jl index 72def02f5..f6ab9af37 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -71,6 +71,7 @@ include("utilities.jl") include("broadcast.jl") include("mapreduce.jl") include("accumulate.jl") +include("tensor.jl") include("indexing.jl") include("random.jl") include("fft.jl") diff --git a/src/tensor.jl b/src/tensor.jl new file mode 100644 index 000000000..15f79a94f --- /dev/null +++ b/src/tensor.jl @@ -0,0 +1,85 @@ +# High-level tile-decomposed matmul on top of `tensor_ops::matmul2d`. + +function _tensor_matmul_kernel!(C::MtlDeviceArray, A::MtlDeviceArray, B::MtlDeviceArray, + M::UInt32, N::UInt32, K::UInt32, + tm::UInt32, tn::UInt32, tk::UInt32) + threads = Int32(threads_per_threadgroup_3d().x) + tgid = threadgroup_position_in_grid_3d() + n_tile = Int32(tgid.x) - Int32(1) + m_tile = Int32(tgid.y) - Int32(1) + n_off = n_tile * Int32(tn) + m_off = m_tile * Int32(tm) + + # In the matmul ABI, output is laid out as Julia col-major (apple_N, apple_M). + # We swap operands at the call site (apple_A buf = Julia B, apple_B buf = + # Julia A) so the natural Julia semantics `C = A * B` come out; see the + # derivation in `tensor_matmul!`. + tA = MtlInlineTensor(B, (K, M)) + tB = MtlInlineTensor(A, (N, K)) + tC = MtlInlineTensor(C, (N, M)) + + mC = view(tC, (n_off, m_off), (Int32(tn), Int32(tm))) + + desc = matmul2d_descriptor(tm, tn, tk; mode = matmul2d_multiply_accumulate) + nslices = Int32(K ÷ tk) + for s in Int32(0):(nslices - Int32(1)) + k_off = s * Int32(tk) + mA = view(tA, (k_off, m_off), (Int32(tk), Int32(tm))) + mB = view(tB, (n_off, k_off), (Int32(tn), Int32(tk))) + tensor_ops_matmul2d!(desc, mA, mB, mC, threads) + end + return +end + +""" + tensor_matmul!(C, A, B; tile_m=64, tile_n=64, tile_k=32) + +Compute `C = A * B` (Julia matrix-product semantics) using +`tensor_ops::matmul2d` with a tile decomposition. `A` is `(m, k)`, `B` is +`(k, n)`, `C` is `(m, n)`, all column-major `MtlMatrix`. `C` is zeroed +before the matmul (the kernel accumulates). + +`tile_m`, `tile_n`, `tile_k` set the per-threadgroup tile shape. The +matrix dimensions must be evenly divisible by their respective tiles +(`m % tile_m == 0`, `n % tile_n == 0`, `k % tile_k == 0`). Each +threadgroup uses 4 SIMD-groups (128 threads on the M1/M2 hardware) and +covers one `(tile_m, tile_n)` output tile; the K dimension is looped over +inside the kernel via `multiply_accumulate`. + +The implementation maps Julia's natural `C = A * B` to the matmul2d +operand convention by swapping `A`/`B` at the kernel level: matmul2d's +`apple_A` slot receives Julia's `B`, its `apple_B` slot receives Julia's +`A`. This lets every operand use packed strides without an explicit +transpose flag — matmul2d's storage order for the (M, N) output happens +to be the transpose of how Julia reads the same buffer column-major, so +two swaps cancel out and the user gets the answer they expect. + +Requires macOS 26+. +""" +function tensor_matmul!(C::MtlMatrix{TC}, A::MtlMatrix{TA}, B::MtlMatrix{TB}; + tile_m::Integer = 64, tile_n::Integer = 64, + tile_k::Integer = 32) where {TA, TB, TC} + m, k = size(A) + k2, n = size(B) + k == k2 || throw(DimensionMismatch( + "A is ($m, $k), B is ($k2, $n) — inner dims must match")) + size(C) == (m, n) || throw(DimensionMismatch( + "C is $(size(C)), expected ($m, $n)")) + + # Apple-side dims (see above for the swap derivation). + aM, aN, aK = n, m, k + aM % tile_m == 0 || throw(ArgumentError( + "tile_m=$tile_m must divide n=$n")) + aN % tile_n == 0 || throw(ArgumentError( + "tile_n=$tile_n must divide m=$m")) + aK % tile_k == 0 || throw(ArgumentError( + "tile_k=$tile_k must divide k=$k")) + + fill!(C, zero(TC)) + groups = (aN ÷ tile_n, aM ÷ tile_m, 1) + @metal threads = 4 * 32 groups = groups _tensor_matmul_kernel!( + C, A, B, + UInt32(aM), UInt32(aN), UInt32(aK), + UInt32(tile_m), UInt32(tile_n), UInt32(tile_k)) + return C +end From e0e3f4c07365da830c23207ced6036c85fdce944 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 26 May 2026 16:32:52 +0200 Subject: [PATCH 16/24] Generalize MtlInlineTensor to arbitrary rank. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Constructors now take an NTuple{R, Integer} of extents (and optionally strides); view() takes NTuple{R} for origin and extents. Strides default to packed (computed via the prefix product of extents). The air.* intrinsics already take rank as an Int16, so the Julia wrappers just forward through. matmul2d itself is rank-2 only — but rank-3/4 tensors are useful for slicing into matmul operands and for the future convolution2d op. The per-thread descriptor buffer is bumped to 128 bytes to cover ranks up to ~8 with i32 indices, since the dynamic-alloca pattern Apple uses (`alloca i8, i64 %sz` keyed on get_descriptor_size_tensor) can't be expressed via Julia's static typing. Tested at rank 3 and rank 4: construct, slice, and read back extents via air.get_extent_private_tensor — values round-trip exactly. --- src/device/intrinsics/tensor.jl | 142 +++++++++++++++++--------------- 1 file changed, 74 insertions(+), 68 deletions(-) diff --git a/src/device/intrinsics/tensor.jl b/src/device/intrinsics/tensor.jl index cadb60d4e..19bb3cfa0 100644 --- a/src/device/intrinsics/tensor.jl +++ b/src/device/intrinsics/tensor.jl @@ -1,5 +1,5 @@ export MtlInlineTensor, matmul2d_descriptor, tensor_ops_matmul2d!, - matmul2d_multiply, matmul2d_multiply_accumulate + matmul2d_multiply, matmul2d_multiply_accumulate, tensor_matmul! using Core: LLVMPtr @@ -15,7 +15,13 @@ using Core: LLVMPtr # `Ref{NTuple{N, UInt8}}`. Julia's `llvm-alloc-opt` pass promotes the Ref to # a stack alloca once everything is inlined into the kernel. -const _TENSOR_DESCRIPTOR_SIZE = 64 +# Conservative upper bound on the size of an i32-indexed tensor descriptor. +# Apple's compiler emits a dynamic `alloca i8, i64 %sz` where `%sz` comes +# from `air.get_descriptor_size_tensor` (marked deferred-static-alloca-size +# so it's resolved at metallib build time). We use one static size that +# covers the ranks we care about (≤ 8 for i32-indexed tensors); higher +# ranks would need a bigger buffer. +const _TENSOR_DESCRIPTOR_SIZE = 128 const _TensorDescriptorStorage = Base.RefValue{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}} @@ -38,8 +44,8 @@ const _TensorDescriptorStorage = Base.RefValue{NTuple{_TENSOR_DESCRIPTOR_SIZE, U handle::_TensorDescriptorStorage, rank::Int16, data::LLVMPtr{UInt8, AS.Device}, - extents::NTuple{2, Int32}, - strides::NTuple{2, Int32}, + extents::NTuple{<:Any, Int32}, + strides::NTuple{<:Any, Int32}, contiguous::Int8, ) = ccall("extern air.init_strided_private_tensor.i32.global", llvmcall, Cvoid, @@ -51,8 +57,8 @@ const _TensorDescriptorStorage = Base.RefValue{NTuple{_TENSOR_DESCRIPTOR_SIZE, U handle::_TensorDescriptorStorage, rank::Int16, data::LLVMPtr{UInt8, AS.ThreadGroup}, - extents::NTuple{2, Int32}, - strides::NTuple{2, Int32}, + extents::NTuple{<:Any, Int32}, + strides::NTuple{<:Any, Int32}, contiguous::Int8, ) = ccall("extern air.init_strided_private_tensor.i32.local", llvmcall, Cvoid, @@ -70,8 +76,8 @@ const _TensorDescriptorStorage = Base.RefValue{NTuple{_TENSOR_DESCRIPTOR_SIZE, U dst::_TensorDescriptorStorage, src::_TensorDescriptorStorage, rank::Int16, - origin::NTuple{2, Int32}, - extents::NTuple{2, Int32}, + origin::NTuple{<:Any, Int32}, + extents::NTuple{<:Any, Int32}, ) = ccall("extern air.slice_private_tensor_private_tensor.s.i32", llvmcall, Cvoid, (Ref{UInt8}, Ref{UInt8}, Int16, Ref{Int32}, Ref{Int32}), @@ -83,105 +89,105 @@ const _TensorDescriptorStorage = Base.RefValue{NTuple{_TENSOR_DESCRIPTOR_SIZE, U """ MtlInlineTensor{T, R, ASpace} -Kernel-stack tensor view over an `MtlDeviceArray` or `MtlThreadGroupArray`, -suitable for use as an operand of [`tensor_ops_matmul2d!`](@ref). `T` is the -element type; `R` is the rank (only rank 2 is supported today); `ASpace` is -the address space of the underlying data (`AS.Device` or `AS.ThreadGroup`). -Backed by a thread-private byte buffer (an inline `Ref`) that the runtime -initializes at construction. +Kernel-stack tensor view over an `MtlDeviceArray` or `MtlThreadGroupArray`. +`T` is the element type, `R` the rank, `ASpace` the address space of the +underlying data (`AS.Device` or `AS.ThreadGroup`). Backed by a thread- +private byte buffer (an inline `Ref`) that the runtime initializes at +construction. -Note: extents follow the MSL `dextents{e1, e2, ...}` convention +Extents follow the MSL `dextents{e1, e2, ...}` convention (innermost dimension first), which is the row-major view the matmul kernel expects. For a Julia column-major `MtlMatrix(M, N)`, pass extents `(M, N)` if you want to treat columns as the inner dimension. + +`tensor_ops::matmul2d` itself is rank-2 — higher-rank tensors are useful +for slicing, multi-batch lifts, and the future `convolution2d` op. """ struct MtlInlineTensor{T, R, ASpace} storage::_TensorDescriptorStorage end -# In-kernel constructors: packed-stride rank-2 tensor over device or -# threadgroup memory. `contiguous` is `1` (packed) by default; the -# explicit-stride methods below pass `0`. -@device_function @inline function MtlInlineTensor{T, 2, AS.Device}( +# `tensor_inline` packed-stride strides: stride(0) = 1, stride(k) = prod(extents[1:k]). +@inline function _packed_strides(extents::NTuple{R, Int32}) where {R} + ntuple(Val(R)) do k + s = Int32(1) + for j in 1:(k - 1) + s *= extents[j] + end + s + end +end + +# In-kernel constructors (packed strides): +@device_function @inline function MtlInlineTensor{T, R, AS.Device}( data::MtlDeviceArray{T, <:Any, AS.Device}, - e1::Integer, e2::Integer) where {T} + extents::NTuple{R, <:Integer}) where {T, R} + e = Int32.(extents) storage = Ref{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}}() - init_strided_tensor_device!(storage, Int16(2), + init_strided_tensor_device!(storage, Int16(R), reinterpret(LLVMPtr{UInt8, AS.Device}, pointer(data)), - (Int32(e1), Int32(e2)), - (Int32(1), Int32(e1)), - Int8(1)) - return MtlInlineTensor{T, 2, AS.Device}(storage) + e, _packed_strides(e), Int8(1)) + return MtlInlineTensor{T, R, AS.Device}(storage) end -@device_function @inline function MtlInlineTensor{T, 2, AS.ThreadGroup}( +@device_function @inline function MtlInlineTensor{T, R, AS.ThreadGroup}( data::MtlDeviceArray{T, <:Any, AS.ThreadGroup}, - e1::Integer, e2::Integer) where {T} + extents::NTuple{R, <:Integer}) where {T, R} + e = Int32.(extents) storage = Ref{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}}() - init_strided_tensor_threadgroup!(storage, Int16(2), + init_strided_tensor_threadgroup!(storage, Int16(R), reinterpret(LLVMPtr{UInt8, AS.ThreadGroup}, pointer(data)), - (Int32(e1), Int32(e2)), - (Int32(1), Int32(e1)), - Int8(1)) - return MtlInlineTensor{T, 2, AS.ThreadGroup}(storage) + e, _packed_strides(e), Int8(1)) + return MtlInlineTensor{T, R, AS.ThreadGroup}(storage) end -# Explicit-stride variants (mark the tensor as non-packed for the runtime). -@device_function @inline function MtlInlineTensor{T, 2, AS.Device}( +# Explicit-stride variants: +@device_function @inline function MtlInlineTensor{T, R, AS.Device}( data::MtlDeviceArray{T, <:Any, AS.Device}, - e1::Integer, e2::Integer, - s1::Integer, s2::Integer) where {T} + extents::NTuple{R, <:Integer}, + strides::NTuple{R, <:Integer}) where {T, R} storage = Ref{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}}() - init_strided_tensor_device!(storage, Int16(2), + init_strided_tensor_device!(storage, Int16(R), reinterpret(LLVMPtr{UInt8, AS.Device}, pointer(data)), - (Int32(e1), Int32(e2)), - (Int32(s1), Int32(s2)), - Int8(0)) - return MtlInlineTensor{T, 2, AS.Device}(storage) + Int32.(extents), Int32.(strides), Int8(0)) + return MtlInlineTensor{T, R, AS.Device}(storage) end -@device_function @inline function MtlInlineTensor{T, 2, AS.ThreadGroup}( +@device_function @inline function MtlInlineTensor{T, R, AS.ThreadGroup}( data::MtlDeviceArray{T, <:Any, AS.ThreadGroup}, - e1::Integer, e2::Integer, - s1::Integer, s2::Integer) where {T} + extents::NTuple{R, <:Integer}, + strides::NTuple{R, <:Integer}) where {T, R} storage = Ref{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}}() - init_strided_tensor_threadgroup!(storage, Int16(2), + init_strided_tensor_threadgroup!(storage, Int16(R), reinterpret(LLVMPtr{UInt8, AS.ThreadGroup}, pointer(data)), - (Int32(e1), Int32(e2)), - (Int32(s1), Int32(s2)), - Int8(0)) - return MtlInlineTensor{T, 2, AS.ThreadGroup}(storage) + Int32.(extents), Int32.(strides), Int8(0)) + return MtlInlineTensor{T, R, AS.ThreadGroup}(storage) end -# Convenience: infer address space from the array. +# Convenience: infer rank and address space from the inputs. @inline MtlInlineTensor(data::MtlDeviceArray{T, <:Any, A}, - extents::NTuple{2, <:Integer}) where {T, A} = - MtlInlineTensor{T, 2, A}(data, extents[1], extents[2]) + extents::NTuple{R, <:Integer}) where {T, R, A} = + MtlInlineTensor{T, R, A}(data, extents) @inline MtlInlineTensor(data::MtlDeviceArray{T, <:Any, A}, - extents::NTuple{2, <:Integer}, - strides::NTuple{2, <:Integer}) where {T, A} = - MtlInlineTensor{T, 2, A}(data, extents[1], extents[2], strides[1], strides[2]) + extents::NTuple{R, <:Integer}, + strides::NTuple{R, <:Integer}) where {T, R, A} = + MtlInlineTensor{T, R, A}(data, extents, strides) Base.eltype(::Type{<:MtlInlineTensor{T}}) where {T} = T Base.eltype(::MtlInlineTensor{T}) where {T} = T # Slice. Origins are zero-based to match MSL semantics. -@device_function @inline function _slice_inline_tensor( - t::MtlInlineTensor{T, 2, A}, - o1::Integer, o2::Integer, - e1::Integer, e2::Integer) where {T, A} +@device_function @inline function Base.view( + t::MtlInlineTensor{T, R, A}, + origin::NTuple{R, <:Integer}, + extents::NTuple{R, <:Integer}) where {T, R, A} storage = Ref{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}}() - slice_private_tensor!(storage, t.storage, Int16(2), - (Int32(o1), Int32(o2)), - (Int32(e1), Int32(e2))) - return MtlInlineTensor{T, 2, A}(storage) + slice_private_tensor!(storage, t.storage, Int16(R), + Int32.(origin), Int32.(extents)) + return MtlInlineTensor{T, R, A}(storage) end -@inline Base.view(t::MtlInlineTensor{T, 2}, origin::NTuple{2, <:Integer}, - extents::NTuple{2, <:Integer}) where {T} = - _slice_inline_tensor(t, origin[1], origin[2], extents[1], extents[2]) - ## matmul2d descriptor (mirrors `mpp::tensor_ops::matmul2d_descriptor`). @@ -259,7 +265,7 @@ _tensorops_aspace_prefix(::Val{AS.ThreadGroup}) = "tg" `dest = left * right (+ dest if mode=multiply_accumulate)` executed cooperatively by `threads` participating threads (i.e. -`simdgroup_size * num_simdgroups`). Each operand is an +`simdgroup_size * num_simdgroups`). Each operand is a rank-2 [`MtlInlineTensor`](@ref) over either device or threadgroup memory; the right `__tensorops_impl_matmul2d_op_run_{aspace}_{type}_..._*` symbol is picked based on the operand types and address spaces. From 421d754b838d0716f43107481776c972abb709fe Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 26 May 2026 19:18:32 +0200 Subject: [PATCH 17/24] Wrap tensor_ops::matmul2d in a TensorOpsMatmul2D{DESC, NSIMD} op type. Mirrors Apple's MSL `matmul2d>` template: both the descriptor and the simdgroup count are encoded as type parameters so the generated AIR carries them as compile-time constants. The simdgroup count specifically: the AGX register allocator runs out of stack registers ("LLVM ERROR: The shader is out of stack registers space") when the simdgroup count is a runtime value and two matmul calls live in the same kernel, which our old `tensor_ops_matmul2d!` signature forced. Removing the runtime-threads entry point closes the footgun since MSL has no equivalent. The callable needs all three of `@device_function`, `@inline`, and `@generated`: without `@device_function` the call site falls off the GPUCompiler overlay table and downstream MtlInlineTensor/view calls lose their stack-alloca lowering. `_tensor_matmul_kernel!` now takes Val{TM}, Val{TN}, Val{TK}, Val{NSIMD}; both flashattention kernels likewise lift their tile shape into Val parameters before dispatch. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/flashattention.jl | 33 +++++++++++++------- src/device/intrinsics/tensor.jl | 53 ++++++++++++++++++++++----------- src/tensor.jl | 37 ++++++++++++----------- 3 files changed, 77 insertions(+), 46 deletions(-) diff --git a/examples/flashattention.jl b/examples/flashattention.jl index a14e4240c..a7058bf89 100644 --- a/examples/flashattention.jl +++ b/examples/flashattention.jl @@ -201,15 +201,17 @@ function _fa_tensor_qk_softmax!(Q::AbstractMatrix{Float16}, K::AbstractMatrix{Float16}, S::AbstractMatrix{Float32}, P::AbstractMatrix{Float16}, - D::UInt32, N::UInt32, scale::Float32) - threads = Int32(threads_per_threadgroup_3d().x) + D::UInt32, N::UInt32, scale::Float32, + ::Val{TN}, ::Val{TD}, + ::Val{NSIMD}) where {TN, TD, NSIMD} tid = Int32(thread_position_in_threadgroup_3d().x) - Int32(1) A = MtlInlineTensor(Q, (D, N)) B = MtlInlineTensor(K, (D, N)) C = MtlInlineTensor(S, (N, N)) - desc = matmul2d_descriptor(N, N, D; transpose_right = true) - tensor_ops_matmul2d!(desc, A, B, C, threads) + op = TensorOpsMatmul2D{matmul2d_descriptor(TN, TN, TD; transpose_right = true), + Int32(NSIMD)}() + op(A, B, C) threadgroup_barrier(Metal.MemoryFlagDevice) # Column-wise softmax. 64 of 128 threads do real work; the rest wait. @@ -239,13 +241,14 @@ end function _fa_tensor_pv!(O::AbstractMatrix{Float16}, V::AbstractMatrix{Float16}, P::AbstractMatrix{Float16}, - D::UInt32, N::UInt32) - threads = Int32(threads_per_threadgroup_3d().x) + D::UInt32, N::UInt32, + ::Val{TN}, ::Val{TD}, + ::Val{NSIMD}) where {TN, TD, NSIMD} A = MtlInlineTensor(P, (N, N)) B = MtlInlineTensor(V, (D, N)) C = MtlInlineTensor(O, (D, N)) - desc = matmul2d_descriptor(N, D, N) - tensor_ops_matmul2d!(desc, A, B, C, threads) + op = TensorOpsMatmul2D{matmul2d_descriptor(TN, TD, TN), Int32(NSIMD)}() + op(A, B, C) return end @@ -267,7 +270,13 @@ function attention_tensor(Q::MtlArray{Float16,4}, K::MtlArray{Float16,4}, P = MtlArray{Float16}(undef, N, N) simdgroup_size = 32 - threads = 4 * simdgroup_size # matmul descriptor wants execution_simdgroups<4> + nsimd = 4 # matches `execution_simdgroups<4>` in the op desc + threads = nsimd * simdgroup_size + + # The matmul descriptors carry (TN, TD) — the static tile shape per head. + TN_val = Val(Int32(N)) + TD_val = Val(Int32(D)) + NS_val = Val(Int32(nsimd)) for b in 1:B, h in 1:H Qm = view(Q, :, :, h, b) @@ -276,8 +285,10 @@ function attention_tensor(Q::MtlArray{Float16,4}, K::MtlArray{Float16,4}, Om = view(O, :, :, h, b) @metal threads = threads _fa_tensor_qk_softmax!(Qm, Km, S, P, UInt32(D), UInt32(N), - Float32(scale)) - @metal threads = threads _fa_tensor_pv!(Om, Vm, P, UInt32(D), UInt32(N)) + Float32(scale), + TN_val, TD_val, NS_val) + @metal threads = threads _fa_tensor_pv!(Om, Vm, P, UInt32(D), UInt32(N), + TN_val, TD_val, NS_val) end Metal.synchronize() return O diff --git a/src/device/intrinsics/tensor.jl b/src/device/intrinsics/tensor.jl index 19bb3cfa0..c4ca960ea 100644 --- a/src/device/intrinsics/tensor.jl +++ b/src/device/intrinsics/tensor.jl @@ -1,4 +1,4 @@ -export MtlInlineTensor, matmul2d_descriptor, tensor_ops_matmul2d!, +export MtlInlineTensor, matmul2d_descriptor, TensorOpsMatmul2D, matmul2d_multiply, matmul2d_multiply_accumulate, tensor_matmul! using Core: LLVMPtr @@ -209,17 +209,19 @@ into the destination, set `mode = matmul2d_multiply_accumulate` and zero the destination before the loop. A typical pattern: ```julia -desc = matmul2d_descriptor(M, N, TileK; mode = matmul2d_multiply_accumulate) +op = TensorOpsMatmul2D{matmul2d_descriptor(M, N, TileK; + mode = matmul2d_multiply_accumulate), + Int32(NSIMD)}() for s in 0:(nslices - 1) sA = view(tA, (Int32(s) * Int32(TileK), Int32(0)), (Int32(TileK), Int32(M))) sB = view(tB, (Int32(0), Int32(s) * Int32(TileK)), (Int32(N), Int32(TileK))) - tensor_ops_matmul2d!(desc, sA, sB, tC, threads) + op(sA, sB, tC) end ``` Keep the loop's trip count dynamic — a compile-time-known trip count -that fully unrolls into multiple `tensor_ops_matmul2d!` call sites -currently crashes Apple's back-end (see `ISSUE-tensor-ops.md`). +that fully unrolls into multiple op call sites currently crashes Apple's +back-end (see `ISSUE-tensor-ops.md`). """ struct matmul2d_descriptor m::Int32 @@ -261,33 +263,47 @@ _tensorops_aspace_prefix(::Val{AS.Device}) = "dv" _tensorops_aspace_prefix(::Val{AS.ThreadGroup}) = "tg" """ - tensor_ops_matmul2d!(desc, left, right, dest, threads) - -`dest = left * right (+ dest if mode=multiply_accumulate)` executed -cooperatively by `threads` participating threads (i.e. -`simdgroup_size * num_simdgroups`). Each operand is a rank-2 -[`MtlInlineTensor`](@ref) over either device or threadgroup memory; the -right `__tensorops_impl_matmul2d_op_run_{aspace}_{type}_..._*` symbol is -picked based on the operand types and address spaces. + TensorOpsMatmul2D{DESC, NSIMD} + +Configured `tensor_ops::matmul2d` op. Mirrors Apple's MSL +`matmul2d>` template: `DESC` is the +[`matmul2d_descriptor`](@ref) value, `NSIMD` is the simdgroup count +(`execution_simdgroups`). Both are encoded as type parameters so the +generated AIR carries them as compile-time constants — `NSIMD` +specifically: the AGX register allocator runs out of stack registers +when the simdgroup count is a runtime value and two matmul calls live +in the same kernel. + +Construct with [`TensorOpsMatmul2D(desc, Val(N))`](@ref) and invoke +like a function: + +```julia +op = TensorOpsMatmul2D(matmul2d_descriptor(64, 32, -1), Val(4)) +op(left, right, dest) # run; mirrors `op.run(...)` in MSL +``` """ -@generated function tensor_ops_matmul2d!( - desc::matmul2d_descriptor, +struct TensorOpsMatmul2D{DESC, NSIMD} end + +TensorOpsMatmul2D(desc::matmul2d_descriptor, ::Val{NSIMD}) where {NSIMD} = + TensorOpsMatmul2D{desc, Int32(NSIMD)}() + +@device_function @inline @generated function (::TensorOpsMatmul2D{DESC, NSIMD})( left::MtlInlineTensor{TL, 2, AL}, right::MtlInlineTensor{TR, 2, AR}, - dest::MtlInlineTensor{TD, 2, AD}, - threads::Int32) where {TL, TR, TD, AL, AR, AD} + dest::MtlInlineTensor{TD, 2, AD}) where {DESC, NSIMD, TL, TR, TD, AL, AR, AD} sym = "__tensorops_impl_matmul2d_op_run" * "_$(_tensorops_aspace_prefix(Val(AL)))_$(_tensorops_suffix(TL))" * "_$(_tensorops_aspace_prefix(Val(AR)))_$(_tensorops_suffix(TR))" * "_$(_tensorops_aspace_prefix(Val(AD)))_$(_tensorops_suffix(TD))" quote + threads = Int32(NSIMD) * Int32(threads_per_simdgroup()) ccall($"extern $sym", llvmcall, Cvoid, (Ref{matmul2d_descriptor}, Ref{UInt8}, Int32, Ref{UInt8}, Int32, Ref{UInt8}, Int32, Int32), - desc, + $(QuoteNode(DESC)), left.storage, $_TENSOR_DESC_INLINE, right.storage, $_TENSOR_DESC_INLINE, dest.storage, $_TENSOR_DESC_INLINE, @@ -295,3 +311,4 @@ picked based on the operand types and address spaces. return nothing end end + diff --git a/src/tensor.jl b/src/tensor.jl index 15f79a94f..46e0409cd 100644 --- a/src/tensor.jl +++ b/src/tensor.jl @@ -1,32 +1,33 @@ # High-level tile-decomposed matmul on top of `tensor_ops::matmul2d`. +# Specialized on the tile shape (TM, TN, TK) and simdgroup count (NSIMD) so the +# matmul descriptor and execution width are compile-time constants — see +# `TensorOpsMatmul2D`. function _tensor_matmul_kernel!(C::MtlDeviceArray, A::MtlDeviceArray, B::MtlDeviceArray, M::UInt32, N::UInt32, K::UInt32, - tm::UInt32, tn::UInt32, tk::UInt32) - threads = Int32(threads_per_threadgroup_3d().x) + ::Val{TM}, ::Val{TN}, ::Val{TK}, + ::Val{NSIMD}) where {TM, TN, TK, NSIMD} tgid = threadgroup_position_in_grid_3d() n_tile = Int32(tgid.x) - Int32(1) m_tile = Int32(tgid.y) - Int32(1) - n_off = n_tile * Int32(tn) - m_off = m_tile * Int32(tm) + n_off = n_tile * Int32(TN) + m_off = m_tile * Int32(TM) - # In the matmul ABI, output is laid out as Julia col-major (apple_N, apple_M). - # We swap operands at the call site (apple_A buf = Julia B, apple_B buf = - # Julia A) so the natural Julia semantics `C = A * B` come out; see the - # derivation in `tensor_matmul!`. tA = MtlInlineTensor(B, (K, M)) tB = MtlInlineTensor(A, (N, K)) tC = MtlInlineTensor(C, (N, M)) - mC = view(tC, (n_off, m_off), (Int32(tn), Int32(tm))) + mC = view(tC, (n_off, m_off), (Int32(TN), Int32(TM))) - desc = matmul2d_descriptor(tm, tn, tk; mode = matmul2d_multiply_accumulate) - nslices = Int32(K ÷ tk) + op = TensorOpsMatmul2D{matmul2d_descriptor(TM, TN, TK; + mode = matmul2d_multiply_accumulate), + Int32(NSIMD)}() + nslices = Int32(K ÷ UInt32(TK)) for s in Int32(0):(nslices - Int32(1)) - k_off = s * Int32(tk) - mA = view(tA, (k_off, m_off), (Int32(tk), Int32(tm))) - mB = view(tB, (n_off, k_off), (Int32(tn), Int32(tk))) - tensor_ops_matmul2d!(desc, mA, mB, mC, threads) + k_off = s * Int32(TK) + mA = view(tA, (k_off, m_off), (Int32(TK), Int32(TM))) + mB = view(tB, (n_off, k_off), (Int32(TN), Int32(TK))) + op(mA, mB, mC) end return end @@ -77,9 +78,11 @@ function tensor_matmul!(C::MtlMatrix{TC}, A::MtlMatrix{TA}, B::MtlMatrix{TB}; fill!(C, zero(TC)) groups = (aN ÷ tile_n, aM ÷ tile_m, 1) - @metal threads = 4 * 32 groups = groups _tensor_matmul_kernel!( + nsimd = 4 + @metal threads = nsimd * 32 groups = groups _tensor_matmul_kernel!( C, A, B, UInt32(aM), UInt32(aN), UInt32(aK), - UInt32(tile_m), UInt32(tile_n), UInt32(tile_k)) + Val(Int32(tile_m)), Val(Int32(tile_n)), Val(Int32(tile_k)), + Val(Int32(nsimd))) return C end From 181a82f9d14ff40de1a0d60449502721f7cd8625 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 26 May 2026 19:20:48 +0200 Subject: [PATCH 18/24] =?UTF-8?q?flashattention:=20fuse=20the=20tensor-ops?= =?UTF-8?q?=20kernels=20into=20a=20single=20QK=E2=86=92softmax=E2=86=92PV?= =?UTF-8?q?=20pass.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Now that two matmul callsites in one kernel compile cleanly (via the TensorOpsMatmul2D wrapper), the split that worked around the back-end crash is no longer needed. The scores and softmaxed P tiles move from device MtlArrays to threadgroup memory, so there's no host-side scratch allocation and no device-memory round-trip between the two matmuls. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/flashattention.jl | 121 ++++++++++++++++++------------------- 1 file changed, 60 insertions(+), 61 deletions(-) diff --git a/examples/flashattention.jl b/examples/flashattention.jl index a7058bf89..ce40314a0 100644 --- a/examples/flashattention.jl +++ b/examples/flashattention.jl @@ -32,23 +32,27 @@ # tuned reference. Works on macOS 13+ / M1+. # # attention_tensor(Q, K, V) -# Two kernels (QKᵀ + softmax, then PV) using the Metal 4 -# `tensor_ops::matmul2d` primitives. Each kernel builds +# One fused kernel (QKᵀ → softmax → ·V) using the Metal 4 +# `tensor_ops::matmul2d` primitives. The kernel builds # `tensor_inline` views over the `MtlDeviceArray` inputs, so the # kernel signature stays buffer-shaped — no host-side `MTLTensor` # / `MTL4ComputeCommandEncoder` wrapping is needed. The matmuls # lower to externally-defined `__tensorops_impl_matmul2d_op_*` # symbols (linked from the MetalPerformancePrimitives runtime), -# not `air.*` intrinsics. Requires macOS 26+; on M3/M4 the runtime -# still lowers to the same simdgroup MMA hardware. Limited to N = -# D = 64 because the matmul descriptor is specialized to that -# single 64x64 tile. Splitting QK and PV across two dispatches — -# rather than one fused kernel — works around an Apple back-end -# crash on two `__tensorops_impl_matmul2d_op_run_*` calls in a -# single kernel; it also means the scores tile is materialized in -# device memory (`cooperative_tensor` would keep it in registers -# for true postfix-fusion, but the device-side dynamic-alloca -# support that requires isn't wired up yet). +# not `air.*` intrinsics. Scratch for the scores and softmaxed P +# lives in threadgroup memory for the lifetime of the kernel — no +# device-memory round-trip between the two matmuls. Requires +# macOS 26+; on M3/M4 the runtime still lowers to the same +# simdgroup MMA hardware. Limited to N = D = 64 because the +# matmul descriptor is specialized to that single 64×64 tile, and +# the two matmul callsites only avoid Apple's back-end +# "out of stack registers" crash when the `matmul2d` op is built +# through Metal.jl's `TensorOpsMatmul2D{DESC, NSIMD}` wrapper +# (descriptor + simdgroup count as type parameters, mirroring +# MSL's `matmul2d>`). +# `cooperative_tensor` would keep the scores tile in registers for +# true postfix-fusion, but the device-side dynamic-alloca support +# that requires isn't wired up yet. # # All implementations take Julia 4-D `(head_dim, seq, num_heads, batch)` # inputs — MPSGraph sees these reversed as `(batch, num_heads, seq, @@ -191,30 +195,40 @@ end ## Custom kernel with Metal 4 tensor ops (matmul2d, inline tensors) -# Step 1: compute Q^T K into a Float32 scores buffer, then a row-wise softmax -# (cast to Float16) into a P buffer. The matmul writes its (M, N) output in a -# layout that Julia reads as K^T Q (the transpose of Q^T K), so we apply a +# One fused kernel per (head, batch): QKᵀ → softmax → ·V, with scores and +# softmaxed P kept in threadgroup memory. The matmul writes its (M, N) output +# in a layout that Julia reads as KᵀQ (the transpose of QᵀK), so we apply a # *column*-wise softmax — that's what corresponds to row-wise softmax of the -# implicit Q^T K, and it's the right direction for column-major contiguous +# implicit QᵀK, and it's the right direction for column-major contiguous # memory access. -function _fa_tensor_qk_softmax!(Q::AbstractMatrix{Float16}, - K::AbstractMatrix{Float16}, - S::AbstractMatrix{Float32}, - P::AbstractMatrix{Float16}, - D::UInt32, N::UInt32, scale::Float32, - ::Val{TN}, ::Val{TD}, - ::Val{NSIMD}) where {TN, TD, NSIMD} +function _fa_tensor!(O::AbstractMatrix{Float16}, + Q::AbstractMatrix{Float16}, + K::AbstractMatrix{Float16}, + V::AbstractMatrix{Float16}, + D::UInt32, N::UInt32, scale::Float32, + ::Val{TN}, ::Val{TD}, + ::Val{NSIMD}) where {TN, TD, NSIMD} tid = Int32(thread_position_in_threadgroup_3d().x) - Int32(1) - A = MtlInlineTensor(Q, (D, N)) - B = MtlInlineTensor(K, (D, N)) - C = MtlInlineTensor(S, (N, N)) - op = TensorOpsMatmul2D{matmul2d_descriptor(TN, TN, TD; transpose_right = true), - Int32(NSIMD)}() - op(A, B, C) - threadgroup_barrier(Metal.MemoryFlagDevice) + # Scratch lives in threadgroup memory for the entire kernel: scores tile + # (Float32 for accumulator precision) and the softmaxed P (Float16 for the + # second matmul). + S = MtlThreadGroupArray(Float32, (TN, TN)) + P = MtlThreadGroupArray(Float16, (TN, TN)) + + # Step 1: S = QᵀK (read as KᵀQ in Julia layout, see above). + let tA = MtlInlineTensor(Q, (D, N)), + tB = MtlInlineTensor(K, (D, N)), + tC = MtlInlineTensor(S, (N, N)) + op = TensorOpsMatmul2D{matmul2d_descriptor(TN, TN, TD; + transpose_right = true), + Int32(NSIMD)}() + op(tA, tB, tC) + end + threadgroup_barrier(Metal.MemoryFlagThreadGroup) - # Column-wise softmax. 64 of 128 threads do real work; the rest wait. + # Step 2: column-wise softmax. N of (NSIMD*32) threads do real work; the + # rest wait at the barrier below. @inbounds if tid < Int32(N) col = tid + Int32(1) m = -Inf32 @@ -233,22 +247,16 @@ function _fa_tensor_qk_softmax!(Q::AbstractMatrix{Float16}, P[i, col] = Float16(S[i, col] * inv_s) end end - return -end + threadgroup_barrier(Metal.MemoryFlagThreadGroup) -# Step 2: O = V · P (in Julia view; equivalent to V · P_attn^T because the -# softmax output is stored in the transposed layout). -function _fa_tensor_pv!(O::AbstractMatrix{Float16}, - V::AbstractMatrix{Float16}, - P::AbstractMatrix{Float16}, - D::UInt32, N::UInt32, - ::Val{TN}, ::Val{TD}, - ::Val{NSIMD}) where {TN, TD, NSIMD} - A = MtlInlineTensor(P, (N, N)) - B = MtlInlineTensor(V, (D, N)) - C = MtlInlineTensor(O, (D, N)) - op = TensorOpsMatmul2D{matmul2d_descriptor(TN, TD, TN), Int32(NSIMD)}() - op(A, B, C) + # Step 3: O = V·P (Julia view; equivalent to V·Pᵀ in math notation because + # the softmax output is stored in the transposed layout). + let tA = MtlInlineTensor(P, (N, N)), + tB = MtlInlineTensor(V, (D, N)), + tC = MtlInlineTensor(O, (D, N)) + op = TensorOpsMatmul2D{matmul2d_descriptor(TN, TD, TN), Int32(NSIMD)}() + op(tA, tB, tC) + end return end @@ -257,18 +265,11 @@ function attention_tensor(Q::MtlArray{Float16,4}, K::MtlArray{Float16,4}, scale = inv(sqrt(Float32(size(Q, 1))))) @assert size(Q) == size(K) == size(V) D, N, H, B = size(Q) - # MPP requires a real tile, and the (m, n, k) descriptor below is - # specialized to (N, N, D); allowing other shapes would mean dispatching - # multiple threadgroups. + # The matmul descriptor below is specialized to (N, N, D); allowing other + # shapes would mean dispatching multiple threadgroups. @assert D == N "tensor-ops kernel currently expects D == N" O = similar(Q) - # Allocate persistent scratch for the scores / softmax outputs. One per - # (head, batch) pair would let us overlap; for clarity we reuse a single - # pair across all dispatches. - S = MtlArray{Float32}(undef, N, N) - P = MtlArray{Float16}(undef, N, N) - simdgroup_size = 32 nsimd = 4 # matches `execution_simdgroups<4>` in the op desc threads = nsimd * simdgroup_size @@ -283,12 +284,10 @@ function attention_tensor(Q::MtlArray{Float16,4}, K::MtlArray{Float16,4}, Km = view(K, :, :, h, b) Vm = view(V, :, :, h, b) Om = view(O, :, :, h, b) - @metal threads = threads _fa_tensor_qk_softmax!(Qm, Km, S, P, - UInt32(D), UInt32(N), - Float32(scale), - TN_val, TD_val, NS_val) - @metal threads = threads _fa_tensor_pv!(Om, Vm, P, UInt32(D), UInt32(N), - TN_val, TD_val, NS_val) + @metal threads = threads _fa_tensor!(Om, Qm, Km, Vm, + UInt32(D), UInt32(N), + Float32(scale), + TN_val, TD_val, NS_val) end Metal.synchronize() return O From 1325b40b27f42bc6885548bb782db143d714419b Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 26 May 2026 19:31:01 +0200 Subject: [PATCH 19/24] flashattention: dispatch (head, batch) pairs as one grid, not a host loop. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously the wrapper iterated over `(b, h)` on the host and submitted H*B separate kernel launches. Now a single dispatch with grid = (H, B) covers them all — each threadgroup reads its own `(h, b)` from `threadgroup_position_in_grid` and slices the 4-D buffers via pointer arithmetic. Heads run in parallel where the hardware can, no per-launch encoder overhead. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/flashattention.jl | 79 +++++++++++++++++++++----------------- 1 file changed, 43 insertions(+), 36 deletions(-) diff --git a/examples/flashattention.jl b/examples/flashattention.jl index ce40314a0..3177e986a 100644 --- a/examples/flashattention.jl +++ b/examples/flashattention.jl @@ -33,7 +33,10 @@ # # attention_tensor(Q, K, V) # One fused kernel (QKᵀ → softmax → ·V) using the Metal 4 -# `tensor_ops::matmul2d` primitives. The kernel builds +# `tensor_ops::matmul2d` primitives. Single dispatch with grid = +# (H, B), one threadgroup per (head, batch) pair, so all heads +# run in parallel — the kernel reads its own `(h, b)` from +# `threadgroup_position_in_grid`. The kernel builds # `tensor_inline` views over the `MtlDeviceArray` inputs, so the # kernel signature stays buffer-shaped — no host-side `MTLTensor` # / `MTL4ComputeCommandEncoder` wrapping is needed. The matmuls @@ -201,14 +204,27 @@ end # *column*-wise softmax — that's what corresponds to row-wise softmax of the # implicit QᵀK, and it's the right direction for column-major contiguous # memory access. -function _fa_tensor!(O::AbstractMatrix{Float16}, - Q::AbstractMatrix{Float16}, - K::AbstractMatrix{Float16}, - V::AbstractMatrix{Float16}, - D::UInt32, N::UInt32, scale::Float32, - ::Val{TN}, ::Val{TD}, - ::Val{NSIMD}) where {TN, TD, NSIMD} - tid = Int32(thread_position_in_threadgroup_3d().x) - Int32(1) +function _fa_tensor!(O::MtlDeviceArray{Float16, 4}, + Q::MtlDeviceArray{Float16, 4}, + K::MtlDeviceArray{Float16, 4}, + V::MtlDeviceArray{Float16, 4}, + scale::Float32, + ::Val{TD}, ::Val{TN}, + ::Val{NSIMD}) where {TD, TN, NSIMD} + # One threadgroup per (head, batch) pair. + tgid = threadgroup_position_in_grid_3d() + h = Int32(tgid.x) - Int32(1) + b = Int32(tgid.y) - Int32(1) + tid = Int32(thread_position_in_threadgroup_3d().x) - Int32(1) + + # Pointer arithmetic for the (h, b) slice of each 4-D buffer. + H = Int32(size(Q, 3)) + DN = Int32(TD) * Int32(TN) + slice_first = (b * H + h) * DN + Int32(1) + Qb = MtlDeviceArray{Float16, 2, Metal.AS.Device}((Int32(TD), Int32(TN)), pointer(Q, slice_first)) + Kb = MtlDeviceArray{Float16, 2, Metal.AS.Device}((Int32(TD), Int32(TN)), pointer(K, slice_first)) + Vb = MtlDeviceArray{Float16, 2, Metal.AS.Device}((Int32(TD), Int32(TN)), pointer(V, slice_first)) + Ob = MtlDeviceArray{Float16, 2, Metal.AS.Device}((Int32(TD), Int32(TN)), pointer(O, slice_first)) # Scratch lives in threadgroup memory for the entire kernel: scores tile # (Float32 for accumulator precision) and the softmaxed P (Float16 for the @@ -217,9 +233,9 @@ function _fa_tensor!(O::AbstractMatrix{Float16}, P = MtlThreadGroupArray(Float16, (TN, TN)) # Step 1: S = QᵀK (read as KᵀQ in Julia layout, see above). - let tA = MtlInlineTensor(Q, (D, N)), - tB = MtlInlineTensor(K, (D, N)), - tC = MtlInlineTensor(S, (N, N)) + let tA = MtlInlineTensor(Qb, (Int32(TD), Int32(TN))), + tB = MtlInlineTensor(Kb, (Int32(TD), Int32(TN))), + tC = MtlInlineTensor(S, (Int32(TN), Int32(TN))) op = TensorOpsMatmul2D{matmul2d_descriptor(TN, TN, TD; transpose_right = true), Int32(NSIMD)}() @@ -227,23 +243,23 @@ function _fa_tensor!(O::AbstractMatrix{Float16}, end threadgroup_barrier(Metal.MemoryFlagThreadGroup) - # Step 2: column-wise softmax. N of (NSIMD*32) threads do real work; the + # Step 2: column-wise softmax. TN of (NSIMD*32) threads do real work; the # rest wait at the barrier below. - @inbounds if tid < Int32(N) + @inbounds if tid < Int32(TN) col = tid + Int32(1) m = -Inf32 - for i in Int32(1):Int32(N) + for i in Int32(1):Int32(TN) v = S[i, col] * scale m = v > m ? v : m end s = 0.0f0 - for i in Int32(1):Int32(N) + for i in Int32(1):Int32(TN) p = exp(S[i, col] * scale - m) S[i, col] = p s += p end inv_s = 1.0f0 / s - for i in Int32(1):Int32(N) + for i in Int32(1):Int32(TN) P[i, col] = Float16(S[i, col] * inv_s) end end @@ -251,9 +267,9 @@ function _fa_tensor!(O::AbstractMatrix{Float16}, # Step 3: O = V·P (Julia view; equivalent to V·Pᵀ in math notation because # the softmax output is stored in the transposed layout). - let tA = MtlInlineTensor(P, (N, N)), - tB = MtlInlineTensor(V, (D, N)), - tC = MtlInlineTensor(O, (D, N)) + let tA = MtlInlineTensor(P, (Int32(TN), Int32(TN))), + tB = MtlInlineTensor(Vb, (Int32(TD), Int32(TN))), + tC = MtlInlineTensor(Ob, (Int32(TD), Int32(TN))) op = TensorOpsMatmul2D{matmul2d_descriptor(TN, TD, TN), Int32(NSIMD)}() op(tA, tB, tC) end @@ -274,22 +290,13 @@ function attention_tensor(Q::MtlArray{Float16,4}, K::MtlArray{Float16,4}, nsimd = 4 # matches `execution_simdgroups<4>` in the op desc threads = nsimd * simdgroup_size - # The matmul descriptors carry (TN, TD) — the static tile shape per head. - TN_val = Val(Int32(N)) - TD_val = Val(Int32(D)) - NS_val = Val(Int32(nsimd)) - - for b in 1:B, h in 1:H - Qm = view(Q, :, :, h, b) - Km = view(K, :, :, h, b) - Vm = view(V, :, :, h, b) - Om = view(O, :, :, h, b) - @metal threads = threads _fa_tensor!(Om, Qm, Km, Vm, - UInt32(D), UInt32(N), - Float32(scale), - TN_val, TD_val, NS_val) - end - Metal.synchronize() + # Single dispatch covering all (head, batch) pairs: one threadgroup each, + # grid = (H, B). The kernel uses `threadgroup_position_in_grid` to pick its + # slice. The matmul descriptors carry (TN, TD) — the static tile shape per + # head. + Metal.@sync @metal threads = threads groups = (H, B, 1) _fa_tensor!( + O, Q, K, V, Float32(scale), + Val(Int32(D)), Val(Int32(N)), Val(Int32(nsimd))) return O end From ca2737a2c4000e595458bb4f554be8055c2ee383 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 26 May 2026 19:56:00 +0200 Subject: [PATCH 20/24] Remove accidentally-committed files. --- ISSUE-tensor-ops.md | 152 --------------------- bin/coop_matmul.ll | 223 ------------------------------ bin/coop_matmul.metal | 32 ----- bin/inline_matmul.ll | 290 ---------------------------------------- bin/inline_matmul.metal | 32 ----- bin/simple_matmul.ll | 128 ------------------ bin/simple_matmul.metal | 22 --- 7 files changed, 879 deletions(-) delete mode 100644 ISSUE-tensor-ops.md delete mode 100644 bin/coop_matmul.ll delete mode 100644 bin/coop_matmul.metal delete mode 100644 bin/inline_matmul.ll delete mode 100644 bin/inline_matmul.metal delete mode 100644 bin/simple_matmul.ll delete mode 100644 bin/simple_matmul.metal diff --git a/ISSUE-tensor-ops.md b/ISSUE-tensor-ops.md deleted file mode 100644 index 3477a11ec..000000000 --- a/ISSUE-tensor-ops.md +++ /dev/null @@ -1,152 +0,0 @@ -# Metal 4 tensor ops (matmul2d / cooperative_tensor) — status - -## What's working - -`examples/flashattention.jl` now has an `attention_tensor(Q, K, V)` path that -dispatches the two attention matmuls via the Metal 4 `tensor_ops::matmul2d` -primitives. It matches the CPU reference at `D = N = 64`, single head, single -batch. Requires macOS 26+. - -The device-side wrappers live in `src/device/intrinsics/tensor.jl`: - -- `MtlInlineTensor{T, R}` — kernel-stack tensor view (`tensor_inline` form) - over an `MtlDeviceArray`. Built via `air.init_strided_private_tensor`. The - per-thread tensor descriptor is held by a `Ref{NTuple{64, UInt8}}` — - Julia's `llvm-alloc-opt` pass promotes it to a stack alloca because every - use is `@inline`d into the kernel and the gc-managed object only escapes - via `pointer_from_objref` (which `allocopt` treats as `addrescaped`, not - `escaped`). `GC.@preserve` around the ccalls keeps the buffer alive - across the runtime calls. -- `matmul2d_descriptor(m, n, k=-1; transpose_left, transpose_right, - relaxed_precision, mode)` — 20-byte POD matching - `mpp::tensor_ops::matmul2d_descriptor`. -- `tensor_ops_matmul2d!(desc, left, right, dest, threads)` — dispatches one - of `__tensorops_impl_matmul2d_op_run_dv_{tl}_dv_{tr}_dv_{td}` based on the - element types of the operand tensors. `threads` must equal - `simdgroup_size * num_simdgroups` for the descriptor's scope. - -The inline-tensor route lets us reuse the existing Metal.jl kernel ABI: -kernel args are still `MtlDeviceArray`s, so no host-side `MTLTensor` / -`MTL4ComputeCommandEncoder` wrapping is needed. - -The GPUCompiler bits: - -- `GPUCompiler/src/metal.jl` `isintrinsic` whitelists `__tensorops_impl_` - symbols (alongside `air.`). -- `annotate_air_intrinsics!` attaches `section "air.externally_defined"` and - `(convergent, nounwind)` attributes to `__tensorops_impl_*` declarations. - Without the section attribute, the metallib back-end won't resolve the - symbol from the MetalPerformancePrimitives runtime. - -## What's intentionally not exposed - -- **`static_slice<>` / compile-time extents.** Apple's tensor API only - exposes `static_slice` on `tensor_handle` operands, not `tensor_inline`. - An inline tensor built with static extents (e.g. - `tensor, tensor_inline>`) emits - identical AIR to one built with dynamic extents — same - `air.init_strided_private_tensor` + runtime extents arrays. So encoding - static extents in the `MtlInlineTensor` type would only buy us a slightly - smaller alloca for the extents tuple; it would not enable bounds-check - elision in the matmul or in the slice path. We leave it dynamic. - -## What's not working / known limitations - -- **Two `__tensorops_impl_matmul2d_op_run_*` calls in one kernel crash the - Metal back-end** at pipeline-state creation - (`XPC_ERROR_CONNECTION_INTERRUPTED` from `AGXMetalG15X_M1`). MSL-compiled - metallibs of the same kernel shape build pipeline states fine, so the - crash is triggered by our specific AIR pattern: the - `matmul2d_descriptor` ends up populated as a sequence of per-field - stores (via Julia's lowering of `Ref(::matmul2d_descriptor)` and SROA), - rather than Apple's pattern of `memcpy` from a `linkonce_odr` constant - global. The likely fix is to emit the constant-global + memcpy pattern - for descriptors whose fields are compile-time constants. Local - reproducer in `bugs/two_matmul_crash/` (gitignored — see the README - there for the AIR diff and what's been tried). The attention example - sidesteps this by splitting QK and PV into two dispatches. -- **No `cooperative_tensor` yet.** That means the softmax epilogue can't be - done in registers — the scores tile is materialized in device memory. A - proper Flash Attention would fuse the softmax into the cooperative tensor - between the two matmuls. -- **No `tensor_handle` kernel args.** Apple's matmul samples (and the bulk of - the MPP docs) describe tensors as host-bound `MTLTensor` parameters that - arrive in the kernel as opaque `%struct._tensor_t addrspace(1)*`. That - requires both a host-side `MTL4ArgumentTable` / `MTLTensor` wrapping and a - Metal.jl kernel-ABI rewrite. Inline tensors give us most of the - expressiveness without any of that. -- **No threadgroup-memory matmul.** Only `dv_*` (device-memory) variants of - the run helpers are wrapped. `tg_*` variants would let us stage tiles into - threadgroup memory. -- **`D == N` only.** The attention example uses one matmul descriptor sized - to a single 64×64 tile; supporting arbitrary `D, N` means dispatching - multiple threadgroups and tiling on the host. - -## Reverse-engineering reference - -Annotated AIR for the kernels we generate Apple-style equivalents for: - -- `bin/simple_matmul.metal` / `bin/simple_matmul.ll` — minimal NN matmul, - device-memory destination, `tensor_handle` parameters. -- `bin/coop_matmul.metal` / `bin/coop_matmul.ll` — cooperative-tensor - destination with a trivial scale-by-2 postfix epilogue. Closest template - for the proper Flash Attention path. -- `bin/inline_matmul.metal` / `bin/inline_matmul.ll` — the `tensor_inline` - form that Metal.jl actually uses. Matches the IR shape our wrappers emit. - -Apple's headers: - -- `/usr/metal//lib/clang//include/metal/{metal_tensor,metal_cooperative_tensor}` -- `/System/Library/Frameworks/MetalPerformancePrimitives.framework/Versions/A/Headers/{MPPTensorOpsMatMul2d.h,__impl/MPPTensorOpsMatMul2dImpl.h}` - -### AIR shapes used by our wrappers - -Inline tensor construction (`air.*` intrinsics, in `i32`-indexed flavor): - -```llvm -i16 @air.get_descriptor_size_tensor(i16 rank, i16 index_size) -void @air.init_strided_private_tensor.i32.global(i8* %handle, i16 rank, - i8 addrspace(1)* %data, - i8* %extents, i8* %strides, - i8 %contiguous) -i32 @air.get_extent_private_tensor.i32(i8* %handle, i16 rank, i16 dim) -void @air.slice_private_tensor_private_tensor.s.i32(i8* %dst, i8* %src, - i16 rank, i8* %origin, - i8* %extents) -``` - -Matmul run (externally-defined, `section "air.externally_defined"`): - -```llvm -void @__tensorops_impl_matmul2d_op_run_dv_{tl}_dv_{tr}_dv_{td}( - %"struct.matmul2d_descriptor"* %desc, - i8* %left, i32 %left_desc_type, - i8* %right, i32 %right_desc_type, - i8* %destination, i32 %dest_desc_type, - i32 %threads) -``` - -`{tl}, {tr}, {td}` are element-type suffixes (`f16`, `f32`, `bf16`, `i8`, …) -and the descriptor types are `1` for `tensor_handle`, `2` for -`tensor_inline`. - -## What's still TODO - -In rough order of value: - -1. **`MtlCooperativeTensor`** — would enable the proper Flash Attention - postfix-fusion path. Needs dynamic stack allocation (the Apple compiler - emits `alloca i8, i64 %sz` where `%sz` comes from - `__tensorops_impl_matmul2d_op_cooperative_tensor_data_size` and is marked - `"deferred-static-alloca-size"`). Workaround: reserve a conservative - upper bound at compile time. -2. **Threadgroup-memory matmul variants.** Wrap `_tg_*` flavors of the run - helpers and let `MtlInlineTensor` accept a `MtlThreadGroupArray`. -3. **Tile decomposition.** Drop the `D == N == tile` constraint by - dispatching multiple threadgroups per matmul and slicing on `tgid`. -4. **`tensor_handle` kernel args + host-side `MTLTensor` / `MTL4` wrappers.** - The biggest piece, and the closest path to what Apple's samples - demonstrate. Inline tensors get us most of the way without it, so this - is now only worth doing if we want first-class interop with Apple's - tensor APIs (e.g., to consume an `MTLTensor` produced by some other - framework). diff --git a/bin/coop_matmul.ll b/bin/coop_matmul.ll deleted file mode 100644 index 80227be6b..000000000 --- a/bin/coop_matmul.ll +++ /dev/null @@ -1,223 +0,0 @@ -; ModuleID = 'coop_matmul.metal' -source_filename = "coop_matmul.metal" -target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024-n8:16:32" -target triple = "air64_v28-apple-macosx26.0.0" - -%"struct.mpp::tensor_ops::matmul2d_descriptor" = type { i32, i32, i32, i8, i8, i8, i32 } -%struct._tensor_t = type opaque -%"struct.metal::tensor.6" = type { %"struct.metal::__tensor_base.7", %struct._tensor_t addrspace(1)* } -%"struct.metal::__tensor_base.7" = type { %"struct.metal::__tensor_offsets.8" } -%"struct.metal::__tensor_offsets.8" = type { %"struct.metal::array" } -%"struct.metal::array" = type { [2 x i32] } -%"struct.metal::tensor.3" = type { %"struct.metal::__tensor_base.4", %struct._tensor_t addrspace(1)* } -%"struct.metal::__tensor_base.4" = type { %"struct.metal::__tensor_offsets.5" } -%"struct.metal::__tensor_offsets.5" = type { %"struct.metal::array" } - -@_ZTAXtlN3mpp10tensor_ops19matmul2d_descriptorELi64ELi32ELin1EEE = linkonce_odr local_unnamed_addr constant %"struct.mpp::tensor_ops::matmul2d_descriptor" { i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0 } - -; Function Attrs: convergent nounwind -define void @coop_matmul(%struct._tensor_t addrspace(1)* %0, %struct._tensor_t addrspace(1)* %1, %struct._tensor_t addrspace(1)* %2, <2 x i32> noundef %3) local_unnamed_addr #0 { - %5 = alloca %"struct.mpp::tensor_ops::matmul2d_descriptor", align 4 - %6 = alloca %"struct.metal::tensor.6", align 8 - %7 = alloca %"struct.metal::tensor.3", align 8 - %8 = alloca %"struct.metal::tensor.3", align 8 - %9 = tail call i64 @_ZN5metal18cooperative_tensorIfNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEEN3mpp10tensor_ops17__mutmul2d_detail16__operand_layoutIXtlNS4_19matmul2d_descriptorELi64ELi32ELin1EEELNS5_36__matmul2d_cooperative_operand_indexE2ENS_20execution_simdgroupsILm4EEEDhDhfiJEEEEE.MTL_SIZEAS() #7 - %10 = alloca i8, i64 %9, align 4 - %11 = bitcast %"struct.metal::tensor.3"* %7 to i8* - call void @llvm.lifetime.start.p0i8(i64 16, i8* nonnull %11) #7 - %12 = extractelement <2 x i32> %3, i64 1 - %13 = shl i32 %12, 6 - %14 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %7, i64 0, i32 0, i32 0, i32 0, i32 0, i64 0 - store i32 0, i32* %14, align 8 - %15 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %7, i64 0, i32 0, i32 0, i32 0, i32 0, i64 1 - store i32 %13, i32* %15, align 4 - %16 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %7, i64 0, i32 1 - store %struct._tensor_t addrspace(1)* %0, %struct._tensor_t addrspace(1)** %16, align 8 - %17 = bitcast %"struct.metal::tensor.3"* %8 to i8* - call void @llvm.lifetime.start.p0i8(i64 16, i8* nonnull %17) #7 - %18 = extractelement <2 x i32> %3, i64 0 - %19 = shl i32 %18, 5 - %20 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %8, i64 0, i32 0, i32 0, i32 0, i32 0, i64 0 - store i32 %19, i32* %20, align 8 - %21 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %8, i64 0, i32 0, i32 0, i32 0, i32 0, i64 1 - store i32 0, i32* %21, align 4 - %22 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %8, i64 0, i32 1 - store %struct._tensor_t addrspace(1)* %1, %struct._tensor_t addrspace(1)** %22, align 8 - call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %10) - %23 = tail call i32 @air.get_simdgroup_size.i32() #8 - %24 = shl i32 %23, 2 - call void @__tensorops_impl_matmul2d_op_cooperative_tensor_init(i32 noundef 2, i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0, i8* noundef nonnull %10, i32 noundef 268435472, i32 noundef 268435472, i32 noundef 268435488, i32 noundef %24) #9 - br label %25 - -25: ; preds = %37, %4 - %26 = phi i16 [ 0, %4 ], [ %38, %37 ] - %27 = call zeroext i16 @__tensorops_impl_matmul2d_op_cooperative_tensor_num_elements(i32 noundef 2, i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0, i8* noundef nonnull %10, i32 noundef 268435472, i32 noundef 268435472, i32 noundef %24) #9 - %28 = icmp ult i16 %26, %27 - br i1 %28, label %32, label %29 - -29: ; preds = %25 - %30 = bitcast %"struct.mpp::tensor_ops::matmul2d_descriptor"* %5 to i8* - call void @llvm.lifetime.start.p0i8(i64 20, i8* nonnull %30) #7 - call void @llvm.memcpy.p0i8.p0i8.i64(i8* noundef nonnull align 4 dereferenceable(20) %30, i8* noundef nonnull align 4 dereferenceable(20) bitcast (%"struct.mpp::tensor_ops::matmul2d_descriptor"* @_ZTAXtlN3mpp10tensor_ops19matmul2d_descriptorELi64ELi32ELin1EEE to i8*), i64 20, i1 false) #7, !tbaa.struct !23 - %31 = call i8* @__tensorops_impl_matmul2d_op_cooperative_tensor_get_element_pointer(i32 noundef 2, i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0, i8* noundef nonnull %10, i16 noundef zeroext -1, i32 noundef 268435472, i32 noundef 268435472, i32 noundef 268435488) #9 - call void @__tensorops_impl_matmul2d_op_run_cooperative_dv_f16_dv_f16_f32(%"struct.mpp::tensor_ops::matmul2d_descriptor"* noundef nonnull align 4 dereferenceable(20) %5, i8* noundef nonnull %11, i32 noundef 1, i8* noundef nonnull %17, i32 noundef 1, i8* noundef %31, i32 noundef %24) #9 - call void @llvm.lifetime.end.p0i8(i64 20, i8* nonnull %30) #7 - br label %39 - -32: ; preds = %25 - %33 = call zeroext i1 @__tensorops_impl_matmul2d_op_cooperative_tensor_is_valid_element(i32 noundef 2, i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0, i8* noundef nonnull %10, i16 noundef zeroext %26, i32 noundef 268435472, i32 noundef 268435472, i32 noundef 268435488, i32 noundef %24) #9 - br i1 %33, label %34, label %37 - -34: ; preds = %32 - %35 = call i8* @__tensorops_impl_matmul2d_op_cooperative_tensor_get_element_pointer(i32 noundef 2, i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0, i8* noundef nonnull %10, i16 noundef zeroext %26, i32 noundef 268435472, i32 noundef 268435472, i32 noundef 268435488) #9 - %36 = bitcast i8* %35 to float* - store float 0.000000e+00, float* %36, align 4, !tbaa !32 - br label %37 - -37: ; preds = %32, %34 - %38 = add nuw i16 %26, 1 - br label %25, !llvm.loop !34 - -39: ; preds = %55, %29 - %40 = phi i16 [ 0, %29 ], [ %56, %55 ] - %41 = call zeroext i16 @__tensorops_impl_matmul2d_op_cooperative_tensor_num_elements(i32 noundef 2, i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0, i8* noundef nonnull %10, i32 noundef 268435472, i32 noundef 268435472, i32 noundef %24) #9 - %42 = icmp ult i16 %40, %41 - br i1 %42, label %48, label %43 - -43: ; preds = %39 - %44 = bitcast %"struct.metal::tensor.6"* %6 to i8* - call void @llvm.lifetime.start.p0i8(i64 16, i8* nonnull %44) - %45 = getelementptr inbounds %"struct.metal::tensor.6", %"struct.metal::tensor.6"* %6, i64 0, i32 0, i32 0, i32 0, i32 0, i64 0 - store i32 %19, i32* %45, align 8 - %46 = getelementptr inbounds %"struct.metal::tensor.6", %"struct.metal::tensor.6"* %6, i64 0, i32 0, i32 0, i32 0, i32 0, i64 1 - store i32 %13, i32* %46, align 4 - %47 = getelementptr inbounds %"struct.metal::tensor.6", %"struct.metal::tensor.6"* %6, i64 0, i32 1 - store %struct._tensor_t addrspace(1)* %2, %struct._tensor_t addrspace(1)** %47, align 8 - call void @__tensorops_impl_matmul2d_op_cooperative_tensor_store_dv_f32(i32 noundef 2, i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0, i8* noundef nonnull %10, i8* noundef nonnull %44, i32 noundef 1, i32 noundef 268435472, i32 noundef 268435472, i32 noundef 268435488, i32 noundef %24) #9 - call void @llvm.lifetime.end.p0i8(i64 16, i8* nonnull %44) - call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %10) #7 - call void @llvm.lifetime.end.p0i8(i64 16, i8* nonnull %17) #7 - call void @llvm.lifetime.end.p0i8(i64 16, i8* nonnull %11) #7 - ret void - -48: ; preds = %39 - %49 = call zeroext i1 @__tensorops_impl_matmul2d_op_cooperative_tensor_is_valid_element(i32 noundef 2, i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0, i8* noundef nonnull %10, i16 noundef zeroext %40, i32 noundef 268435472, i32 noundef 268435472, i32 noundef 268435488, i32 noundef %24) #9 - br i1 %49, label %50, label %55 - -50: ; preds = %48 - %51 = call i8* @__tensorops_impl_matmul2d_op_cooperative_tensor_get_element_pointer(i32 noundef 2, i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0, i8* noundef nonnull %10, i16 noundef zeroext %40, i32 noundef 268435472, i32 noundef 268435472, i32 noundef 268435488) #9 - %52 = bitcast i8* %51 to float* - %53 = load float, float* %52, align 4, !tbaa !32 - %54 = fmul fast float %53, 2.000000e+00 - store float %54, float* %52, align 4, !tbaa !32 - br label %55 - -55: ; preds = %48, %50 - %56 = add nuw i16 %40, 1 - br label %39, !llvm.loop !36 -} - -; Function Attrs: argmemonly mustprogress nocallback nofree nosync nounwind willreturn -declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #1 - -; Function Attrs: mustprogress nofree nosync readnone speculatable willreturn -define linkonce_odr hidden i64 @_ZN5metal18cooperative_tensorIfNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEEN3mpp10tensor_ops17__mutmul2d_detail16__operand_layoutIXtlNS4_19matmul2d_descriptorELi64ELi32ELin1EEELNS5_36__matmul2d_cooperative_operand_indexE2ENS_20execution_simdgroupsILm4EEEDhDhfiJEEEEE.MTL_SIZEAS() local_unnamed_addr #2 { - %1 = tail call i64 @_ZN3mpp10tensor_ops17__mutmul2d_detail16__operand_layoutIXtlNS0_19matmul2d_descriptorELi64ELi32ELin1EEELNS1_36__matmul2d_cooperative_operand_indexE2EN5metal20execution_simdgroupsILm4EEEDhDhfiJEE19thread_storage_sizeEv() #10 - ret i64 %1 -} - -; Function Attrs: convergent nounwind -define linkonce_odr i64 @_ZN3mpp10tensor_ops17__mutmul2d_detail16__operand_layoutIXtlNS0_19matmul2d_descriptorELi64ELi32ELin1EEELNS1_36__matmul2d_cooperative_operand_indexE2EN5metal20execution_simdgroupsILm4EEEDhDhfiJEE19thread_storage_sizeEv() local_unnamed_addr #3 align 2 { - %1 = tail call i32 @air.get_simdgroup_size.i32() #8 - %2 = shl i32 %1, 2 - %3 = tail call i64 @__tensorops_impl_matmul2d_op_cooperative_tensor_data_size(i32 noundef 2, i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0, i32 noundef 268435472, i32 noundef 268435472, i32 noundef 268435488, i32 noundef %2) #9 - ret i64 %3 -} - -; Function Attrs: argmemonly mustprogress nocallback nofree nosync nounwind willreturn -declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #1 - -; Function Attrs: argmemonly mustprogress nofree nounwind willreturn -declare void @llvm.memcpy.p0i8.p0i8.i64(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i64, i1 immarg) #4 - -; Function Attrs: convergent -declare void @__tensorops_impl_matmul2d_op_cooperative_tensor_init(i32 noundef, i32, i32, i32, i8, i8, i8, i32, i8* noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #5 section "air.externally_defined" - -; Function Attrs: mustprogress nofree nosync nounwind readnone willreturn -declare i32 @air.get_simdgroup_size.i32() local_unnamed_addr #6 - -; Function Attrs: convergent -declare i64 @__tensorops_impl_matmul2d_op_cooperative_tensor_data_size(i32 noundef, i32, i32, i32, i8, i8, i8, i32, i32 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #5 section "air.externally_defined" - -; Function Attrs: convergent -declare zeroext i16 @__tensorops_impl_matmul2d_op_cooperative_tensor_num_elements(i32 noundef, i32, i32, i32, i8, i8, i8, i32, i8* noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #5 section "air.externally_defined" - -; Function Attrs: convergent -declare zeroext i1 @__tensorops_impl_matmul2d_op_cooperative_tensor_is_valid_element(i32 noundef, i32, i32, i32, i8, i8, i8, i32, i8* noundef, i16 noundef zeroext, i32 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #5 section "air.externally_defined" - -; Function Attrs: convergent -declare i8* @__tensorops_impl_matmul2d_op_cooperative_tensor_get_element_pointer(i32 noundef, i32, i32, i32, i8, i8, i8, i32, i8* noundef, i16 noundef zeroext, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #5 section "air.externally_defined" - -; Function Attrs: convergent -declare void @__tensorops_impl_matmul2d_op_run_cooperative_dv_f16_dv_f16_f32(%"struct.mpp::tensor_ops::matmul2d_descriptor"* noundef nonnull align 4 dereferenceable(20), i8* noundef, i32 noundef, i8* noundef, i32 noundef, i8* noundef, i32 noundef) local_unnamed_addr #5 section "air.externally_defined" - -; Function Attrs: convergent -declare void @__tensorops_impl_matmul2d_op_cooperative_tensor_store_dv_f32(i32 noundef, i32, i32, i32, i8, i8, i8, i32, i8* noundef, i8* noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #5 section "air.externally_defined" - -attributes #0 = { convergent nounwind "approx-func-fp-math"="true" "frame-pointer"="all" "min-legal-vector-width"="64" "no-builtins" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="true" } -attributes #1 = { argmemonly mustprogress nocallback nofree nosync nounwind willreturn } -attributes #2 = { mustprogress nofree nosync readnone speculatable willreturn "deferred-static-alloca-size" } -attributes #3 = { convergent nounwind "approx-func-fp-math"="true" "frame-pointer"="all" "min-legal-vector-width"="0" "no-builtins" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="true" } -attributes #4 = { argmemonly mustprogress nofree nounwind willreturn } -attributes #5 = { convergent "approx-func-fp-math"="true" "frame-pointer"="all" "no-builtins" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="true" } -attributes #6 = { mustprogress nofree nosync nounwind readnone willreturn } -attributes #7 = { nounwind } -attributes #8 = { nounwind readnone willreturn } -attributes #9 = { convergent nobuiltin nounwind "no-builtins" } -attributes #10 = { convergent nobuiltin "no-builtins" } - -!llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7, !8} -!air.kernel = !{!9} -!air.compile_options = !{!16, !17, !18} -!llvm.ident = !{!19} -!air.version = !{!20} -!air.language_version = !{!21} -!air.source_file_name = !{!22} - -!0 = !{i32 2, !"SDK Version", [2 x i32] [i32 26, i32 2]} -!1 = !{i32 1, !"wchar_size", i32 4} -!2 = !{i32 7, !"frame-pointer", i32 2} -!3 = !{i32 7, !"air.max_device_buffers", i32 31} -!4 = !{i32 7, !"air.max_constant_buffers", i32 31} -!5 = !{i32 7, !"air.max_threadgroup_buffers", i32 31} -!6 = !{i32 7, !"air.max_textures", i32 128} -!7 = !{i32 7, !"air.max_read_write_textures", i32 8} -!8 = !{i32 7, !"air.max_samplers", i32 16} -!9 = !{void (%struct._tensor_t addrspace(1)*, %struct._tensor_t addrspace(1)*, %struct._tensor_t addrspace(1)*, <2 x i32>)* @coop_matmul, !10, !11} -!10 = !{} -!11 = !{!12, !13, !14, !15} -!12 = !{i32 0, !"air.tensor", !"air.location_index", i32 0, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.arg_type_name", !"tensor>", !"air.arg_name", !"A"} -!13 = !{i32 1, !"air.tensor", !"air.location_index", i32 1, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.arg_type_name", !"tensor>", !"air.arg_name", !"B"} -!14 = !{i32 2, !"air.tensor", !"air.location_index", i32 2, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.arg_type_name", !"tensor>", !"air.arg_name", !"C"} -!15 = !{i32 3, !"air.threadgroup_position_in_grid", !"air.arg_type_name", !"uint2", !"air.arg_name", !"tgid"} -!16 = !{!"air.compile.denorms_disable"} -!17 = !{!"air.compile.fast_math_enable"} -!18 = !{!"air.compile.framebuffer_fetch_enable"} -!19 = !{!"Apple metal version 32023.864 (metalfe-32023.864)"} -!20 = !{i32 2, i32 8, i32 0} -!21 = !{!"Metal", i32 4, i32 0, i32 0} -!22 = !{!"/private/tmp/metaltest/coop_matmul.metal"} -!23 = !{i64 0, i64 4, !24, i64 4, i64 4, !24, i64 8, i64 4, !24, i64 12, i64 1, !28, i64 13, i64 1, !28, i64 14, i64 1, !28, i64 16, i64 4, !30} -!24 = !{!25, !25, i64 0} -!25 = !{!"int", !26, i64 0} -!26 = !{!"omnipotent char", !27, i64 0} -!27 = !{!"Simple C++ TBAA"} -!28 = !{!29, !29, i64 0} -!29 = !{!"bool", !26, i64 0} -!30 = !{!31, !31, i64 0} -!31 = !{!"_ZTSN3mpp10tensor_ops19matmul2d_descriptor4modeE", !26, i64 0} -!32 = !{!33, !33, i64 0} -!33 = !{!"float", !26, i64 0} -!34 = distinct !{!34, !35} -!35 = !{!"llvm.loop.mustprogress"} -!36 = distinct !{!36, !35} diff --git a/bin/coop_matmul.metal b/bin/coop_matmul.metal deleted file mode 100644 index 26c51faa9..000000000 --- a/bin/coop_matmul.metal +++ /dev/null @@ -1,32 +0,0 @@ -#include -#include -#include -#include - -using namespace metal; -using namespace mpp::tensor_ops; - -kernel void coop_matmul(tensor> A, - tensor> B, - tensor> C, - uint2 tgid [[threadgroup_position_in_grid]]) -{ - constexpr auto desc = matmul2d_descriptor(64, 32, static_cast(dynamic_extent)); - matmul2d> op; - - auto mA = A.slice(0, tgid.y * 64); - auto mB = B.slice(tgid.x * 32, 0); - auto mC = C.slice(tgid.x * 32, tgid.y * 64); - - auto cT = op.get_destination_cooperative_tensor(); - for (uint16_t i = 0; i < cT.get_capacity(); ++i) { - if (cT.is_valid_element(i)) cT[i] = 0; - } - op.run(mA, mB, cT); - - // postfix-fuse: just scale + cast as a stand-in for softmax epilogue - for (uint16_t i = 0; i < cT.get_capacity(); ++i) { - if (cT.is_valid_element(i)) cT[i] *= 2.0f; - } - cT.store(mC); -} diff --git a/bin/inline_matmul.ll b/bin/inline_matmul.ll deleted file mode 100644 index a9ed2aef6..000000000 --- a/bin/inline_matmul.ll +++ /dev/null @@ -1,290 +0,0 @@ -; ModuleID = 'inline_matmul.metal' -source_filename = "inline_matmul.metal" -target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024-n8:16:32" -target triple = "air64_v28-apple-macosx26.0.0" - -%"struct.mpp::tensor_ops::matmul2d_descriptor" = type { i32, i32, i32, i8, i8, i8, i32 } -%"struct.metal::array" = type { [2 x i32] } -%struct._tensor_t = type opaque - -@_ZTAXtlN3mpp10tensor_ops19matmul2d_descriptorELi64ELi32ELin1EEE = linkonce_odr local_unnamed_addr constant %"struct.mpp::tensor_ops::matmul2d_descriptor" { i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0 } - -; Function Attrs: convergent nounwind -define void @inline_matmul(half addrspace(1)* noundef "air-buffer-no-alias" %0, half addrspace(1)* noundef "air-buffer-no-alias" %1, float addrspace(1)* noundef "air-buffer-no-alias" %2, i32 addrspace(2)* nocapture noundef readonly align 4 dereferenceable(4) "air-buffer-no-alias" %3, i32 addrspace(2)* nocapture noundef readonly align 4 dereferenceable(4) "air-buffer-no-alias" %4, i32 addrspace(2)* nocapture noundef readonly align 4 dereferenceable(4) "air-buffer-no-alias" %5, <2 x i32> noundef %6) local_unnamed_addr #0 { - %8 = alloca %"struct.mpp::tensor_ops::matmul2d_descriptor", align 4 - %9 = alloca %"struct.metal::array", align 4 - %10 = alloca %"struct.metal::array", align 4 - %11 = alloca %"struct.metal::array", align 4 - %12 = alloca %"struct.metal::array", align 4 - %13 = alloca %"struct.metal::array", align 4 - %14 = alloca %"struct.metal::array", align 4 - %15 = alloca %"struct.metal::array", align 4 - %16 = alloca %"struct.metal::array", align 4 - %17 = alloca %"struct.metal::array", align 4 - %18 = alloca %"struct.metal::array", align 4 - %19 = alloca %"struct.metal::array", align 4 - %20 = alloca %"struct.metal::array", align 4 - %21 = tail call i64 @_ZN5metal6tensorIU9MTLdeviceDhNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEENS_13tensor_inlineEJEEE.MTL_SIZEAS() #7 - %22 = alloca i8, i64 %21, align 8 - %23 = alloca i8, i64 %21, align 8 - %24 = tail call i64 @_ZN5metal6tensorIU9MTLdevicefNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEENS_13tensor_inlineEJEEE.MTL_SIZEAS() #7 - %25 = alloca i8, i64 %24, align 8 - %26 = alloca i8, i64 %21, align 8 - %27 = alloca i8, i64 %21, align 8 - %28 = alloca i8, i64 %24, align 8 - call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %22) - %29 = load i32, i32 addrspace(2)* %5, align 4, !tbaa !26, !alias.scope !30, !noalias !33 - %30 = load i32, i32 addrspace(2)* %3, align 4, !tbaa !26, !alias.scope !39, !noalias !40 - %31 = bitcast i8* %22 to %struct._tensor_t* - %32 = bitcast half addrspace(1)* %0 to i8 addrspace(1)* - %33 = bitcast %"struct.metal::array"* %13 to i8* - call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %33) #7 - %34 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %13, i64 0, i32 0, i64 0 - store i32 %29, i32* %34, align 4, !tbaa !26 - %35 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %13, i64 0, i32 0, i64 1 - store i32 %30, i32* %35, align 4, !tbaa !26 - %36 = bitcast %"struct.metal::array"* %14 to i8* - call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %36) #7 - %37 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %14, i64 0, i32 0, i64 0 - store i32 1, i32* %37, align 4, !tbaa !26 - %38 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %14, i64 0, i32 0, i64 1 - store i32 %29, i32* %38, align 4, !tbaa !26 - call void @air.init_strided_private_tensor.i32.global(%struct._tensor_t* nocapture nonnull writeonly %31, i16 2, i8 addrspace(1)* readnone %32, i8* nocapture nonnull readonly %33, i8* nocapture nonnull readonly %36, i8 1) #8 - call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %36) #7 - call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %33) #7 - call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %23) - %39 = load i32, i32 addrspace(2)* %4, align 4, !tbaa !26, !alias.scope !41, !noalias !42 - %40 = bitcast i8* %23 to %struct._tensor_t* - %41 = bitcast half addrspace(1)* %1 to i8 addrspace(1)* - %42 = bitcast %"struct.metal::array"* %11 to i8* - call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %42) #7 - %43 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %11, i64 0, i32 0, i64 0 - store i32 %39, i32* %43, align 4, !tbaa !26 - %44 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %11, i64 0, i32 0, i64 1 - store i32 %29, i32* %44, align 4, !tbaa !26 - %45 = bitcast %"struct.metal::array"* %12 to i8* - call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %45) #7 - %46 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %12, i64 0, i32 0, i64 0 - store i32 1, i32* %46, align 4, !tbaa !26 - %47 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %12, i64 0, i32 0, i64 1 - store i32 %39, i32* %47, align 4, !tbaa !26 - call void @air.init_strided_private_tensor.i32.global(%struct._tensor_t* nocapture nonnull writeonly %40, i16 2, i8 addrspace(1)* readnone %41, i8* nocapture nonnull readonly %42, i8* nocapture nonnull readonly %45, i8 1) #8 - call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %45) #7 - call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %42) #7 - call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %25) - %48 = bitcast i8* %25 to %struct._tensor_t* - %49 = bitcast float addrspace(1)* %2 to i8 addrspace(1)* - %50 = bitcast %"struct.metal::array"* %9 to i8* - call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %50) #7 - %51 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %9, i64 0, i32 0, i64 0 - store i32 %39, i32* %51, align 4, !tbaa !26 - %52 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %9, i64 0, i32 0, i64 1 - store i32 %30, i32* %52, align 4, !tbaa !26 - %53 = bitcast %"struct.metal::array"* %10 to i8* - call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %53) #7 - %54 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %10, i64 0, i32 0, i64 0 - store i32 1, i32* %54, align 4, !tbaa !26 - %55 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %10, i64 0, i32 0, i64 1 - store i32 %39, i32* %55, align 4, !tbaa !26 - call void @air.init_strided_private_tensor.i32.global(%struct._tensor_t* nocapture nonnull writeonly %48, i16 2, i8 addrspace(1)* readnone %49, i8* nocapture nonnull readonly %50, i8* nocapture nonnull readonly %53, i8 0) #8 - call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %53) #7 - call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %50) #7 - call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %26) - %56 = extractelement <2 x i32> %6, i64 1 - %57 = shl i32 %56, 6 - %58 = bitcast %"struct.metal::array"* %20 to i8* - call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %58) #7, !noalias !43 - %59 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %20, i64 0, i32 0, i64 0 - store i32 0, i32* %59, align 4, !tbaa !26, !noalias !43 - %60 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %20, i64 0, i32 0, i64 1 - store i32 %57, i32* %60, align 4, !tbaa !26, !noalias !43 - %61 = bitcast %"struct.metal::array"* %16 to i8* - call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %61) #7 - %62 = call i32 @air.get_extent_private_tensor.i32(%struct._tensor_t* nocapture nonnull readonly %31, i16 2, i16 0) #8 - %63 = call i32 @air.get_extent_private_tensor.i32(%struct._tensor_t* nocapture nonnull readonly %31, i16 2, i16 1) #8 - %64 = sub i32 %63, %57 - %65 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %16, i64 0, i32 0, i64 0 - store i32 %62, i32* %65, align 4 - %66 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %16, i64 0, i32 0, i64 1 - store i32 %64, i32* %66, align 4 - %67 = bitcast i8* %26 to %struct._tensor_t* - call void @air.slice_private_tensor_private_tensor.s.i32(%struct._tensor_t* nocapture nonnull writeonly %67, %struct._tensor_t* nocapture nonnull readonly %31, i16 2, i8* nocapture nonnull readonly %58, i8* nocapture nonnull readonly %61) #8 - call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %61) #7 - call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %58) #7, !noalias !43 - call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %27) - %68 = extractelement <2 x i32> %6, i64 0 - %69 = shl i32 %68, 5 - %70 = bitcast %"struct.metal::array"* %19 to i8* - call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %70) #7, !noalias !46 - %71 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %19, i64 0, i32 0, i64 0 - store i32 %69, i32* %71, align 4, !tbaa !26, !noalias !46 - %72 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %19, i64 0, i32 0, i64 1 - store i32 0, i32* %72, align 4, !tbaa !26, !noalias !46 - %73 = bitcast %"struct.metal::array"* %17 to i8* - call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %73) #7 - %74 = call i32 @air.get_extent_private_tensor.i32(%struct._tensor_t* nocapture nonnull readonly %40, i16 2, i16 0) #8 - %75 = call i32 @air.get_extent_private_tensor.i32(%struct._tensor_t* nocapture nonnull readonly %40, i16 2, i16 1) #8 - %76 = sub i32 %74, %69 - %77 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %17, i64 0, i32 0, i64 0 - store i32 %76, i32* %77, align 4 - %78 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %17, i64 0, i32 0, i64 1 - store i32 %75, i32* %78, align 4 - %79 = bitcast i8* %27 to %struct._tensor_t* - call void @air.slice_private_tensor_private_tensor.s.i32(%struct._tensor_t* nocapture nonnull writeonly %79, %struct._tensor_t* nocapture nonnull readonly %40, i16 2, i8* nocapture nonnull readonly %70, i8* nocapture nonnull readonly %73) #8 - call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %73) #7 - call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %70) #7, !noalias !46 - call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %28) - %80 = bitcast %"struct.metal::array"* %18 to i8* - call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %80) #7, !noalias !49 - %81 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %18, i64 0, i32 0, i64 0 - store i32 %69, i32* %81, align 4, !tbaa !26, !noalias !49 - %82 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %18, i64 0, i32 0, i64 1 - store i32 %57, i32* %82, align 4, !tbaa !26, !noalias !49 - %83 = bitcast %"struct.metal::array"* %15 to i8* - call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %83) #7 - %84 = call i32 @air.get_extent_private_tensor.i32(%struct._tensor_t* nocapture nonnull readonly %48, i16 2, i16 0) #8 - %85 = call i32 @air.get_extent_private_tensor.i32(%struct._tensor_t* nocapture nonnull readonly %48, i16 2, i16 1) #8 - %86 = sub i32 %84, %69 - %87 = sub i32 %85, %57 - %88 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %15, i64 0, i32 0, i64 0 - store i32 %86, i32* %88, align 4 - %89 = getelementptr inbounds %"struct.metal::array", %"struct.metal::array"* %15, i64 0, i32 0, i64 1 - store i32 %87, i32* %89, align 4 - %90 = bitcast i8* %28 to %struct._tensor_t* - call void @air.slice_private_tensor_private_tensor.s.i32(%struct._tensor_t* nocapture nonnull writeonly %90, %struct._tensor_t* nocapture nonnull readonly %48, i16 2, i8* nocapture nonnull readonly %80, i8* nocapture nonnull readonly %83) #8 - call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %83) #7 - call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %80) #7, !noalias !49 - %91 = tail call i32 @air.get_simdgroup_size.i32() #9 - %92 = shl i32 %91, 2 - %93 = bitcast %"struct.mpp::tensor_ops::matmul2d_descriptor"* %8 to i8* - call void @llvm.lifetime.start.p0i8(i64 20, i8* nonnull %93) #7 - call void @llvm.memcpy.p0i8.p0i8.i64(i8* noundef nonnull align 4 dereferenceable(20) %93, i8* noundef nonnull align 4 dereferenceable(20) bitcast (%"struct.mpp::tensor_ops::matmul2d_descriptor"* @_ZTAXtlN3mpp10tensor_ops19matmul2d_descriptorELi64ELi32ELin1EEE to i8*), i64 20, i1 false) #7, !tbaa.struct !52 - call void @__tensorops_impl_matmul2d_op_run_dv_f16_dv_f16_dv_f32(%"struct.mpp::tensor_ops::matmul2d_descriptor"* noundef nonnull align 4 dereferenceable(20) %8, i8* noundef nonnull %26, i32 noundef 2, i8* noundef nonnull %27, i32 noundef 2, i8* noundef nonnull %28, i32 noundef 2, i32 noundef %92) #10 - call void @llvm.lifetime.end.p0i8(i64 20, i8* nonnull %93) #7 - call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %28) #7 - call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %27) #7 - call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %26) #7 - call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %25) #7 - call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %23) #7 - call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %22) #7 - ret void -} - -; Function Attrs: mustprogress nofree nosync readnone speculatable willreturn -define linkonce_odr hidden i64 @_ZN5metal6tensorIU9MTLdeviceDhNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEENS_13tensor_inlineEJEEE.MTL_SIZEAS() local_unnamed_addr #1 { - %1 = tail call i16 @air.get_descriptor_size_tensor(i16 2, i16 4) #9 - %2 = zext i16 %1 to i64 - ret i64 %2 -} - -; Function Attrs: mustprogress nofree nosync nounwind readnone willreturn -declare i16 @air.get_descriptor_size_tensor(i16, i16) local_unnamed_addr #2 - -; Function Attrs: argmemonly mustprogress nocallback nofree nosync nounwind willreturn -declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #3 - -; Function Attrs: argmemonly mustprogress nocallback nofree nosync nounwind willreturn -declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #3 - -; Function Attrs: mustprogress nofree nosync readnone speculatable willreturn -define linkonce_odr hidden i64 @_ZN5metal6tensorIU9MTLdevicefNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEENS_13tensor_inlineEJEEE.MTL_SIZEAS() local_unnamed_addr #1 { - %1 = tail call i16 @air.get_descriptor_size_tensor(i16 2, i16 4) #9 - %2 = zext i16 %1 to i64 - ret i64 %2 -} - -; Function Attrs: argmemonly mustprogress nounwind willreturn -declare void @air.init_strided_private_tensor.i32.global(%struct._tensor_t* nocapture writeonly, i16, i8 addrspace(1)* readnone, i8* nocapture readonly, i8* nocapture readonly, i8) local_unnamed_addr #4 - -; Function Attrs: argmemonly mustprogress nounwind willreturn -declare i32 @air.get_extent_private_tensor.i32(%struct._tensor_t* nocapture readonly, i16, i16) local_unnamed_addr #4 - -; Function Attrs: argmemonly mustprogress nounwind willreturn -declare void @air.slice_private_tensor_private_tensor.s.i32(%struct._tensor_t* nocapture writeonly, %struct._tensor_t* nocapture readonly, i16, i8* nocapture readonly, i8* nocapture readonly) local_unnamed_addr #4 - -; Function Attrs: argmemonly mustprogress nofree nounwind willreturn -declare void @llvm.memcpy.p0i8.p0i8.i64(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i64, i1 immarg) #5 - -; Function Attrs: convergent -declare void @__tensorops_impl_matmul2d_op_run_dv_f16_dv_f16_dv_f32(%"struct.mpp::tensor_ops::matmul2d_descriptor"* noundef nonnull align 4 dereferenceable(20), i8* noundef, i32 noundef, i8* noundef, i32 noundef, i8* noundef, i32 noundef, i32 noundef) local_unnamed_addr #6 section "air.externally_defined" - -; Function Attrs: mustprogress nofree nosync nounwind readnone willreturn -declare i32 @air.get_simdgroup_size.i32() local_unnamed_addr #2 - -attributes #0 = { convergent nounwind "approx-func-fp-math"="true" "frame-pointer"="all" "min-legal-vector-width"="64" "no-builtins" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="true" } -attributes #1 = { mustprogress nofree nosync readnone speculatable willreturn "deferred-static-alloca-size" } -attributes #2 = { mustprogress nofree nosync nounwind readnone willreturn } -attributes #3 = { argmemonly mustprogress nocallback nofree nosync nounwind willreturn } -attributes #4 = { argmemonly mustprogress nounwind willreturn } -attributes #5 = { argmemonly mustprogress nofree nounwind willreturn } -attributes #6 = { convergent "approx-func-fp-math"="true" "frame-pointer"="all" "no-builtins" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="true" } -attributes #7 = { nounwind } -attributes #8 = { argmemonly nounwind willreturn } -attributes #9 = { nounwind readnone willreturn } -attributes #10 = { convergent nobuiltin nounwind "no-builtins" } - -!llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7, !8} -!air.kernel = !{!9} -!air.compile_options = !{!19, !20, !21} -!llvm.ident = !{!22} -!air.version = !{!23} -!air.language_version = !{!24} -!air.source_file_name = !{!25} - -!0 = !{i32 2, !"SDK Version", [2 x i32] [i32 26, i32 2]} -!1 = !{i32 1, !"wchar_size", i32 4} -!2 = !{i32 7, !"frame-pointer", i32 2} -!3 = !{i32 7, !"air.max_device_buffers", i32 31} -!4 = !{i32 7, !"air.max_constant_buffers", i32 31} -!5 = !{i32 7, !"air.max_threadgroup_buffers", i32 31} -!6 = !{i32 7, !"air.max_textures", i32 128} -!7 = !{i32 7, !"air.max_read_write_textures", i32 8} -!8 = !{i32 7, !"air.max_samplers", i32 16} -!9 = !{void (half addrspace(1)*, half addrspace(1)*, float addrspace(1)*, i32 addrspace(2)*, i32 addrspace(2)*, i32 addrspace(2)*, <2 x i32>)* @inline_matmul, !10, !11} -!10 = !{} -!11 = !{!12, !13, !14, !15, !16, !17, !18} -!12 = !{i32 0, !"air.buffer", !"air.location_index", i32 0, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.arg_type_size", i32 2, !"air.arg_type_align_size", i32 2, !"air.arg_type_name", !"half", !"air.arg_name", !"Abuf"} -!13 = !{i32 1, !"air.buffer", !"air.location_index", i32 1, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.arg_type_size", i32 2, !"air.arg_type_align_size", i32 2, !"air.arg_type_name", !"half", !"air.arg_name", !"Bbuf"} -!14 = !{i32 2, !"air.buffer", !"air.location_index", i32 2, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.arg_type_size", i32 4, !"air.arg_type_align_size", i32 4, !"air.arg_type_name", !"float", !"air.arg_name", !"Cbuf"} -!15 = !{i32 3, !"air.buffer", !"air.buffer_size", i32 4, !"air.location_index", i32 3, i32 1, !"air.read", !"air.address_space", i32 2, !"air.arg_type_size", i32 4, !"air.arg_type_align_size", i32 4, !"air.arg_type_name", !"uint", !"air.arg_name", !"M"} -!16 = !{i32 4, !"air.buffer", !"air.buffer_size", i32 4, !"air.location_index", i32 4, i32 1, !"air.read", !"air.address_space", i32 2, !"air.arg_type_size", i32 4, !"air.arg_type_align_size", i32 4, !"air.arg_type_name", !"uint", !"air.arg_name", !"N"} -!17 = !{i32 5, !"air.buffer", !"air.buffer_size", i32 4, !"air.location_index", i32 5, i32 1, !"air.read", !"air.address_space", i32 2, !"air.arg_type_size", i32 4, !"air.arg_type_align_size", i32 4, !"air.arg_type_name", !"uint", !"air.arg_name", !"K"} -!18 = !{i32 6, !"air.threadgroup_position_in_grid", !"air.arg_type_name", !"uint2", !"air.arg_name", !"tgid"} -!19 = !{!"air.compile.denorms_disable"} -!20 = !{!"air.compile.fast_math_enable"} -!21 = !{!"air.compile.framebuffer_fetch_enable"} -!22 = !{!"Apple metal version 32023.864 (metalfe-32023.864)"} -!23 = !{i32 2, i32 8, i32 0} -!24 = !{!"Metal", i32 4, i32 0, i32 0} -!25 = !{!"/private/tmp/metaltest/inline_matmul.metal"} -!26 = !{!27, !27, i64 0} -!27 = !{!"int", !28, i64 0} -!28 = !{!"omnipotent char", !29, i64 0} -!29 = !{!"Simple C++ TBAA"} -!30 = !{!31} -!31 = distinct !{!31, !32, !"air-alias-scope-arg(5)"} -!32 = distinct !{!32, !"air-alias-scopes(inline_matmul)"} -!33 = !{!34, !35, !36, !37, !38} -!34 = distinct !{!34, !32, !"air-alias-scope-arg(0)"} -!35 = distinct !{!35, !32, !"air-alias-scope-arg(1)"} -!36 = distinct !{!36, !32, !"air-alias-scope-arg(2)"} -!37 = distinct !{!37, !32, !"air-alias-scope-arg(3)"} -!38 = distinct !{!38, !32, !"air-alias-scope-arg(4)"} -!39 = !{!37} -!40 = !{!34, !35, !36, !38, !31} -!41 = !{!38} -!42 = !{!34, !35, !36, !37, !31} -!43 = !{!44} -!44 = distinct !{!44, !45, !"_ZNK5metal6tensorIU9MTLdeviceDhNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEENS_13tensor_inlineEJEE5sliceIJijEEENS_9enable_ifIXaafraa16is_convertible_vIT_iEeqsZT_clL_ZNS5_8get_rankEvEEES5_E4typeEDpS8_: argument 0"} -!45 = distinct !{!45, !"_ZNK5metal6tensorIU9MTLdeviceDhNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEENS_13tensor_inlineEJEE5sliceIJijEEENS_9enable_ifIXaafraa16is_convertible_vIT_iEeqsZT_clL_ZNS5_8get_rankEvEEES5_E4typeEDpS8_"} -!46 = !{!47} -!47 = distinct !{!47, !48, !"_ZNK5metal6tensorIU9MTLdeviceDhNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEENS_13tensor_inlineEJEE5sliceIJjiEEENS_9enable_ifIXaafraa16is_convertible_vIT_iEeqsZT_clL_ZNS5_8get_rankEvEEES5_E4typeEDpS8_: argument 0"} -!48 = distinct !{!48, !"_ZNK5metal6tensorIU9MTLdeviceDhNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEENS_13tensor_inlineEJEE5sliceIJjiEEENS_9enable_ifIXaafraa16is_convertible_vIT_iEeqsZT_clL_ZNS5_8get_rankEvEEES5_E4typeEDpS8_"} -!49 = !{!50} -!50 = distinct !{!50, !51, !"_ZNK5metal6tensorIU9MTLdevicefNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEENS_13tensor_inlineEJEE5sliceIJjjEEENS_9enable_ifIXaafraa16is_convertible_vIT_iEeqsZT_clL_ZNS5_8get_rankEvEEES5_E4typeEDpS8_: argument 0"} -!51 = distinct !{!51, !"_ZNK5metal6tensorIU9MTLdevicefNS_7extentsIiJLm18446744073709551615ELm18446744073709551615EEEENS_13tensor_inlineEJEE5sliceIJjjEEENS_9enable_ifIXaafraa16is_convertible_vIT_iEeqsZT_clL_ZNS5_8get_rankEvEEES5_E4typeEDpS8_"} -!52 = !{i64 0, i64 4, !26, i64 4, i64 4, !26, i64 8, i64 4, !26, i64 12, i64 1, !53, i64 13, i64 1, !53, i64 14, i64 1, !53, i64 16, i64 4, !55} -!53 = !{!54, !54, i64 0} -!54 = !{!"bool", !28, i64 0} -!55 = !{!56, !56, i64 0} -!56 = !{!"_ZTSN3mpp10tensor_ops19matmul2d_descriptor4modeE", !28, i64 0} diff --git a/bin/inline_matmul.metal b/bin/inline_matmul.metal deleted file mode 100644 index f1dd66565..000000000 --- a/bin/inline_matmul.metal +++ /dev/null @@ -1,32 +0,0 @@ -#include -#include -#include - -using namespace metal; -using namespace mpp::tensor_ops; - -kernel void inline_matmul(device half* Abuf, - device half* Bbuf, - device float* Cbuf, - constant uint& M, - constant uint& N, - constant uint& K, - uint2 tgid [[threadgroup_position_in_grid]]) -{ - // Build tensor_inline views over raw buffers. - auto A = tensor, tensor_inline>( - Abuf, dextents{int32_t(K), int32_t(M)}); - auto B = tensor, tensor_inline>( - Bbuf, dextents{int32_t(N), int32_t(K)}); - auto C = tensor, tensor_inline>( - Cbuf, dextents{int32_t(N), int32_t(M)}); - - constexpr auto desc = matmul2d_descriptor(64, 32, static_cast(dynamic_extent)); - matmul2d> op; - - auto mA = A.slice(0, tgid.y * 64); - auto mB = B.slice(tgid.x * 32, 0); - auto mC = C.slice(tgid.x * 32, tgid.y * 64); - - op.run(mA, mB, mC); -} diff --git a/bin/simple_matmul.ll b/bin/simple_matmul.ll deleted file mode 100644 index c3a3bb3ff..000000000 --- a/bin/simple_matmul.ll +++ /dev/null @@ -1,128 +0,0 @@ -; ModuleID = 'simple_matmul.metal' -source_filename = "simple_matmul.metal" -target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024-n8:16:32" -target triple = "air64_v28-apple-macosx26.0.0" - -%"struct.mpp::tensor_ops::matmul2d_descriptor" = type { i32, i32, i32, i8, i8, i8, i32 } -%struct._tensor_t = type opaque -%"struct.metal::tensor.3" = type { %"struct.metal::__tensor_base.4", %struct._tensor_t addrspace(1)* } -%"struct.metal::__tensor_base.4" = type { %"struct.metal::__tensor_offsets.5" } -%"struct.metal::__tensor_offsets.5" = type { %"struct.metal::array" } -%"struct.metal::array" = type { [2 x i32] } -%"struct.metal::tensor.6" = type { %"struct.metal::__tensor_base.7", %struct._tensor_t addrspace(1)* } -%"struct.metal::__tensor_base.7" = type { %"struct.metal::__tensor_offsets.8" } -%"struct.metal::__tensor_offsets.8" = type { %"struct.metal::array" } - -@_ZTAXtlN3mpp10tensor_ops19matmul2d_descriptorELi64ELi32ELin1EEE = linkonce_odr local_unnamed_addr constant %"struct.mpp::tensor_ops::matmul2d_descriptor" { i32 64, i32 32, i32 -1, i8 0, i8 0, i8 0, i32 0 } - -; Function Attrs: convergent nounwind -define void @simple_matmul(%struct._tensor_t addrspace(1)* %0, %struct._tensor_t addrspace(1)* %1, %struct._tensor_t addrspace(1)* %2, <2 x i32> noundef %3) local_unnamed_addr #0 { - %5 = alloca %"struct.mpp::tensor_ops::matmul2d_descriptor", align 4 - %6 = alloca %"struct.metal::tensor.3", align 8 - %7 = alloca %"struct.metal::tensor.3", align 8 - %8 = alloca %"struct.metal::tensor.6", align 8 - %9 = bitcast %"struct.metal::tensor.3"* %6 to i8* - call void @llvm.lifetime.start.p0i8(i64 16, i8* nonnull %9) #5 - %10 = extractelement <2 x i32> %3, i64 1 - %11 = shl i32 %10, 6 - %12 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %6, i64 0, i32 0, i32 0, i32 0, i32 0, i64 0 - store i32 0, i32* %12, align 8 - %13 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %6, i64 0, i32 0, i32 0, i32 0, i32 0, i64 1 - store i32 %11, i32* %13, align 4 - %14 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %6, i64 0, i32 1 - store %struct._tensor_t addrspace(1)* %0, %struct._tensor_t addrspace(1)** %14, align 8 - %15 = bitcast %"struct.metal::tensor.3"* %7 to i8* - call void @llvm.lifetime.start.p0i8(i64 16, i8* nonnull %15) #5 - %16 = extractelement <2 x i32> %3, i64 0 - %17 = shl i32 %16, 5 - %18 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %7, i64 0, i32 0, i32 0, i32 0, i32 0, i64 0 - store i32 %17, i32* %18, align 8 - %19 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %7, i64 0, i32 0, i32 0, i32 0, i32 0, i64 1 - store i32 0, i32* %19, align 4 - %20 = getelementptr inbounds %"struct.metal::tensor.3", %"struct.metal::tensor.3"* %7, i64 0, i32 1 - store %struct._tensor_t addrspace(1)* %1, %struct._tensor_t addrspace(1)** %20, align 8 - %21 = bitcast %"struct.metal::tensor.6"* %8 to i8* - call void @llvm.lifetime.start.p0i8(i64 16, i8* nonnull %21) #5 - %22 = getelementptr inbounds %"struct.metal::tensor.6", %"struct.metal::tensor.6"* %8, i64 0, i32 0, i32 0, i32 0, i32 0, i64 0 - store i32 %17, i32* %22, align 8 - %23 = getelementptr inbounds %"struct.metal::tensor.6", %"struct.metal::tensor.6"* %8, i64 0, i32 0, i32 0, i32 0, i32 0, i64 1 - store i32 %11, i32* %23, align 4 - %24 = getelementptr inbounds %"struct.metal::tensor.6", %"struct.metal::tensor.6"* %8, i64 0, i32 1 - store %struct._tensor_t addrspace(1)* %2, %struct._tensor_t addrspace(1)** %24, align 8 - %25 = tail call i32 @air.get_simdgroup_size.i32() #6 - %26 = shl i32 %25, 2 - %27 = bitcast %"struct.mpp::tensor_ops::matmul2d_descriptor"* %5 to i8* - call void @llvm.lifetime.start.p0i8(i64 20, i8* nonnull %27) #5 - call void @llvm.memcpy.p0i8.p0i8.i64(i8* noundef nonnull align 4 dereferenceable(20) %27, i8* noundef nonnull align 4 dereferenceable(20) bitcast (%"struct.mpp::tensor_ops::matmul2d_descriptor"* @_ZTAXtlN3mpp10tensor_ops19matmul2d_descriptorELi64ELi32ELin1EEE to i8*), i64 20, i1 false) #5, !tbaa.struct !23 - call void @__tensorops_impl_matmul2d_op_run_dv_f16_dv_f16_dv_f32(%"struct.mpp::tensor_ops::matmul2d_descriptor"* noundef nonnull align 4 dereferenceable(20) %5, i8* noundef nonnull %9, i32 noundef 1, i8* noundef nonnull %15, i32 noundef 1, i8* noundef nonnull %21, i32 noundef 1, i32 noundef %26) #7 - call void @llvm.lifetime.end.p0i8(i64 20, i8* nonnull %27) #5 - call void @llvm.lifetime.end.p0i8(i64 16, i8* nonnull %21) #5 - call void @llvm.lifetime.end.p0i8(i64 16, i8* nonnull %15) #5 - call void @llvm.lifetime.end.p0i8(i64 16, i8* nonnull %9) #5 - ret void -} - -; Function Attrs: argmemonly mustprogress nocallback nofree nosync nounwind willreturn -declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #1 - -; Function Attrs: argmemonly mustprogress nocallback nofree nosync nounwind willreturn -declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #1 - -; Function Attrs: argmemonly mustprogress nofree nounwind willreturn -declare void @llvm.memcpy.p0i8.p0i8.i64(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i64, i1 immarg) #2 - -; Function Attrs: convergent -declare void @__tensorops_impl_matmul2d_op_run_dv_f16_dv_f16_dv_f32(%"struct.mpp::tensor_ops::matmul2d_descriptor"* noundef nonnull align 4 dereferenceable(20), i8* noundef, i32 noundef, i8* noundef, i32 noundef, i8* noundef, i32 noundef, i32 noundef) local_unnamed_addr #3 section "air.externally_defined" - -; Function Attrs: mustprogress nofree nosync nounwind readnone willreturn -declare i32 @air.get_simdgroup_size.i32() local_unnamed_addr #4 - -attributes #0 = { convergent nounwind "approx-func-fp-math"="true" "frame-pointer"="all" "min-legal-vector-width"="64" "no-builtins" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="true" } -attributes #1 = { argmemonly mustprogress nocallback nofree nosync nounwind willreturn } -attributes #2 = { argmemonly mustprogress nofree nounwind willreturn } -attributes #3 = { convergent "approx-func-fp-math"="true" "frame-pointer"="all" "no-builtins" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="true" } -attributes #4 = { mustprogress nofree nosync nounwind readnone willreturn } -attributes #5 = { nounwind } -attributes #6 = { nounwind readnone willreturn } -attributes #7 = { convergent nobuiltin nounwind "no-builtins" } - -!llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7, !8} -!air.kernel = !{!9} -!air.compile_options = !{!16, !17, !18} -!llvm.ident = !{!19} -!air.version = !{!20} -!air.language_version = !{!21} -!air.source_file_name = !{!22} - -!0 = !{i32 2, !"SDK Version", [2 x i32] [i32 26, i32 2]} -!1 = !{i32 1, !"wchar_size", i32 4} -!2 = !{i32 7, !"frame-pointer", i32 2} -!3 = !{i32 7, !"air.max_device_buffers", i32 31} -!4 = !{i32 7, !"air.max_constant_buffers", i32 31} -!5 = !{i32 7, !"air.max_threadgroup_buffers", i32 31} -!6 = !{i32 7, !"air.max_textures", i32 128} -!7 = !{i32 7, !"air.max_read_write_textures", i32 8} -!8 = !{i32 7, !"air.max_samplers", i32 16} -!9 = !{void (%struct._tensor_t addrspace(1)*, %struct._tensor_t addrspace(1)*, %struct._tensor_t addrspace(1)*, <2 x i32>)* @simple_matmul, !10, !11} -!10 = !{} -!11 = !{!12, !13, !14, !15} -!12 = !{i32 0, !"air.tensor", !"air.location_index", i32 0, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.arg_type_name", !"tensor>", !"air.arg_name", !"A"} -!13 = !{i32 1, !"air.tensor", !"air.location_index", i32 1, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.arg_type_name", !"tensor>", !"air.arg_name", !"B"} -!14 = !{i32 2, !"air.tensor", !"air.location_index", i32 2, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.arg_type_name", !"tensor>", !"air.arg_name", !"C"} -!15 = !{i32 3, !"air.threadgroup_position_in_grid", !"air.arg_type_name", !"uint2", !"air.arg_name", !"tgid"} -!16 = !{!"air.compile.denorms_disable"} -!17 = !{!"air.compile.fast_math_enable"} -!18 = !{!"air.compile.framebuffer_fetch_enable"} -!19 = !{!"Apple metal version 32023.864 (metalfe-32023.864)"} -!20 = !{i32 2, i32 8, i32 0} -!21 = !{!"Metal", i32 4, i32 0, i32 0} -!22 = !{!"/private/tmp/metaltest/simple_matmul.metal"} -!23 = !{i64 0, i64 4, !24, i64 4, i64 4, !24, i64 8, i64 4, !24, i64 12, i64 1, !28, i64 13, i64 1, !28, i64 14, i64 1, !28, i64 16, i64 4, !30} -!24 = !{!25, !25, i64 0} -!25 = !{!"int", !26, i64 0} -!26 = !{!"omnipotent char", !27, i64 0} -!27 = !{!"Simple C++ TBAA"} -!28 = !{!29, !29, i64 0} -!29 = !{!"bool", !26, i64 0} -!30 = !{!31, !31, i64 0} -!31 = !{!"_ZTSN3mpp10tensor_ops19matmul2d_descriptor4modeE", !26, i64 0} diff --git a/bin/simple_matmul.metal b/bin/simple_matmul.metal deleted file mode 100644 index 5b9f94f33..000000000 --- a/bin/simple_matmul.metal +++ /dev/null @@ -1,22 +0,0 @@ -#include -#include -#include - -using namespace metal; -using namespace mpp::tensor_ops; - -kernel void simple_matmul(tensor> A, - tensor> B, - tensor> C, - uint2 tgid [[threadgroup_position_in_grid]]) -{ - constexpr auto desc = matmul2d_descriptor(64, 32, static_cast(dynamic_extent), - false, false, false); - matmul2d> op; - - auto mA = A.slice(0, tgid.y * 64); - auto mB = B.slice(tgid.x * 32, 0); - auto mC = C.slice(tgid.x * 32, tgid.y * 64); - - op.run(mA, mB, mC); -} From 9b9fdac7ad84fac039c5fcbafc6b6b6d571e1575 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 28 May 2026 17:58:56 +0200 Subject: [PATCH 21/24] tensor.jl: use BFloat16s.BFloat16 for Julia 1.10 compat. Core.BFloat16 only exists from Julia 1.11; Project.toml declares julia = "1.10". Route through BFloat16s.BFloat16 (already a direct dep) which aliases Core.BFloat16 on 1.11+. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/device/intrinsics/tensor.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/device/intrinsics/tensor.jl b/src/device/intrinsics/tensor.jl index c4ca960ea..77d3631b7 100644 --- a/src/device/intrinsics/tensor.jl +++ b/src/device/intrinsics/tensor.jl @@ -2,6 +2,7 @@ export MtlInlineTensor, matmul2d_descriptor, TensorOpsMatmul2D, matmul2d_multiply, matmul2d_multiply_accumulate, tensor_matmul! using Core: LLVMPtr +using BFloat16s: BFloat16 # Wrappers for Metal 4 tensor-ops / `mpp::tensor_ops` device-side APIs. # @@ -253,7 +254,7 @@ const _TENSOR_DESC_INLINE = Int32(2) # `__tensor_ops_tensor_descriptor_type::t # an `i8`/`ui8` × `i4`/`ui4` matmul. _tensorops_suffix(::Type{Float16}) = "f16" _tensorops_suffix(::Type{Float32}) = "f32" -_tensorops_suffix(::Type{Core.BFloat16}) = "b16" +_tensorops_suffix(::Type{BFloat16}) = "b16" _tensorops_suffix(::Type{Int8}) = "i8" _tensorops_suffix(::Type{UInt8}) = "ui8" _tensorops_suffix(::Type{Int32}) = "i32" From f269db8a281f1f0b3e2555c3c97db48aee04d5d6 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 28 May 2026 18:06:08 +0200 Subject: [PATCH 22/24] tensor: drop Int32 wrapping inside Val(...). MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Val is the right hammer (Julia methods specialize on types, not values, so this is still the way to lift a numeric to a kernel type parameter), but the value inside should just be the literal — kernels Int32-cast where they need it. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/flashattention.jl | 2 +- src/tensor.jl | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/flashattention.jl b/examples/flashattention.jl index 3177e986a..4a8e803a2 100644 --- a/examples/flashattention.jl +++ b/examples/flashattention.jl @@ -296,7 +296,7 @@ function attention_tensor(Q::MtlArray{Float16,4}, K::MtlArray{Float16,4}, # head. Metal.@sync @metal threads = threads groups = (H, B, 1) _fa_tensor!( O, Q, K, V, Float32(scale), - Val(Int32(D)), Val(Int32(N)), Val(Int32(nsimd))) + Val(D), Val(N), Val(nsimd)) return O end diff --git a/src/tensor.jl b/src/tensor.jl index 46e0409cd..543c86427 100644 --- a/src/tensor.jl +++ b/src/tensor.jl @@ -82,7 +82,6 @@ function tensor_matmul!(C::MtlMatrix{TC}, A::MtlMatrix{TA}, B::MtlMatrix{TB}; @metal threads = nsimd * 32 groups = groups _tensor_matmul_kernel!( C, A, B, UInt32(aM), UInt32(aN), UInt32(aK), - Val(Int32(tile_m)), Val(Int32(tile_n)), Val(Int32(tile_k)), - Val(Int32(nsimd))) + Val(tile_m), Val(tile_n), Val(tile_k), Val(nsimd)) return C end From ad3f0c1d47d4cbc37aa5b5f0afe033c920ead860 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 28 May 2026 18:08:33 +0200 Subject: [PATCH 23/24] tensor: make view(::MtlInlineTensor) 1-based. Origin now follows Julia indexing instead of MSL's zero-based scheme; the subtract-1 happens once inside `view` before handing off to the `air.slice_private_tensor` intrinsic. Callsites updated to match. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/device/intrinsics/tensor.jl | 9 +++++---- src/tensor.jl | 6 +++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/device/intrinsics/tensor.jl b/src/device/intrinsics/tensor.jl index 77d3631b7..bbd2fcb57 100644 --- a/src/device/intrinsics/tensor.jl +++ b/src/device/intrinsics/tensor.jl @@ -178,14 +178,15 @@ end Base.eltype(::Type{<:MtlInlineTensor{T}}) where {T} = T Base.eltype(::MtlInlineTensor{T}) where {T} = T -# Slice. Origins are zero-based to match MSL semantics. +# Slice. `origin` is 1-based to match Julia indexing (the underlying +# `air.slice_private_tensor` intrinsic uses 0-based MSL coordinates). @device_function @inline function Base.view( t::MtlInlineTensor{T, R, A}, origin::NTuple{R, <:Integer}, extents::NTuple{R, <:Integer}) where {T, R, A} storage = Ref{NTuple{_TENSOR_DESCRIPTOR_SIZE, UInt8}}() slice_private_tensor!(storage, t.storage, Int16(R), - Int32.(origin), Int32.(extents)) + Int32.(origin) .- Int32(1), Int32.(extents)) return MtlInlineTensor{T, R, A}(storage) end @@ -214,8 +215,8 @@ op = TensorOpsMatmul2D{matmul2d_descriptor(M, N, TileK; mode = matmul2d_multiply_accumulate), Int32(NSIMD)}() for s in 0:(nslices - 1) - sA = view(tA, (Int32(s) * Int32(TileK), Int32(0)), (Int32(TileK), Int32(M))) - sB = view(tB, (Int32(0), Int32(s) * Int32(TileK)), (Int32(N), Int32(TileK))) + sA = view(tA, (Int32(s) * Int32(TileK) + Int32(1), Int32(1)), (Int32(TileK), Int32(M))) + sB = view(tB, (Int32(1), Int32(s) * Int32(TileK) + Int32(1)), (Int32(N), Int32(TileK))) op(sA, sB, tC) end ``` diff --git a/src/tensor.jl b/src/tensor.jl index 543c86427..2d2e68a17 100644 --- a/src/tensor.jl +++ b/src/tensor.jl @@ -10,8 +10,8 @@ function _tensor_matmul_kernel!(C::MtlDeviceArray, A::MtlDeviceArray, B::MtlDevi tgid = threadgroup_position_in_grid_3d() n_tile = Int32(tgid.x) - Int32(1) m_tile = Int32(tgid.y) - Int32(1) - n_off = n_tile * Int32(TN) - m_off = m_tile * Int32(TM) + n_off = n_tile * Int32(TN) + Int32(1) + m_off = m_tile * Int32(TM) + Int32(1) tA = MtlInlineTensor(B, (K, M)) tB = MtlInlineTensor(A, (N, K)) @@ -24,7 +24,7 @@ function _tensor_matmul_kernel!(C::MtlDeviceArray, A::MtlDeviceArray, B::MtlDevi Int32(NSIMD)}() nslices = Int32(K ÷ UInt32(TK)) for s in Int32(0):(nslices - Int32(1)) - k_off = s * Int32(TK) + k_off = s * Int32(TK) + Int32(1) mA = view(tA, (k_off, m_off), (Int32(TK), Int32(TM))) mB = view(tB, (n_off, k_off), (Int32(TN), Int32(TK))) op(mA, mB, mC) From cb4a938d0d7d95d953bd1881e70a4f1ba8a7f980 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 28 May 2026 19:00:35 +0200 Subject: [PATCH 24/24] compiler: whitelist __tensorops_impl_* in isintrinsic. Metal 4 tensor-ops lower to externally-defined `__tensorops_impl_*` symbols, linked from the MetalPerformancePrimitives runtime at metallib build time. GPUCompiler's upstream Metal `isintrinsic` only whitelists `air.*` and rejected them as undefined. Add a more-specific method on `MetalCompilerJob` (already strictly more specific than the upstream `CompilerJob{MetalCompilerTarget}` method thanks to our `MetalCompilerParams`) that `invoke`s the upstream check and additionally accepts the MPP prefix. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/compiler/compilation.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/compiler/compilation.jl b/src/compiler/compilation.jl index bbf5150cf..21ce31907 100644 --- a/src/compiler/compilation.jl +++ b/src/compiler/compilation.jl @@ -10,6 +10,15 @@ GPUCompiler.method_table(::MetalCompilerJob) = method_table GPUCompiler.kernel_state_type(job::MetalCompilerJob) = KernelState +# Metal 4 tensor-ops lower to externally-defined `__tensorops_impl_*` symbols +# resolved at metallib link time against the MetalPerformancePrimitives +# runtime — not `air.*` intrinsics, so the upstream `startswith(fn, "air.")` +# check in `GPUCompiler/src/metal.jl` rejects them. +GPUCompiler.isintrinsic(job::MetalCompilerJob, fn::String) = + invoke(GPUCompiler.isintrinsic, + Tuple{CompilerJob{MetalCompilerTarget}, String}, job, fn) || + startswith(fn, "__tensorops_impl_") + function GPUCompiler.finish_module!(@nospecialize(job::MetalCompilerJob), mod::LLVM.Module, entry::LLVM.Function) entry = invoke(GPUCompiler.finish_module!,