Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
a8768c4838f67dfad0b3c1b2518e8521c9f6440f
6be2e8902951dabfc15a0c6f0bb872742a959aa3
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
From 44a19189352d77f3e7caa43cb327238dd6243b75 Mon Sep 17 00:00:00 2001
From 2cea04474a86c5c55996b72af05c411e9eb87dd7 Mon Sep 17 00:00:00 2001
From: Garra1980 <igor.zamyatin@intel.com>
Date: Thu, 29 Jan 2026 19:10:23 +0100
Subject: [PATCH] Add-support-for-VectorAnyINTEL-capability
Date: Thu, 12 Feb 2026 22:33:24 +0100
Subject: [PATCH] Add support for VectorAnyINTEL capability

---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 11 +-
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 19 ++-
mlir/include/mlir/IR/CommonTypeConstraints.td | 86 +++++++++++
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 7 +-
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 26 +++-
.../SPIRV/Transforms/SPIRVConversion.cpp | 135 +++++++++++++++---
5 files changed, 233 insertions(+), 32 deletions(-)
5 files changed, 241 insertions(+), 32 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index f8093d3042c5..1390b8e03b86 100644
index 2f189c64300a..983d9e6c4dd1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4242,7 +4242,20 @@ def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
@@ -4275,7 +4275,20 @@ def SPIRV_Float8E5M2EXT : TypeAlias<F8E5M2, "Float8E5M2">;
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR]>;
def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR, SPIRV_Float8E4M3EXT, SPIRV_Float8E5M2EXT]>;
-def SPIRV_Vector : VectorOfRankAndLengthAndType<[1], [2, 3, 4, 8, 16],
+// Vector type is quite restrictive in SPIR-V.
+// It only allows length of 2, 3, and 4 by default and
Expand All @@ -37,7 +37,7 @@ index f8093d3042c5..1390b8e03b86 100644
[SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat]>;
// Component type check is done in the type parser for the following SPIR-V
// dialect-specific types so we use "Any" here.
@@ -4295,7 +4308,9 @@ class SPIRV_MatrixOfType<list<Type> allowedTypes> :
@@ -4328,7 +4341,9 @@ class SPIRV_MatrixOfType<list<Type> allowedTypes> :
"Matrix">;

class SPIRV_VectorOf<Type type> :
Expand Down Expand Up @@ -146,11 +146,11 @@ index a49880b81e90..b351a23b7b5b 100644
// Negative values for `n` index in reverse.
class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type<
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 7c3bfd72115e..c9c8e8305062 100644
index 78f33c238d41..8602b5d81140 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -186,9 +186,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
@@ -190,9 +190,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
parser.emitError(typeLoc, "SPIR-V does not allow one-element vectors");
return Type();
}
- if (t.getNumElements() > 4) {
Expand All @@ -165,7 +165,7 @@ index 7c3bfd72115e..c9c8e8305062 100644
return Type();
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 342a47cdefbf..b7ccd05fff17 100644
index 63b51d1836f7..a31d96d6bbb0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -186,9 +186,10 @@ bool CompositeType::classof(Type type) {
Expand Down

This file was deleted.

36 changes: 20 additions & 16 deletions test/Integration/Dialect/XeGPU/WG/flash_attention_fwd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,16 @@ module @flash_attention attributes {gpu.container_module} {


// Initialize m, l and acc
%m_i_row_in = arith.constant {layout_result_0 = #layout_128x1} dense<0xFF800000> : vector<128x1xf32> // -inf
%l_i_row_in = arith.constant {layout_result_0 = #layout_128x1} dense<1.0> : vector<128x1xf32> // 1.0
%m_i_row_in = arith.constant {layout_result_0 = #layout_128} dense<0xFF800000> : vector<128xf32> // -inf
%l_i_row_in = arith.constant {layout_result_0 = #layout_128} dense<1.0> : vector<128xf32> // 1.0
%zero_dpas_128x16 = arith.constant {layout_result_0 = #layout_128x16} dense<0.0> : vector<128x16xf32>
%zero_128x64 = arith.constant {layout_result_0 = #out} dense<0.0> : vector<128x64xf32>
%zero_128 = arith.constant {layout_result_0 = #layout_128} dense<0.000000e+00> : vector<128xf32>
%minus_inf_128 = arith.constant {layout_result_0 = #layout_128} dense<0xFF800000> : vector<128xf32> // -inf

// Softmax scaling
// FIXME: value 0.5 is hard coded. need to take it from %sm_scale
%qk_scale_128 = arith.constant {layout_result_0 = #layout_128} dense<0.5> : vector<128xf32>
%qk_scale_128x1 = arith.constant {layout_result_0 = #layout_128x1} dense<0.5> : vector<128x1xf32>
%qk_scale_128x16 = arith.constant {layout_result_0 = #layout_128x16} dense<0.5> : vector<128x16xf32>

Expand All @@ -109,7 +110,7 @@ module @flash_attention attributes {gpu.container_module} {
%l_i_row = %l_i_row_in
)
-> (
vector<128x64xf32>, vector<128x1xf32>, vector<128x1xf32>
vector<128x64xf32>, vector<128xf32>, vector<128xf32>
) {
gpu.barrier

Expand Down Expand Up @@ -161,19 +162,20 @@ module @flash_attention attributes {gpu.container_module} {
%qk_out_max_t3 = vector.multi_reduction <maximumf>, %qk_out_max_t2, %minus_inf_128
{layout_result_0 = #xegpu.slice<#layout_128x16, dims = [1]>}
[1] : vector<128x16xf32> to vector<128xf32>
%qk_out_max = vector.shape_cast %qk_out_max_t3 {layout_result_0 = #layout_128x1} : vector<128xf32> to vector<128x1xf32>
// %qk_out_max = vector.shape_cast %qk_out_max_t3 {layout_result_0 = #layout_128x1} : vector<128xf32> to vector<128x1xf32>

// Scale
%qk_out_max_scaled = arith.mulf %qk_out_max, %qk_scale_128x1 {layout_result_0 = #layout_128x1} : vector<128x1xf32>
%qk_out_max_scaled = arith.mulf %qk_out_max_t3, %qk_scale_128 {layout_result_0 = #layout_128} : vector<128xf32>
// Find m_ij_row
%m_ij_row = arith.maximumf %qk_out_max_scaled, %m_i_row fastmath<fast> {layout_result_0 = #layout_128x1} : vector<128x1xf32>
%m_ij_row = arith.maximumf %qk_out_max_scaled, %m_i_row fastmath<fast> {layout_result_0 = #layout_128} : vector<128xf32>
// Scale qk_out by qk_scale
%qk_out_0_scaled = arith.mulf %qk_out_0, %qk_scale_128x16 {layout_result_0 = #layout_128x16} : vector<128x16xf32>
%qk_out_1_scaled = arith.mulf %qk_out_1, %qk_scale_128x16 {layout_result_0 = #layout_128x16} : vector<128x16xf32>
%qk_out_2_scaled = arith.mulf %qk_out_2, %qk_scale_128x16 {layout_result_0 = #layout_128x16} : vector<128x16xf32>
%qk_out_3_scaled = arith.mulf %qk_out_3, %qk_scale_128x16 {layout_result_0 = #layout_128x16} : vector<128x16xf32>
// Broadcast m_ij_row to 128x16
%m_ij_row_broadcasted = vector.broadcast %m_ij_row {layout_result_0 = #layout_128x16} : vector<128x1xf32> to vector<128x16xf32>
%m_ij_row_broadcasted0 = vector.shape_cast %m_ij_row {layout_result_0 = #layout_128x1, layout_operand_0 = #xegpu.slice<#layout_128x1, dims=[1]>} : vector<128xf32> to vector<128x1xf32>
%m_ij_row_broadcasted = vector.broadcast %m_ij_row_broadcasted0 {layout_result_0 = #layout_128x16} : vector<128x1xf32> to vector<128x16xf32>
// Center qk_out by m_ij_row
%qk_out_0_centered = arith.subf %qk_out_0_scaled, %m_ij_row_broadcasted {layout_result_0 = #layout_128x16} : vector<128x16xf32>
%qk_out_1_centered = arith.subf %qk_out_1_scaled, %m_ij_row_broadcasted {layout_result_0 = #layout_128x16} : vector<128x16xf32>
Expand All @@ -191,15 +193,16 @@ module @flash_attention attributes {gpu.container_module} {
%l_ij_row_t3 = vector.multi_reduction <add>, %l_ij_row_t2, %zero_128
{layout_result_0 = #xegpu.slice<#layout_128x16, dims = [1]>}
[1] : vector<128x16xf32> to vector<128xf32>
%l_ij_row = vector.shape_cast %l_ij_row_t3 {layout_result_0 = #layout_128x1} : vector<128xf32> to vector<128x1xf32>
// %l_ij_row = vector.shape_cast %l_ij_row_t3 {layout_result_0 = #layout_128x1} : vector<128xf32> to vector<128x1xf32>
// Compute alpha
%alpha_row_t1 = arith.subf %m_i_row, %m_ij_row {layout_result_0 = #layout_128x1} : vector<128x1xf32>
%alpha_row = math.exp %alpha_row_t1 fastmath<fast> {layout_result_0 = #layout_128x1} : vector<128x1xf32>
%alpha_row_t1 = arith.subf %m_i_row, %m_ij_row {layout_result_0 = #layout_128} : vector<128xf32>
%alpha_row = math.exp %alpha_row_t1 fastmath<fast> {layout_result_0 = #layout_128} : vector<128xf32>
// Update l_i
%l_i_row_new_t1 = arith.mulf %l_i_row, %alpha_row {layout_result_0 = #layout_128x1} : vector<128x1xf32>
%l_i_row_new = arith.addf %l_i_row_new_t1, %l_ij_row {layout_result_0 = #layout_128x1} : vector<128x1xf32>
%l_i_row_new_t1 = arith.mulf %l_i_row, %alpha_row {layout_result_0 = #layout_128} : vector<128xf32>
%l_i_row_new = arith.addf %l_i_row_new_t1, %l_ij_row_t3 {layout_result_0 = #layout_128} : vector<128xf32>
// Update acc
%alpha_row_broadcasted = vector.broadcast %alpha_row {layout_result_0 = #out} : vector<128x1xf32> to vector<128x64xf32>
%alpha_row_broadcasted0 = vector.shape_cast %alpha_row {layout_result_0 = #layout_128x1, layout_operand_0 = #xegpu.slice<#layout_128x1, dims=[1]>} : vector<128xf32> to vector<128x1xf32>
%alpha_row_broadcasted = vector.broadcast %alpha_row_broadcasted0 {layout_result_0 = #out} : vector<128x1xf32> to vector<128x64xf32>
%acc_in_updated = arith.mulf %acc_in, %alpha_row_broadcasted {layout_result_0 = #out} : vector<128x64xf32>

// Convert qk_out_tile to DPAS-A precision for P*V computation.
Expand Down Expand Up @@ -228,10 +231,11 @@ module @flash_attention attributes {gpu.container_module} {
// Compute forth iteration update of 128x64 of P * V
%pv_out_iter3 = xegpu.dpas %qk_out_3_f16, %v_val_slice_3, %pv_out_iter2 {layout_a = #q, layout_b = #v, layout_cd = #out} : vector<128x16xf16>, vector<16x64xf16>, vector<128x64xf32> -> vector<128x64xf32>

scf.yield %pv_out_iter3, %m_ij_row, %l_i_row_new : vector<128x64xf32>, vector<128x1xf32>, vector<128x1xf32>
} {layout_result_0 = #out, layout_result_1 = #layout_128x1, layout_result_2 = #layout_128x1}// end of inner loop
scf.yield %pv_out_iter3, %m_ij_row, %l_i_row_new : vector<128x64xf32>, vector<128xf32>, vector<128xf32>
} {layout_result_0 = #out, layout_result_1 = #layout_128, layout_result_2 = #layout_128}// end of inner loop
// Divide acc output by l_i
%l_i_row_broadcast = vector.broadcast %result#2 {layout_result_0 = #out} : vector<128x1xf32> to vector<128x64xf32>
%l_i_row_broadcast0 = vector.shape_cast %result#2 {layout_result_0 = #layout_128x1, layout_operand_0 = #xegpu.slice<#layout_128x1, dims=[0]>} : vector<128xf32> to vector<128x1xf32>
%l_i_row_broadcast = vector.broadcast %l_i_row_broadcast0 {layout_result_0 = #out} : vector<128x1xf32> to vector<128x64xf32>
%o_val_final_t = arith.divf %result#0, %l_i_row_broadcast {layout_result_0 = #out} : vector<128x64xf32>
// Store output tile.
%o_val_final = arith.truncf %o_val_final_t {layout_result_0 = #out} : vector<128x64xf32> to vector<128x64xf16>
Expand Down
Loading