Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3445,13 +3445,15 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));

#ifndef __EMSCRIPTEN__
// Only support square f16 matrices of size 8 or 16 for now
// Accept f16 subgroup matrix configurations (square or non-square).
// NVIDIA GPUs typically report square configs (e.g. 16x16x16),
// while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16).
// The shaders are already parameterized to handle any M/N/K dimensions.
bool valid_subgroup_matrix_config = false;
if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
if (config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M;
ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N;
Expand Down Expand Up @@ -3789,6 +3791,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
break;
}
// Head dimensions must be divisible by subgroup matrix dimensions
if (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k != 0 ||
src2->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_n != 0) {
break;
}
// Head dimensions must fit in workgroup memory with minimum tile sizes
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
const bool has_mask = op->src[3] != nullptr;
Expand Down
34 changes: 17 additions & 17 deletions ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -369,35 +369,35 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
#endif
for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
let inter_offset = kv_block * SG_MAT_N;
var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE);
var acc: subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M>>(&inter_shmem, inter_offset, false, KV_TILE);

var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, 0u, false, HEAD_DIM_QK);
var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, 0u, false, HEAD_DIM_QK);

#ifdef KV_DIRECT
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + 0u, true, params.stride_k1);
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + 0u, true, params.stride_k1);
#else
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK);
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK);
#endif

var t: u32 = 1u;
for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) {
let h0 = t * SG_MAT_K;
var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h0, false, HEAD_DIM_QK);
var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h0, false, HEAD_DIM_QK);
#ifdef KV_DIRECT
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h0, true, params.stride_k1);
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h0, true, params.stride_k1);
#else
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK);
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK);
#endif
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
q_cur = q0;
k_cur = k0;

let h1 = (t + 1u) * SG_MAT_K;
var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h1, false, HEAD_DIM_QK);
var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h1, false, HEAD_DIM_QK);
#ifdef KV_DIRECT
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h1, true, params.stride_k1);
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h1, true, params.stride_k1);
#else
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK);
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK);
#endif
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
q_cur = q1g;
Expand All @@ -407,11 +407,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
// handle odd tail
if (t < HEAD_DIM_QK / SG_MAT_K) {
let h = t * SG_MAT_K;
var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h, false, HEAD_DIM_QK);
var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h, false, HEAD_DIM_QK);
#ifdef KV_DIRECT
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h, true, params.stride_k1);
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h, true, params.stride_k1);
#else
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK);
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK);
#endif
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
q_cur = qn;
Expand Down Expand Up @@ -566,15 +566,15 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
head_dim_block < HEAD_DIM_V;
head_dim_block += num_subgroups * SG_MAT_N) {
// load O submatrix from shared memory
var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(
var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M>>(
&o_shmem,
head_dim_block,
false,
HEAD_DIM_V
);
for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) {
let p_offset = kv_block * SG_MAT_N;
var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(
&inter_shmem,
p_offset,
false,
Expand All @@ -585,15 +585,15 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
#ifdef KV_DIRECT
let v_block_row = kv_tile + kv_block * SG_MAT_N;
let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block;
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(
&V,
v_global_offset,
false,
params.stride_v1
);
#else
let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V;
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(
&kv_shmem,
v_block_offset + head_dim_block,
false,
Expand Down
Loading