From 9b92f11f5aea7517093f748903c811564125b81b Mon Sep 17 00:00:00 2001 From: jung-min Date: Mon, 2 Mar 2026 07:58:59 +0000 Subject: [PATCH 1/4] [Frontend/template] add SDPA modules --- .../torch_openreg/openreg/__init__.py | 7 +- PyTorchSimFrontend/mlir/mlir_lowering.py | 25 +- PyTorchSimFrontend/mlir/mlir_sdpa_template.py | 664 ++++++++++++++++++ PyTorchSimFrontend/mlir/mlir_template.py | 101 ++- tests/test_sdpa.py | 84 +++ 5 files changed, 878 insertions(+), 3 deletions(-) create mode 100644 PyTorchSimFrontend/mlir/mlir_sdpa_template.py create mode 100644 tests/test_sdpa.py diff --git a/PyTorchSimDevice/torch_openreg/openreg/__init__.py b/PyTorchSimDevice/torch_openreg/openreg/__init__.py index 8d62cee3..5a0de6c3 100644 --- a/PyTorchSimDevice/torch_openreg/openreg/__init__.py +++ b/PyTorchSimDevice/torch_openreg/openreg/__init__.py @@ -24,7 +24,7 @@ class device: def __init__(self, device): self.idx = torch.accelerator._get_device_index(device, optional=True) - self.prev_idx = -1 + self.prev_idx = -1 def __enter__(self): self.prev_idx = torch_openreg._C._exchangeDevice(self.idx) @@ -64,6 +64,11 @@ def _lazy_init(): global _initialized, _tog_simulator if is_initialized(): return + + # Replace the global C++ binding with our custom dispatcher patch + from PyTorchSimFrontend.mlir.mlir_sdpa_template import patched_scaled_dot_product_attention + torch._C._nn.scaled_dot_product_attention = patched_scaled_dot_product_attention + torch_openreg._C._init() register_interface_for_device(custom_device(), ExtensionDeviceInterface) _initialized = True diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index ebf0c80e..e09dcf57 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -15,6 +15,7 @@ from PyTorchSimFrontend.mlir.mlir_conv_sb_template import MLIRConvSingleBatchTemplate from PyTorchSimFrontend.mlir.mlir_conv_sbs_template import MLIRConvSingleBatchStridedTemplate from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate +from PyTorchSimFrontend.mlir.mlir_sdpa_template import MLIRFlashSDPATemplate, flash_sdpa_args from PyTorchSimFrontend import extension_config aten = torch.ops.aten @@ -38,6 +39,26 @@ def tuned_bmm(mat1, mat2, *, layout=None): return mlir_template.generate().output_node() + +def tuned_flash_sdpa( + query : TensorBox, + key : TensorBox, + value : TensorBox, + scale : float, + dropout_p : float = 0.0, + is_causal : bool = False, + return_debug_mask : bool =False) -> tuple: + + print("Enter tuned_flash_sdpa") + + N, Hq, H, L, S, E, Ev, layout, query, key, value = flash_sdpa_args(query, key, value) + mlir_template = MLIRFlashSDPATemplate([query, key, value], layout, scale) + + # _scaled_dot_product_flash_attention has to return a tuple which has 9 values + # since its backward(_scaled_dot_product_flash_attention_backward) needs that values. + # (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) + return (mlir_template.generate().output_node(), None, None, None, None, None, None, None, None) + def conv_layout( x: TensorBox, weight: TensorBox, @@ -188,4 +209,6 @@ def custom_unsafe_index(x, indices): lowerings.update({getattr(aten._sparse_addmm, overload): sparse_addmm for overload in aten._sparse_addmm.overloads()}) lowerings.update({getattr(aten._unsafe_index, overload): custom_unsafe_index for overload in aten._unsafe_index.overloads()}) if extension_config.CONFIG_USE_TIMING_POOLING: - lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template \ No newline at end of file + lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template + +lowerings.update({getattr(aten._scaled_dot_product_flash_attention, overload): tuned_flash_sdpa for overload in aten._scaled_dot_product_flash_attention.overloads()}) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py new file mode 100644 index 00000000..b3d88cc6 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py @@ -0,0 +1,664 @@ +import math # sqrt +import sympy + +from typing import List, Optional + +import torch +from torch import empty_strided +from torch._inductor.ir import IRNode, TensorBox, FixedLayout +from torch._inductor.virtualized import V +from torch._inductor.select_algorithm import realize_inputs +from torch.backends.cuda import flash_sdp_enabled, mem_efficient_sdp_enabled + +from PyTorchSimFrontend import extension_config +from PyTorchSimFrontend.mlir import mlir_common +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel + + +def flash_sdpa_args( + query : TensorBox, + key : TensorBox, + value : TensorBox) -> list: + """ + Arg processing for flash SDPA. + Its logic is based on: + mm_args() which is in torch._inductor.kernel.mm_common.py (142 line). + """ + + # Materialize input buffers for the codegen backend. + query, key, value = realize_inputs(query, key, value) + + # query : (n, hq, l, e) + # key : (n, h, s, e) + # value : (n, h, s, ev) + # out : (n, hq, l, ev) + # n: Batch size + # hq: query's head counts, h: key and value's head counts. + # l: target sequence lenght and s: source sequence length. + # e: embeding dimension of the query and key and ev: embeding dimension of the value. + nq, hq, l, eq = query.get_size() + nk, hk, sk, ek = key.get_size() + nk, hv, sv, ev = value.get_size() + + n = V.graph.sizevars.guard_equals(nq, nk) + n = V.graph.sizevars.guard_equals(nq, nk) + + h = V.graph.sizevars.guard_equals(hk, hv) + s = V.graph.sizevars.guard_equals(sk, sv) + e = V.graph.sizevars.guard_equals(eq, ek) + + # While there are no theoretical requirements for e == ev, + # this implementation enforces e == ev for simplicity. + # Distinct notations are still maintained to ensure future compatibility and clarity. + if e != ev: + raise NotImplementedError("Flash SDPA does not support mismatched head dimensions between query and value.") + + # Flash attention does not split tiles along the head dimension (e or ev). + # Therefore, the head dimension size must be less than or equal to the number of vlanes. + vector_lane = extension_config.vpu_num_lanes + if e > vector_lane or ev > vector_lane: + raise ValueError(f"The head dimension size must be less than or equal to the number of vlanes (e: {e}, ev: {ev}, vlanes: {vector_lane}).") + + # The aten._scaled_dot_product_flash_attention kernel does not accept an explicit enable_gqa parameter. + # Instead, the Flash SDPA implementation infers GQA usage by checking if hq != hk. + # The Flash SDPA for GQA will be implemented after implementing its native version. + if hq != h : + raise NotImplementedError("Flash SDPA for GQA is not supported yet.") + + layout = FixedLayout( + query.get_device(), + query.get_dtype(), + [n, hq, l, ev] + ) + + return [n, hq, h, l, s, e, ev, layout, query, key, value] + +def validate_sdpa_input( + query : torch.Tensor, + key : torch.Tensor, + value : torch.Tensor, + attn_mask : torch.Tensor = None, + dropout_p : float = 0.0, + is_casual : bool = False, + scale : float = None, + enable_gqa : bool = False) -> None: + """ + Validates input tensors and parameters for Scaled Dot Product Attention (SDPA). + This function's logic can be found in: + https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/attention.cpp(504 line) + https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + """ + + # Tensor class, dtype, and device consistency + # Ensure all primary inputs are torch.Tensors + if not all(isinstance(t, torch.Tensor) for t in [query, key, value]): + raise TypeError( + f"Expected query, key and value to be Tensors, but got " + f"{type(query).__name__}, {type(key).__name__}, and {type(value).__name__}." + ) + + # Check for dtype mismatch + if query.dtype != key.dtype or query.dtype != value.dtype: + raise TypeError( + f"Expected query, key, and value to have the same dtype, " + f"but got {query.dtype}, {key.dtype}, and {value.dtype}." + ) + + # Check for device mismatch (e.g., mixing CPU and NPU) + if query.device != key.device or query.device != value.device: + raise ValueError( + f"Expected query, key, and value to be on the same device, " + f"but got {query.device}, {key.device}, and {value.device}." + ) + + # Shape and dimension validation + # SDPA typically expects 4D (B, H, S, D), but we check for at least 2D here + if any(t.dim() < 2 for t in [query, key, value]): + raise ValueError( + f"Expected query, key, and value to be at least 2D, " + f"but got Q:{query.dim()}D, K:{key.dim()}D, V:{value.dim()}D." + ) + + # Attention mask validation + if attn_mask is not None: + if not isinstance(attn_mask, torch.Tensor): + raise TypeError(f"Expected attn_mask to be a Tensor, but got {type(attn_mask).__name__}.") + + # Dtype check: floating point masks must match query dtype; bool masks are also allowed + if attn_mask.dtype.is_floating_point: + if attn_mask.dtype != query.dtype: + raise TypeError(f"Floating point attn_mask must match query dtype ({query.dtype}), but got {attn_mask.dtype}.") + elif attn_mask.dtype != torch.bool: + raise TypeError(f"attn_mask must be floating point or bool, but got {attn_mask.dtype}.") + + # Nested tensor limitation with explicit masking + if query.is_nested or key.is_nested: + raise ValueError("Nested tensors are not supported when an explicit attn_mask is set.") + + # Dropout and causal flag validation (added) + # Dropout probability must be in the range [0, 1) + if not (0.0 <= dropout_p < 1.0): + raise ValueError(f"Expected dropout_p to be in [0, 1), but got {dropout_p}.") + + # Mutual exclusivity: cannot use both explicit mask and causal flag (added) + if is_casual and attn_mask is not None: + raise ValueError("Both attn_mask and is_casual cannot be set at the same time.") + + # Scaling factor validation (added) + if scale is not None and scale <= 0.0: + raise ValueError(f"Expected scale to be a positive number, but got {scale}.") + + # GQA (Grouped Query Attention) constraints (added) + n_head_q = query.size(1) + n_head_k = key.size(1) + n_head_v = value.size(1) + + # The aten._scaled_dot_product_flash_attention kernel does not accept an explicit enable_gqa parameter. + # Instead, the Flash SDPA implementation infers GQA usage by checking if n_head_q != n_head_k. + if not enable_gqa and n_head_q != n_head_k: + raise ValueError(f"Query and Key must have the same number of heads when enable_gqa is false (Q:{n_head_q} vs K:{n_head_k}).") + + if enable_gqa: + if n_head_q == n_head_k: + raise ValueError(f"enable_gqa Query and Key ") + + if n_head_k != n_head_v: + raise ValueError(f"Key and Value must have the same number of heads (K:{n_head_k} vs V:{n_head_v}).") + + # Query heads must be an integer multiple of key heads for grouping + if n_head_q % n_head_k != 0: + raise ValueError( + f"Number of query heads ({n_head_q}) must be divisible by " + f"number of key heads ({n_head_k}) for GQA." + ) + +def convert_boolean_attn_mask(attn_mask: torch.Tensor, target_dtype: torch.dtype) -> float: + """ + Equivalent to the C++ 'convert_boolean_attn_mask' function. + Converts a boolean mask to a floating-point mask for SDPA. + """ + + if attn_mask is not None and attn_mask.dtype == torch.bool: + + new_mask = torch.zeros_like(attn_mask, dtype=target_dtype) + minus_inf = torch.finfo(target_dtype).min + new_mask.masked_fill_(attn_mask.logical_not(), minus_inf) + + return new_mask + + return attn_mask + +def calculate_scale(query: torch.Tensor, scale: float) -> float: + """ + Calculate the scaling factor based on the head dimension if scale is None + Otherwise, use the provided scale. + """ + if scale is None: + return 1.0 / math.sqrt(query.size(-1)) + else: + return scale + +def patched_scaled_dot_product_attention( + query_ : torch.Tensor, + key : torch.Tensor, + value : torch.Tensor, + dropout_p : float = 0.0, + is_casual : bool = False, + attn_mask_ : torch.Tensor = None, + scale_ : float = None, + enable_gqa : bool = None, + orig_fn = torch._C._nn.scaled_dot_product_attention) -> torch.Tensor : + """ + Custom patch for Scaled Dot Product Attention (SDPA) to intercept high-level calls. + For NPU devices, it redirects execution to specific ATen kernels based on global flags. + For all devices, it maintains parity with the original dispatcher logic found in: + https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/attention.cpp + + This function acts as a custom override that replaces the default PyTorch SDPA implementation, + invoked via 'PyTorchSim/PyTorchSimDevice/torch_openreg/openreg/__init__.py'. + """ + + # Device-specific Dispatching: redirect to specialized kernels if on NPU + if "npu" in str(query_.device): + + validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_casual, scale_, enable_gqa) + attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype) + + # Kernel selection logic: emulate C++ dispatcher priority + # Selection priority(can be changed): flash attention > memory efficient > math (cuDNN is not supported) + aten = torch.ops.aten + scale = calculate_scale(query_, scale_) + + if flash_sdp_enabled(): + # Skip padding query, key and value for alignment. + dispatch_kwargs = { + "dropout_p" : dropout_p, + "is_causal" : is_casual, + "return_debug_mask" : False, + "scale" : scale + } + + out_lse_softmax = aten._scaled_dot_product_flash_attention( + query_, key, value, **dispatch_kwargs + ) + + return out_lse_softmax[0] + elif mem_efficient_sdp_enabled(): + # out_and_lse = aten._scaled_dot_product_efficient_attention(...) + # return out_and_lse[0] + raise NotImplementedError("Memory efficient SDPA is not implemented yet.") + else: + dispatch_kwargs = { + "attn_mask" : attn_mask, + "dropout_p" : dropout_p, + "is_causal" : is_casual, + "dropout_mask" : None, + "scale": scale, + "enable_gqa" : enable_gqa + } + + out_lse_softmax = aten._scaled_dot_product_attention_math( + query_, + key, + value, + **dispatch_kwargs) + + return out_lse_softmax[0] + else: + # Fallback: Delegate to the original C++ Dispatcher for other devices + return orig_fn(query_, key, value) + +FLASH_SDPA_TEMPLATE = r""" +// SDPA kernel +// b = {{ b }} +// l = {{ l }} +// s = {{ s }} +// e = {{ e }} +// tile_l = {{ tile_l }} +// tile_s = {{ tile_s }} +// tile_e = {{ tile_e }} +// subtile_l = {{ subtile_l }} +// subtile_s = {{ subtile_s }} +// subtile_e = {{ subtile_e }} +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[query, key, value], outputs=[out], names_str="query, key, value, out", input_reorder=input_reorder)}} { + // Inputs + {{ kernel.def_sram_buffer("query", q_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("key", k_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("value", v_tile_desc, indent_size=2) }} + + // Output + {{ kernel.def_sram_buffer("out", out_tile_desc, indent_size=2) }} + + // Intermediate buffers + {{ kernel.def_sram_buffer("mul", mul_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("max", max_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("sum", sum_desc, indent_size=2) }} + + // Constants + %c0 = arith.constant 0.0 : {{ data_stype }} + %c1 = arith.constant 1.0 : {{ data_stype }} + %c_scale = arith.constant {{ scale }} : {{ data_stype }} + %c_neg_inf = arith.constant -1.0e+30 : {{ data_stype }} + + %v0_c = arith.constant dense<0.0> : vector<{{ chunk_size }}x{{ data_stype }}> + %v0_l = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(tile_l, tile_e) }}x{{ data_stype }}> + %v0_s = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(tile_s, tile_l) }}x{{ data_stype }}> + %v0_2x = arith.constant dense<0.0> : vector<2x{{ data_stype }}> + + %v_neg_inf_c = arith.constant dense<-1.0e+30> : vector<{{ chunk_size }}x{{ data_stype }}> + %v_neg_inf_2x = arith.constant dense<-1.0e+30> : vector<2x{{ data_stype }}> + + %v_scale = vector.broadcast %c_scale : {{ data_stype }} to vector<{{ tile_s }}x{{ data_stype }}> + + {{ kernel.def_local_vars(indent_size=2) }} + + affine.for %index0 = 0 to {{ b }} { + affine.for %index3 = 0 to 1 step 1 { + affine.for %index1 = 0 to {{ l }} step {{ tile_l }} { + {{ kernel.def_dma_op("MVIN", "query", q_idx, q_tile_desc, subtile_size=[1, subtile_l, subtile_e], indent_size=8) }} + + affine.vector_store %v0_l, %out_buffer[0, 0, 0] : {{ out_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ kernel.get_spad_size_per_lane(tile_l, tile_e) }}x{{ data_stype }}> + affine.vector_store %v_neg_inf_2x, %max_buffer[0, 0] : {{ max_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + affine.vector_store %v0_2x, %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + + %qt_buffer2D = memref.reinterpret_cast %q_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_l }}], strides: [{{ tile_l }}, 1] : {{ q_tile_desc.get_mlir_shape(data_stype) }} to memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1> + %ot_buffer2D = memref.reinterpret_cast %out_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_l }}], strides: [{{ tile_l }}, 1] : {{ out_tile_desc.get_mlir_shape(data_stype) }} to memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1> + + affine.for %index2 = 0 to {{ s }} step {{ tile_s }} { + {{ kernel.def_dma_op("MVIN", "key", k_idx, k_tile_desc, subtile_size=[1, subtile_s, subtile_e], indent_size=10) }} + {{ kernel.def_dma_op("MVIN", "value", v_idx, v_tile_desc, subtile_size=[1, subtile_s, subtile_e], indent_size=10) }} + + affine.vector_store %v0_s, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ kernel.get_spad_size_per_lane(tile_s, tile_l) }}x{{ data_stype }}> + + %k_buffer2D = memref.reinterpret_cast %k_buffer to offset: [0], sizes: [{{ tile_s }}, {{ tile_e }}], strides: [{{ tile_e }}, 1] : {{ k_tile_desc.get_mlir_shape(data_stype) }} to memref<{{ tile_s }}x{{ tile_e }}x{{ data_stype }}, 1> + %vt_buffer2D = memref.reinterpret_cast %v_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_s }}], strides: [{{ tile_s }}, 1] : {{ v_tile_desc.get_mlir_shape(data_stype) }} to memref<{{ tile_e }}x{{ tile_s }}x{{ data_stype }}, 1> + + + // key @ query.t and scaling. + linalg.matmul + ins(%k_buffer2D, %qt_buffer2D : memref<{{ tile_s }}x{{ tile_e }}x{{ data_stype }}, 1>, memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>) + outs(%mul_buffer : {{ mul_tile_desc.get_mlir_shape(data_stype) }}) + + %raw_mul_vec = affine.vector_load %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ tile_s }}x{{ data_stype }}> + %scaled_mul_vec = arith.mulf %raw_mul_vec, %v_scale : vector<{{ tile_s }}x{{ data_stype }}> + affine.vector_store %scaled_mul_vec, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ tile_s }}x{{ data_stype }}> + + + // Find new max. + %old_max = affine.vector_load %max_buffer[0,0] : {{ max_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + + %chunk_max_res = affine.for %index5 = 0 to {{ tile_s }} step {{ chunk_size }} iter_args(%iter_max=%v_neg_inf_c) -> (vector<{{ chunk_size }}x{{ data_stype }}>) { + %chunk_val = affine.vector_load %mul_buffer[0, %index5] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ chunk_size }}x{{ data_stype }}> + %local_max = arith.maximumf %chunk_val, %iter_max : vector<{{ chunk_size }}x{{ data_stype }}> + affine.yield %local_max : vector<{{ chunk_size }}x{{ data_stype }}> + } + + %max_cast = vector.shape_cast %chunk_max_res : vector<{{ chunk_size }}x{{ data_stype }}> to vector<{{ chunk_size // 2 }}x2x{{ data_stype }}> + %max_reduced_1 = vector.multi_reduction , %max_cast, %v_neg_inf_2x [0] : vector<8x2x{{ data_stype }}> to vector<2x{{ data_stype }}> + %max_shuffled = vector.shuffle %max_reduced_1, %max_reduced_1 [1, 0] : vector<2x{{ data_stype }}>, vector<2x{{ data_stype }}> + %max_reduced_2 = arith.maximumf %max_reduced_1, %max_shuffled : vector<2x{{ data_stype }}> + + %new_max = arith.maximumf %max_reduced_2, %old_max : vector<2x{{ data_stype }}> + affine.vector_store %new_max, %max_buffer[0, 0] : {{ max_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + + + // Compute rescale factors: exp(old_max - new_max) + %max_diff = arith.subf %old_max, %new_max : vector<2x{{ data_stype }}> + %max_diff_scalar = vector.extract %max_diff[0] : {{ data_stype }} from vector<2x{{ data_stype }}> + + %rescale_bcast_e = vector.broadcast %max_diff_scalar : {{ data_stype }} to vector<{{ tile_e }}x{{ data_stype }}> + %exp_rescale_e = math.exp %rescale_bcast_e : vector<{{ tile_e }}x{{ data_stype }}> + + %rescale_bcast_2 = vector.broadcast %max_diff_scalar : {{ data_stype }} to vector<2x{{ data_stype }}> + %exp_rescale_2 = math.exp %rescale_bcast_2 : vector<2x{{ data_stype }}> + + + // Rescale previous out and sum accumulators + %old_out = affine.vector_load %ot_buffer2D[0, 0] : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>, vector<{{ tile_e }}x{{ data_stype }}> + %rescaled_out = arith.mulf %exp_rescale_e, %old_out : vector<{{ tile_e }}x{{ data_stype }}> + affine.vector_store %rescaled_out, %ot_buffer2D[0, 0] : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>, vector<{{ tile_e }}x{{ data_stype }}> + + %old_sum = affine.vector_load %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + %rescaled_sum = arith.mulf %old_sum, %exp_rescale_2 : vector<2x{{ data_stype }}> + + + // Shift scores and apply exp: exp(x - new_max) + %scaled_scores_reload = affine.vector_load %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ tile_s }}x{{ data_stype }}> + %new_max_scalar = vector.extract %new_max[0] : {{ data_stype }} from vector<2x{{ data_stype }}> + %new_max_bcast = vector.broadcast %new_max_scalar : {{ data_stype }} to vector<{{ tile_s }}x{{ data_stype }}> + + %shifted_scores = arith.subf %scaled_scores_reload, %new_max_bcast : vector<{{ tile_s }}x{{ data_stype }}> + %exp_scores = math.exp %shifted_scores : vector<{{ tile_s }}x{{ data_stype }}> + affine.vector_store %exp_scores, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ tile_s }}x{{ data_stype }}> + + + // accumulate current sum + %chunk_sum_res = affine.for %index5 = 0 to {{ tile_s }} step {{ chunk_size }} iter_args(%iter_sum=%v0_c) -> (vector<{{ chunk_size }}x{{ data_stype }}>) { + %chunk_exp = affine.vector_load %mul_buffer[0, %index5] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ chunk_size }}x{{ data_stype }}> + %local_sum = arith.addf %chunk_exp, %iter_sum : vector<{{ chunk_size }}x{{ data_stype }}> + affine.yield %local_sum : vector<{{ chunk_size }}x{{ data_stype }}> + } + + %zero_2x = vector.broadcast %c0 : {{ data_stype }} to vector<2x{{ data_stype }}> + %sum_cast = vector.shape_cast %chunk_sum_res : vector<{{ chunk_size }}x{{ data_stype }}> to vector<{{ chunk_size // 2 }}x2x{{ data_stype }}> + %sum_reduced_1 = vector.multi_reduction , %sum_cast, %zero_2x [0] : vector<8x2x{{ data_stype }}> to vector<2x{{ data_stype }}> + %sum_shuffled = vector.shuffle %sum_reduced_1, %sum_reduced_1 [1, 0] : vector<2x{{ data_stype }}>, vector<2x{{ data_stype }}> + %sum_reduced_2 = arith.addf %sum_reduced_1, %sum_shuffled : vector<2x{{ data_stype }}> + + %new_sum = arith.addf %sum_reduced_2, %rescaled_sum : vector<2x{{ data_stype }}> + affine.vector_store %new_sum, %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + + + // value.t @ mul + linalg.matmul + { idx_map = array } + ins(%vt_buffer2D, %mul_buffer : memref<{{ tile_e }}x{{ tile_s }}x{{ data_stype }}, 1>, {{ mul_tile_desc.get_mlir_shape(data_stype) }}) + outs(%ot_buffer2D : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>) + } + + // out @ row_sum^(-1) + %final_row_sum = affine.vector_load %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + %one_2x = vector.broadcast %c1 : {{ data_stype }} to vector<2x{{ data_stype }}> + + %reciprocal_row_sum_2x = arith.divf %one_2x, %final_row_sum : vector<2x{{ data_stype }}> + %reciprocal_scalar = vector.extract %reciprocal_row_sum_2x[0] : {{ data_stype }} from vector<2x{{ data_stype }}> + %reciprocal_bcast_e = vector.broadcast %reciprocal_scalar : {{ data_stype }} to vector<{{ tile_e }}x{{ data_stype }}> + + %accumulated_out = affine.vector_load %ot_buffer2D[0, 0] : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>, vector<{{ tile_e }}x{{ data_stype }}> + %stable_final_out = arith.mulf %accumulated_out, %reciprocal_bcast_e : vector<{{ tile_e }}x{{ data_stype }}> + affine.vector_store %stable_final_out, %ot_buffer2D[0, 0] : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>, vector<{{ tile_e }}x{{ data_stype }}> + + {{ kernel.store_output(indent_size=8) }} + } { accumulation_loop=true } + } { outer_loop=true } + } { outer_loop=true } + return +} +""" + +class MLIRFlashSDPATemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, scale, input_reorder=None): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.scale = scale + + def render(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + prologue_nodes: Optional[List[IRNode]] = None, + tile_info = None, + **kwargs): + + # Except for kernel, other arguments are usually None. + query, key, value, out, q_tensor, k_tensor, v_tensor, out_tensor, b, l, s, e, ev, n_extra_node, n_prologue_node = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) + + if tile_info is None: + tile_l, tile_s, tile_e, subtile_l, subtile_s, subtile_e = self.select_tile(kernel, l, s, e, n_extra_node, 0, n_prologue_node)[0] + else: + tile_l, tile_s, tile_e, subtile_l, subtile_s, subtile_e = tile_info + + TOG_latency = l if tile_l > l else tile_l + kernel.loop_size = [TOG_latency, tile_s, tile_e] + + # Select template code + # Other templates will be added according to situations. + nr_reduction_nodes = [node for node in epilogue_nodes if node.is_reduction()] if epilogue_nodes is not None else [] + if nr_reduction_nodes: + raise NotImplementedError("FLASH_SDPA_REDUCTION_TEMPLATE is not implemented yet.") + elif prologue_nodes: + raise NotImplementedError("FLASH_SDPA_PROLOGUE_TEMPLATE is not implemented yet.") + else: + template = FLASH_SDPA_TEMPLATE + epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2", "index3": "index3"} + nr_rdim = 0 + + # Prepare tile descriptors for input and output tensors. + # Intermediate buffers (transient data) do not require DRAM settings(dram stride and dram indices) + # as they are not synchronized with external DRAM. + # DRAM and SRAM tile shapes must match. + vlane_stride = 1 + + # (n, l, s, e, ev) + loop_dim = [sympy.Symbol("index0"), sympy.Symbol("index1"), sympy.Symbol("index2"), sympy.Symbol("index3")] + + + # Hardware constraint: The tile split axis is restricted. + # To accommodate this, we compute (key @ query.t) instead of (query @ key.t). + # SRAM settings + vlane_split_axis = 1 + q_tile_size = [1, tile_l, tile_e] + q_tile_stride = [0, tile_e, 1] + q_tile_desc = mlir_common.MLIRMultiDimTile(q_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + q_tile_desc.set_tile_size_stride(q_tile_size, q_tile_stride) + q_tile_desc.set_name("q_buffer") + q_tile_desc.offset = query.get_layout().offset + # DRAM settings + q_stride = q_tensor.stride() + q_idx = [loop_dim[0]*q_stride[0], loop_dim[1]*q_stride[1], loop_dim[3]*q_stride[2]] # To keep index arguemnt order, we used index_list + + # Since we use a weight-stationary approach in the Systolic Array (SA), + # the split axis of the first operand differs from a standard linear algebra matmul. + # The first operand (key) must be split along the column axis. + # This logic aligns with the relationship between the dot product's summation direction and the hardware's accumulation direction in the SA. + # SRAM settings + vlane_split_axis = 2 + k_tile_size = [1, tile_s, tile_e] + k_tile_stride = [0, 1, tile_s] + k_tile_desc = mlir_common.MLIRMultiDimTile(k_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + k_tile_desc.set_tile_size_stride(k_tile_size, k_tile_stride) + k_tile_desc.set_name("k_buffer") + k_tile_desc.offset = key.get_layout().offset + # DRAM settings + k_stride = k_tensor.stride() + k_idx = [loop_dim[0]*k_stride[0], loop_dim[2]*k_stride[1], loop_dim[3]*k_stride[2]] + + # Since we compute mul = key @ query.t, we perform out.t = (value.t @ Softmax(mul).t).t, + # which simplifies to (value.t @ Softmax(mul)) + # SRAM settings + vlane_split_axis = 1 + v_tile_size = [1, tile_s, tile_e] + v_tile_stride = [0, tile_e, 1] + v_tile_desc = mlir_common.MLIRMultiDimTile(v_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + v_tile_desc.set_tile_size_stride(v_tile_size, v_tile_stride) + v_tile_desc.set_name("v_buffer") + v_tile_desc.offset = value.get_layout().offset + # DRAM settings + v_stride = v_tensor.stride() + v_idx = [loop_dim[0]*v_stride[0], loop_dim[2]*v_stride[1], loop_dim[3]*v_stride[2]] # To keep index arguemnt order, we used index_list + + # Output is also stored in transposed format to match the value.t @ Softmax(mul) operation. + # SRAM settings + vlane_split_axis = 1 + out_tile_size = [1, tile_l, tile_e] + out_tile_stride=[0, tile_e, 1] + out_tile_desc = mlir_common.MLIRMultiDimTile(out_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + out_tile_desc.set_tile_size_stride(out_tile_size, out_tile_stride) + out_tile_desc.set_name("out_buffer") + # DRAM settings + out_stride = out.get_layout().stride[1:] + out_idx = [loop_dim[0]*out_stride[0], loop_dim[1]*out_stride[1], loop_dim[3]*out_stride[2]] + + # Intermediate buffers + + # For mul = key @ query.t + vlane_split_axis = 1 + mul_tile_size = [tile_s, tile_l] + mul_tile_stride = [tile_l, 1] + mul_tile_desc = mlir_common.MLIRMultiDimTile(mul_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + mul_tile_desc.set_tile_size_stride(mul_tile_size, mul_tile_stride) + mul_tile_desc.set_name("mul_buffer") + #FIXME. What is the offset? -> It doesn't matter at this time. + + # For storing maximum values per row + vlane_split_axis = 0 + max_size = [tile_l, 2] + max_stride = [2, 1] + max_desc = mlir_common.MLIRMultiDimTile(max_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + max_desc.set_tile_size_stride(max_size, max_stride) + max_desc.set_name("max_buffer") + + # For storing summation per row + vlane_split_axis = 0 + sum_size = [tile_l, 2] + sum_stride = [2, 1] + sum_desc = mlir_common.MLIRMultiDimTile(sum_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + sum_desc.set_tile_size_stride(sum_size, sum_stride) + sum_desc.set_name("sum_buffer") + + # For reduction + chunk_size = 16 + + kernel.render_options = dict( + KERNEL_NAME = self.name, + kernel = kernel, + b = b, + l = l, + s = s, + e = e, # Input sizes (dram) + tile_l = tile_l, + tile_s = tile_s, + tile_e = tile_e, # Tile sizes (sram) + subtile_l = subtile_l, + subtile_s = subtile_s, + subtile_e = subtile_e, # Subtile sizes (sram) + data_stype="f32", + query = query, + key = key, + value = value, + out = out, # Inputs and output (dram) + q_idx = q_idx, + k_idx = k_idx, + v_idx = v_idx, + out_idx = out_idx, # Strides (dram) + q_tile_desc = q_tile_desc, + k_tile_desc = k_tile_desc, + v_tile_desc = v_tile_desc, + mul_tile_desc = mul_tile_desc, + out_tile_desc = out_tile_desc, # Tile descriptions (sram) + max_desc = max_desc, + sum_desc = sum_desc, # Intermediate buffer descriptions (sram) + scale = self.scale, + chunk_size = chunk_size, + input_reorder = self.input_reorder # ETC + ) + + kernel.epilogue_info = dict( + output_node = self.output_node.name, + sram_var = "out_buffer", + dram_var = "out", + dram_idx = out_idx, + dram_tile_desc = out_tile_desc, + nr_rdim = nr_rdim, + r_dim_size = 0, + dim_aliasing = epilogue_dim_aliasing + ) + + code = self._template_from_string(template).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["l"], kernel.render_options["s"], kernel.render_options["e"]], [kernel.render_options["tile_l"], kernel.render_options["tile_s"], kernel.render_options["tile_e"]]) + return code + + def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): + if template_buffer_node is not None: + self.output_node = template_buffer_node + + query = self.input_nodes[0] + key = self.input_nodes[1] + value = self.input_nodes[2] + out = self.output_node + + q_tensor = empty_strided(query.layout.size, query.layout.stride) + k_tensor = empty_strided(key.layout.size, key.layout.stride) + v_tensor = empty_strided(value.layout.size, value.layout.stride) + out_tensor = empty_strided(out.layout.size, out.layout.stride) + + # Flatten batch and head dimensions (n, h) into a single dimension (b = n*h) + q_tensor = q_tensor.view([-1, q_tensor.shape[-2], q_tensor.shape[-1]]) + k_tensor = k_tensor.view([-1, k_tensor.shape[-2], k_tensor.shape[-1]]) + v_tensor = v_tensor.view([-1, v_tensor.shape[-2], v_tensor.shape[-1]]) + out_tensor = out_tensor.view([-1, out_tensor.shape[-2], out_tensor.shape[-1]]) + + b, l, s, e, ev = q_tensor.size(0), q_tensor.size(1), k_tensor.size(1), k_tensor.size(2), v_tensor.size(2) + + n_extra_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 + n_prologue_node = len(prologue_nodes) if prologue_nodes is not None else 0 + + return query, key, value, out, q_tensor, k_tensor, v_tensor, out_tensor, b, l, s, e, ev, n_extra_node, n_prologue_node + + # Reuse the existing function in MLIRBMMTemplate. + def select_tile(self, kernel, l, s, e, n_extra_node, n_extra_read, n_prologue_node): + + # FIXME: Update the method for getting tile candidates once TestDmaFineGrained oass works correctly with Flash Attention. + # tile_candidates = kernel.flash_sdpa_mapping(l, s, e, n_extra_node=n_extra_node) + tile_candidates = [[kernel.vector_lane, kernel.vector_lane, e]] + + for idx, (tile_l, tile_s, tile_e) in enumerate(tile_candidates): + subtile_l = tile_l if (tile_l < kernel.vector_lane) or n_prologue_node else kernel.vector_lane + subtile_s = tile_s # if (tile_s < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + subtile_e = tile_e # if (tile_e < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + + tile_candidates[idx] = tile_l,tile_s,tile_e,subtile_l,subtile_s,subtile_e + + return tile_candidates diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index b864e5f2..23f5e3dc 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -387,6 +387,100 @@ def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilatio tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) tile_candidates = [v for _, v in tile_candidates] return tile_candidates + + # Flash Attention requires more SRAM compared to standard GEMM. + # Total buffers needed: query, key, value, out, mul, max, sum + # Tensor Shapes: + # query (tile_l, tile_e), key (tile_s, tile_e), value (tile_s, tile_e), mul (tile_s, tile_l), out(tile_l, tile_e) + # max, sum : (tile_l, 2) + def flash_sdpa_mapping(self, l, s, e, n_extra_node=0, n_prologue_node=0, pad_e=True, min_tile=False, is_conv=False): + tile_candidates = [] + + spad_size_per_lane = self.spad_info["spad_size"] + spad_size = spad_size_per_lane * self.vector_lane + + # Double buffering + max_spad_per_lane = spad_size_per_lane // 2 + max_spad_size = spad_size // 2 + + # Padding for utilization + minimum_tile_size = 8 + minimum_n_tile = self.num_cores if min_tile else 1 + l_pad_factor = self.vector_lane if l > self.vector_lane else minimum_tile_size + s_pad_factor = self.vector_lane if s > self.vector_lane else minimum_tile_size + + pad = lambda x, factor: ((x + factor - 1) // factor) * factor + l_padded = pad(l, l_pad_factor) + s_padded = pad(s, s_pad_factor) + + # Calculate the total number of vector-sized blocks + l_idx = l_padded // self.vector_lane + s_idx = s_padded // self.vector_lane + + # Generate candidates for the number of blocks per tile + l_tile_range = sympy.divisors(l_idx) if l > self.vector_lane else [1] + s_tile_range = sympy.divisors(s_idx) if s > self.vector_lane else [1] + + # Convert block count to actual tile size + maximize_i_j = 1 + max_used_spad_size = 0 + + # Flash Attention does not tile along the head dimension (e or ev). + tile_e = e + + for i in l_tile_range: + tile_l = i * self.vector_lane if l > self.vector_lane else l_padded + for j in s_tile_range: + tile_s = j * self.vector_lane if s > self.vector_lane else s_padded + + # Calculate used spad size + used_spad_size = ( + tile_l * tile_e * (1 + n_prologue_node) # query + + tile_s * tile_e # key + + tile_s * tile_e # value + + tile_s * tile_l # mul + + tile_l * tile_e * (1 + n_extra_node) # out + + (tile_l * 2) * 2 # max, sum + ) * self.precision + + # Calculate used spad size per lane. + query_per_lane = tile_e * (1+n_prologue_node) + key_per_lane = tile_s + value_per_lane = tile_e + mul_per_lane = tile_s + out_per_lane = tile_e * (1 + n_extra_node) + vec_per_lane = 2 * 2 + + used_spad_per_lane = ( + query_per_lane + + key_per_lane + + value_per_lane + + mul_per_lane + + out_per_lane + + vec_per_lane + ) * self.precision + + # Add the validated candidate to the list if it passes all hardware constraints. + n_tile = math.ceil(l / max(tile_l, 128)) * math.ceil(s / max(tile_s, 128)) + check_spad_size = (used_spad_size < max_spad_size and used_spad_per_lane < max_spad_per_lane) + + if (check_spad_size + and max_used_spad_size < used_spad_size # SRAM utilization + and maximize_i_j <= tile_l * tile_s # Larger tile + and n_tile >= minimum_n_tile # Pallelism + and max(tile_s, 128) // max(tile_l, 128) < 10): # Balanced Shape + max_used_spad_size = used_spad_size + maximize_i_j = tile_l * tile_s + + if check_spad_size: + tile_candidates.append((used_spad_size, (tile_l, tile_s, tile_e))) + + # Sort by used_spad_size. + # tile_candidates[0] is the best solution we have. + tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) + tile_candidates = [v for _, v in tile_candidates] + + return tile_candidates def meta_kernel(self): kernel_arg_attributes = self.kernel_arg_attributes @@ -827,7 +921,12 @@ def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_com def def_sram_buffer(self, dram_name, tile_desc, id=0, indent_size=0): # Prepare code block with self: - dtype = self.named_nodes[dram_name].get_layout().dtype + try: + dtype = self.named_nodes[dram_name].get_layout().dtype + except (KeyError, AttributeError, TypeError): + import torch + dtype = torch.float32 + tile_shape = tile_desc.get_mlir_shape(mlir_common.DTYPE_TO_MLIR[dtype]) buffer_name = self.allocate_sram_buffer(dtype, dram_name, tile_desc, id, forced_name=dram_name) code = f"%{tile_desc.name} = memref.get_global @{buffer_name} : {tile_shape}" diff --git a/tests/test_sdpa.py b/tests/test_sdpa.py new file mode 100644 index 00000000..9c921eb4 --- /dev/null +++ b/tests/test_sdpa.py @@ -0,0 +1,84 @@ +import sys +import math +import torch +import inspect +from typing import List +import torch.nn.functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.fx.passes.graph_drawer import FxGraphDrawer +from torch._inductor.decomposition import decompositions + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + message = f"|{name} Test Passed|" + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def test_scaled_dot_product_attention(device, backends="flash"): + torch.manual_seed(0) + n_batch_list = [1, 4, 8, 16] + n_head_list = [1, 4, 8, 12] + n_token_list = [128, 256, 512, 1024] + head_dim_list = [32, 64, 128] + + for n_batch in n_batch_list: + for n_head in n_head_list: + for n_token in n_token_list: + for head_dim in head_dim_list: + # Inputs + query = torch.rand(n_batch, n_head, n_token, head_dim, dtype=torch.float32) + key = torch.rand(n_batch, n_head, n_token, head_dim, dtype=torch.float32) + value = torch.rand(n_batch, n_head, n_token, head_dim, dtype=torch.float32) + + query = query.to(device=device) + key = key.to(device=device) + value = value.to(device=device) + + # With NPU + if backends == "flash": + backends = [SDPBackend.FLASH_ATTENTION] + elif backends == "math": + backends = [SDPBackend.MATH] + elif backends == "memory_efficient": + backends = [SDPBackend.EFFICIENT_ATTENTION] + else: + backends = [SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION] + + with sdpa_kernel(backends=backends): + opt_fn = torch.compile(dynamic=False)(F.scaled_dot_product_attention) + out = opt_fn(query, key, value) + + out = out.to(device) + + # With CPU + device = torch.device('cpu') + query = query.to(device=device) + key = key.to(device=device) + value = value.to(device=device) + cpu_out = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) + + name = f"SDPA(n_batch: {n_batch}, n_head: {n_head}, n_token: {n_token}, head_dim: {head_dim})" + test_result(name, out, cpu_out) + + print("All tests passed!") + +def clear_caches(): + import os + from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache + from torch._inductor.codecache import FxGraphCache + AOTAutogradCache.clear() + torch._dynamo.reset() + os.environ["TORCHINDUCTOR_CACHE"] = "0" + FxGraphCache.clear() + +if __name__ == "__main__": + clear_caches() + + device = torch.device('npu:0') + test_scaled_dot_product_attention(device, backends="flash") + \ No newline at end of file From f615178ae581236a1b4d1018f9b458b2c552179f Mon Sep 17 00:00:00 2001 From: jung-min Date: Wed, 4 Mar 2026 07:57:47 +0000 Subject: [PATCH 2/4] [Fix] Prevent fallback to eager mode after reaching compilation limit (7) --- tests/test_sdpa.py | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/tests/test_sdpa.py b/tests/test_sdpa.py index 9c921eb4..6ffd6f2e 100644 --- a/tests/test_sdpa.py +++ b/tests/test_sdpa.py @@ -14,6 +14,7 @@ def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): print("-" * len(message)) print(message) print("-" * len(message)) + pass else: print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) @@ -31,35 +32,25 @@ def test_scaled_dot_product_attention(device, backends="flash"): for n_token in n_token_list: for head_dim in head_dim_list: # Inputs + clear_caches() query = torch.rand(n_batch, n_head, n_token, head_dim, dtype=torch.float32) key = torch.rand(n_batch, n_head, n_token, head_dim, dtype=torch.float32) value = torch.rand(n_batch, n_head, n_token, head_dim, dtype=torch.float32) + # With NPU query = query.to(device=device) key = key.to(device=device) value = value.to(device=device) - # With NPU - if backends == "flash": - backends = [SDPBackend.FLASH_ATTENTION] - elif backends == "math": - backends = [SDPBackend.MATH] - elif backends == "memory_efficient": - backends = [SDPBackend.EFFICIENT_ATTENTION] - else: - backends = [SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION] - - with sdpa_kernel(backends=backends): - opt_fn = torch.compile(dynamic=False)(F.scaled_dot_product_attention) - out = opt_fn(query, key, value) - + opt_fn = torch.compile(dynamic=False)(F.scaled_dot_product_attention) + out = opt_fn(query, key, value) out = out.to(device) # With CPU - device = torch.device('cpu') - query = query.to(device=device) - key = key.to(device=device) - value = value.to(device=device) + cpu_device = torch.device('cpu') + query = query.to(device=cpu_device) + key = key.to(device=cpu_device) + value = value.to(device=cpu_device) cpu_out = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) name = f"SDPA(n_batch: {n_batch}, n_head: {n_head}, n_token: {n_token}, head_dim: {head_dim})" @@ -76,9 +67,7 @@ def clear_caches(): os.environ["TORCHINDUCTOR_CACHE"] = "0" FxGraphCache.clear() -if __name__ == "__main__": - clear_caches() - +if __name__ == "__main__": device = torch.device('npu:0') test_scaled_dot_product_attention(device, backends="flash") \ No newline at end of file From 8ca5d02d599d06725b90963ee44701cb50e8f444 Mon Sep 17 00:00:00 2001 From: jung-min Date: Wed, 4 Mar 2026 08:09:28 +0000 Subject: [PATCH 3/4] [FIX] Add idx_map to the first matmul for logical consistency --- PyTorchSimFrontend/mlir/mlir_sdpa_template.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py index b3d88cc6..49c6c6bb 100644 --- a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py +++ b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py @@ -339,6 +339,7 @@ def patched_scaled_dot_product_attention( // key @ query.t and scaling. linalg.matmul + { idx_map = array } ins(%k_buffer2D, %qt_buffer2D : memref<{{ tile_s }}x{{ tile_e }}x{{ data_stype }}, 1>, memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>) outs(%mul_buffer : {{ mul_tile_desc.get_mlir_shape(data_stype) }}) @@ -451,7 +452,7 @@ def render(self, prologue_nodes: Optional[List[IRNode]] = None, tile_info = None, **kwargs): - + # Except for kernel, other arguments are usually None. query, key, value, out, q_tensor, k_tensor, v_tensor, out_tensor, b, l, s, e, ev, n_extra_node, n_prologue_node = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) From 3d9cb387b2ba27853efb983241fa4450c3174d9d Mon Sep 17 00:00:00 2001 From: jung-min Date: Thu, 5 Mar 2026 11:45:36 +0000 Subject: [PATCH 4/4] [Frontend/template] Connect SDPA template to NPU using Torch OpenReg --- PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp | 34 +--- PyTorchSimDevice/csrc/aten/native/Extra.cpp | 51 +---- .../torch_openreg/openreg/__init__.py | 4 +- PyTorchSimFrontend/mlir/mlir_lowering.py | 14 +- PyTorchSimFrontend/mlir/mlir_sdpa_template.py | 186 +----------------- 5 files changed, 14 insertions(+), 275 deletions(-) diff --git a/PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp b/PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp index 04ba6d48..f048f878 100644 --- a/PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp +++ b/PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -40,36 +41,6 @@ void wrapper_quantize_tensor_per_tensor_affine_stub( rtensor, qtensor, scale, zero_point); } -std::tuple< - at::Tensor, - at::Tensor, - at::Tensor, - at::Tensor, - c10::SymInt, - c10::SymInt, - at::Tensor, - at::Tensor, - at::Tensor> -wrapper__scaled_dot_product_fused_attention_overrideable( - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const std::optional& attn_bias, - double dropout_p, - bool is_causal, - bool return_debug_mask, - std::optional scale) { - return at::native::openreg::_scaled_dot_product_fused_attention_overrideable( - query, - key, - value, - attn_bias, - dropout_p, - is_causal, - return_debug_mask, - scale); -} - std::tuple wrapper_scaled_dot_product_fused_attention_overrideable_backward( const at::Tensor& grad_out, @@ -172,9 +143,6 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { m.impl("abs.out", &wrapper_abs_out); m.impl("quantize_per_tensor", &wrapper_quantize_per_tensor); m.impl("_fused_sdp_choice", &wrapper__fused_sdp_choice); - m.impl( - "_scaled_dot_product_fused_attention_overrideable", - &wrapper__scaled_dot_product_fused_attention_overrideable); m.impl( "_scaled_dot_product_fused_attention_overrideable_backward", &wrapper_scaled_dot_product_fused_attention_overrideable_backward); diff --git a/PyTorchSimDevice/csrc/aten/native/Extra.cpp b/PyTorchSimDevice/csrc/aten/native/Extra.cpp index 711d114c..aaf28e1a 100644 --- a/PyTorchSimDevice/csrc/aten/native/Extra.cpp +++ b/PyTorchSimDevice/csrc/aten/native/Extra.cpp @@ -19,7 +19,8 @@ int64_t _fused_sdp_choice( bool is_causal, std::optional scale, bool enable_gqa) { - auto backend = sdp::SDPBackend::math; + + auto backend = sdp::SDPBackend::overrideable; return static_cast(backend); } @@ -29,54 +30,6 @@ void quantize_tensor_per_tensor_affine_stub( double scale, int64_t zero_point) {} -std::tuple< - at::Tensor, - at::Tensor, - at::Tensor, - at::Tensor, - c10::SymInt, - c10::SymInt, - at::Tensor, - at::Tensor, - at::Tensor> -_scaled_dot_product_fused_attention_overrideable( - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const std::optional& attn_bias, - double dropout_p, - bool is_causal, - bool return_debug_mask, - std::optional scale) { - const int64_t batch_size = query.size(0); - const int64_t num_heads = query.size(1); - const int64_t head_dim_v = value.size(3); - const int64_t max_seqlen_q = query.size(2); - const int64_t max_seqlen_kv = key.size(2); - - auto opts = query.options(); - auto output = - at::empty({batch_size, num_heads, max_seqlen_q, head_dim_v}, opts); - auto logsumexp = - at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); - auto debug_attn_mask = at::empty( - {batch_size, num_heads, max_seqlen_q, max_seqlen_kv}, - opts.dtype(at::kFloat)); - auto philox_seed = at::empty({}, at::dtype(at::kLong)); - auto philox_offset = at::empty({}, at::dtype(at::kLong)); - - return std::make_tuple( - output, - logsumexp, - at::Tensor(), - at::Tensor(), - max_seqlen_q, - max_seqlen_kv, - philox_seed, - philox_offset, - debug_attn_mask); -} - std::tuple _scaled_dot_product_fused_attention_overrideable_backward( const at::Tensor& grad_out, diff --git a/PyTorchSimDevice/torch_openreg/openreg/__init__.py b/PyTorchSimDevice/torch_openreg/openreg/__init__.py index 5a0de6c3..9d10f90e 100644 --- a/PyTorchSimDevice/torch_openreg/openreg/__init__.py +++ b/PyTorchSimDevice/torch_openreg/openreg/__init__.py @@ -66,8 +66,8 @@ def _lazy_init(): return # Replace the global C++ binding with our custom dispatcher patch - from PyTorchSimFrontend.mlir.mlir_sdpa_template import patched_scaled_dot_product_attention - torch._C._nn.scaled_dot_product_attention = patched_scaled_dot_product_attention + # from PyTorchSimFrontend.mlir.mlir_sdpa_template import patched_scaled_dot_product_attention + # torch._C._nn.scaled_dot_product_attention = patched_scaled_dot_product_attention torch_openreg._C._init() register_interface_for_device(custom_device(), ExtensionDeviceInterface) diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index e09dcf57..a6b2478c 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -15,7 +15,7 @@ from PyTorchSimFrontend.mlir.mlir_conv_sb_template import MLIRConvSingleBatchTemplate from PyTorchSimFrontend.mlir.mlir_conv_sbs_template import MLIRConvSingleBatchStridedTemplate from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate -from PyTorchSimFrontend.mlir.mlir_sdpa_template import MLIRFlashSDPATemplate, flash_sdpa_args +from PyTorchSimFrontend.mlir.mlir_sdpa_template import MLIRFlashSDPATemplate, flash_sdpa_args, calculate_scale from PyTorchSimFrontend import extension_config aten = torch.ops.aten @@ -44,14 +44,16 @@ def tuned_flash_sdpa( query : TensorBox, key : TensorBox, value : TensorBox, - scale : float, + attn_bias : Optional[TensorBox] = None, dropout_p : float = 0.0, is_causal : bool = False, - return_debug_mask : bool =False) -> tuple: + return_debug_mask : bool = False, + scale : Optional[float] = None) -> tuple: - print("Enter tuned_flash_sdpa") - + + scale = calculate_scale(query, scale) N, Hq, H, L, S, E, Ev, layout, query, key, value = flash_sdpa_args(query, key, value) + mlir_template = MLIRFlashSDPATemplate([query, key, value], layout, scale) # _scaled_dot_product_flash_attention has to return a tuple which has 9 values @@ -211,4 +213,4 @@ def custom_unsafe_index(x, indices): if extension_config.CONFIG_USE_TIMING_POOLING: lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template -lowerings.update({getattr(aten._scaled_dot_product_flash_attention, overload): tuned_flash_sdpa for overload in aten._scaled_dot_product_flash_attention.overloads()}) \ No newline at end of file +lowerings.update({getattr(aten._scaled_dot_product_fused_attention_overrideable, overload): tuned_flash_sdpa for overload in aten._scaled_dot_product_fused_attention_overrideable.overloads()}) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py index 49c6c6bb..05030f27 100644 --- a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py +++ b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py @@ -73,121 +73,6 @@ def flash_sdpa_args( ) return [n, hq, h, l, s, e, ev, layout, query, key, value] - -def validate_sdpa_input( - query : torch.Tensor, - key : torch.Tensor, - value : torch.Tensor, - attn_mask : torch.Tensor = None, - dropout_p : float = 0.0, - is_casual : bool = False, - scale : float = None, - enable_gqa : bool = False) -> None: - """ - Validates input tensors and parameters for Scaled Dot Product Attention (SDPA). - This function's logic can be found in: - https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/attention.cpp(504 line) - https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - """ - - # Tensor class, dtype, and device consistency - # Ensure all primary inputs are torch.Tensors - if not all(isinstance(t, torch.Tensor) for t in [query, key, value]): - raise TypeError( - f"Expected query, key and value to be Tensors, but got " - f"{type(query).__name__}, {type(key).__name__}, and {type(value).__name__}." - ) - - # Check for dtype mismatch - if query.dtype != key.dtype or query.dtype != value.dtype: - raise TypeError( - f"Expected query, key, and value to have the same dtype, " - f"but got {query.dtype}, {key.dtype}, and {value.dtype}." - ) - - # Check for device mismatch (e.g., mixing CPU and NPU) - if query.device != key.device or query.device != value.device: - raise ValueError( - f"Expected query, key, and value to be on the same device, " - f"but got {query.device}, {key.device}, and {value.device}." - ) - - # Shape and dimension validation - # SDPA typically expects 4D (B, H, S, D), but we check for at least 2D here - if any(t.dim() < 2 for t in [query, key, value]): - raise ValueError( - f"Expected query, key, and value to be at least 2D, " - f"but got Q:{query.dim()}D, K:{key.dim()}D, V:{value.dim()}D." - ) - - # Attention mask validation - if attn_mask is not None: - if not isinstance(attn_mask, torch.Tensor): - raise TypeError(f"Expected attn_mask to be a Tensor, but got {type(attn_mask).__name__}.") - - # Dtype check: floating point masks must match query dtype; bool masks are also allowed - if attn_mask.dtype.is_floating_point: - if attn_mask.dtype != query.dtype: - raise TypeError(f"Floating point attn_mask must match query dtype ({query.dtype}), but got {attn_mask.dtype}.") - elif attn_mask.dtype != torch.bool: - raise TypeError(f"attn_mask must be floating point or bool, but got {attn_mask.dtype}.") - - # Nested tensor limitation with explicit masking - if query.is_nested or key.is_nested: - raise ValueError("Nested tensors are not supported when an explicit attn_mask is set.") - - # Dropout and causal flag validation (added) - # Dropout probability must be in the range [0, 1) - if not (0.0 <= dropout_p < 1.0): - raise ValueError(f"Expected dropout_p to be in [0, 1), but got {dropout_p}.") - - # Mutual exclusivity: cannot use both explicit mask and causal flag (added) - if is_casual and attn_mask is not None: - raise ValueError("Both attn_mask and is_casual cannot be set at the same time.") - - # Scaling factor validation (added) - if scale is not None and scale <= 0.0: - raise ValueError(f"Expected scale to be a positive number, but got {scale}.") - - # GQA (Grouped Query Attention) constraints (added) - n_head_q = query.size(1) - n_head_k = key.size(1) - n_head_v = value.size(1) - - # The aten._scaled_dot_product_flash_attention kernel does not accept an explicit enable_gqa parameter. - # Instead, the Flash SDPA implementation infers GQA usage by checking if n_head_q != n_head_k. - if not enable_gqa and n_head_q != n_head_k: - raise ValueError(f"Query and Key must have the same number of heads when enable_gqa is false (Q:{n_head_q} vs K:{n_head_k}).") - - if enable_gqa: - if n_head_q == n_head_k: - raise ValueError(f"enable_gqa Query and Key ") - - if n_head_k != n_head_v: - raise ValueError(f"Key and Value must have the same number of heads (K:{n_head_k} vs V:{n_head_v}).") - - # Query heads must be an integer multiple of key heads for grouping - if n_head_q % n_head_k != 0: - raise ValueError( - f"Number of query heads ({n_head_q}) must be divisible by " - f"number of key heads ({n_head_k}) for GQA." - ) - -def convert_boolean_attn_mask(attn_mask: torch.Tensor, target_dtype: torch.dtype) -> float: - """ - Equivalent to the C++ 'convert_boolean_attn_mask' function. - Converts a boolean mask to a floating-point mask for SDPA. - """ - - if attn_mask is not None and attn_mask.dtype == torch.bool: - - new_mask = torch.zeros_like(attn_mask, dtype=target_dtype) - minus_inf = torch.finfo(target_dtype).min - new_mask.masked_fill_(attn_mask.logical_not(), minus_inf) - - return new_mask - - return attn_mask def calculate_scale(query: torch.Tensor, scale: float) -> float: """ @@ -195,79 +80,10 @@ def calculate_scale(query: torch.Tensor, scale: float) -> float: Otherwise, use the provided scale. """ if scale is None: - return 1.0 / math.sqrt(query.size(-1)) + return 1.0 / math.sqrt(query.layout.size[-1]) else: return scale -def patched_scaled_dot_product_attention( - query_ : torch.Tensor, - key : torch.Tensor, - value : torch.Tensor, - dropout_p : float = 0.0, - is_casual : bool = False, - attn_mask_ : torch.Tensor = None, - scale_ : float = None, - enable_gqa : bool = None, - orig_fn = torch._C._nn.scaled_dot_product_attention) -> torch.Tensor : - """ - Custom patch for Scaled Dot Product Attention (SDPA) to intercept high-level calls. - For NPU devices, it redirects execution to specific ATen kernels based on global flags. - For all devices, it maintains parity with the original dispatcher logic found in: - https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/attention.cpp - - This function acts as a custom override that replaces the default PyTorch SDPA implementation, - invoked via 'PyTorchSim/PyTorchSimDevice/torch_openreg/openreg/__init__.py'. - """ - - # Device-specific Dispatching: redirect to specialized kernels if on NPU - if "npu" in str(query_.device): - - validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_casual, scale_, enable_gqa) - attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype) - - # Kernel selection logic: emulate C++ dispatcher priority - # Selection priority(can be changed): flash attention > memory efficient > math (cuDNN is not supported) - aten = torch.ops.aten - scale = calculate_scale(query_, scale_) - - if flash_sdp_enabled(): - # Skip padding query, key and value for alignment. - dispatch_kwargs = { - "dropout_p" : dropout_p, - "is_causal" : is_casual, - "return_debug_mask" : False, - "scale" : scale - } - - out_lse_softmax = aten._scaled_dot_product_flash_attention( - query_, key, value, **dispatch_kwargs - ) - - return out_lse_softmax[0] - elif mem_efficient_sdp_enabled(): - # out_and_lse = aten._scaled_dot_product_efficient_attention(...) - # return out_and_lse[0] - raise NotImplementedError("Memory efficient SDPA is not implemented yet.") - else: - dispatch_kwargs = { - "attn_mask" : attn_mask, - "dropout_p" : dropout_p, - "is_causal" : is_casual, - "dropout_mask" : None, - "scale": scale, - "enable_gqa" : enable_gqa - } - - out_lse_softmax = aten._scaled_dot_product_attention_math( - query_, - key, - value, - **dispatch_kwargs) - - return out_lse_softmax[0] - else: - # Fallback: Delegate to the original C++ Dispatcher for other devices - return orig_fn(query_, key, value) FLASH_SDPA_TEMPLATE = r""" // SDPA kernel