Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 1 addition & 33 deletions PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <ATen/native/CPUFallback.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/transformers/attention.h>

#include <torch/csrc/autograd/autograd_not_implemented_fallback.h>
#include <torch/library.h>
Expand Down Expand Up @@ -40,36 +41,6 @@ void wrapper_quantize_tensor_per_tensor_affine_stub(
rtensor, qtensor, scale, zero_point);
}

std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
c10::SymInt,
c10::SymInt,
at::Tensor,
at::Tensor,
at::Tensor>
wrapper__scaled_dot_product_fused_attention_overrideable(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const std::optional<at::Tensor>& attn_bias,
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale) {
return at::native::openreg::_scaled_dot_product_fused_attention_overrideable(
query,
key,
value,
attn_bias,
dropout_p,
is_causal,
return_debug_mask,
scale);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
wrapper_scaled_dot_product_fused_attention_overrideable_backward(
const at::Tensor& grad_out,
Expand Down Expand Up @@ -172,9 +143,6 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("abs.out", &wrapper_abs_out);
m.impl("quantize_per_tensor", &wrapper_quantize_per_tensor);
m.impl("_fused_sdp_choice", &wrapper__fused_sdp_choice);
m.impl(
"_scaled_dot_product_fused_attention_overrideable",
&wrapper__scaled_dot_product_fused_attention_overrideable);
m.impl(
"_scaled_dot_product_fused_attention_overrideable_backward",
&wrapper_scaled_dot_product_fused_attention_overrideable_backward);
Expand Down
51 changes: 2 additions & 49 deletions PyTorchSimDevice/csrc/aten/native/Extra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ int64_t _fused_sdp_choice(
bool is_causal,
std::optional<double> scale,
bool enable_gqa) {
auto backend = sdp::SDPBackend::math;

auto backend = sdp::SDPBackend::overrideable;
return static_cast<int64_t>(backend);
}

Expand All @@ -29,54 +30,6 @@ void quantize_tensor_per_tensor_affine_stub(
double scale,
int64_t zero_point) {}

std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
c10::SymInt,
c10::SymInt,
at::Tensor,
at::Tensor,
at::Tensor>
_scaled_dot_product_fused_attention_overrideable(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const std::optional<at::Tensor>& attn_bias,
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale) {
const int64_t batch_size = query.size(0);
const int64_t num_heads = query.size(1);
const int64_t head_dim_v = value.size(3);
const int64_t max_seqlen_q = query.size(2);
const int64_t max_seqlen_kv = key.size(2);

auto opts = query.options();
auto output =
at::empty({batch_size, num_heads, max_seqlen_q, head_dim_v}, opts);
auto logsumexp =
at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
auto debug_attn_mask = at::empty(
{batch_size, num_heads, max_seqlen_q, max_seqlen_kv},
opts.dtype(at::kFloat));
auto philox_seed = at::empty({}, at::dtype(at::kLong));
auto philox_offset = at::empty({}, at::dtype(at::kLong));

return std::make_tuple(
output,
logsumexp,
at::Tensor(),
at::Tensor(),
max_seqlen_q,
max_seqlen_kv,
philox_seed,
philox_offset,
debug_attn_mask);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_scaled_dot_product_fused_attention_overrideable_backward(
const at::Tensor& grad_out,
Expand Down
7 changes: 6 additions & 1 deletion PyTorchSimDevice/torch_openreg/openreg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class device:

def __init__(self, device):
self.idx = torch.accelerator._get_device_index(device, optional=True)
self.prev_idx = -1
self.prev_idx = -1

def __enter__(self):
self.prev_idx = torch_openreg._C._exchangeDevice(self.idx)
Expand Down Expand Up @@ -64,6 +64,11 @@ def _lazy_init():
global _initialized, _tog_simulator
if is_initialized():
return

# Replace the global C++ binding with our custom dispatcher patch
# from PyTorchSimFrontend.mlir.mlir_sdpa_template import patched_scaled_dot_product_attention
# torch._C._nn.scaled_dot_product_attention = patched_scaled_dot_product_attention

torch_openreg._C._init()
register_interface_for_device(custom_device(), ExtensionDeviceInterface)
_initialized = True
Expand Down
27 changes: 26 additions & 1 deletion PyTorchSimFrontend/mlir/mlir_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from PyTorchSimFrontend.mlir.mlir_conv_sb_template import MLIRConvSingleBatchTemplate
from PyTorchSimFrontend.mlir.mlir_conv_sbs_template import MLIRConvSingleBatchStridedTemplate
from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate
from PyTorchSimFrontend.mlir.mlir_sdpa_template import MLIRFlashSDPATemplate, flash_sdpa_args, calculate_scale
from PyTorchSimFrontend import extension_config

aten = torch.ops.aten
Expand All @@ -38,6 +39,28 @@ def tuned_bmm(mat1, mat2, *, layout=None):

return mlir_template.generate().output_node()


def tuned_flash_sdpa(
query : TensorBox,
key : TensorBox,
value : TensorBox,
attn_bias : Optional[TensorBox] = None,
dropout_p : float = 0.0,
is_causal : bool = False,
return_debug_mask : bool = False,
scale : Optional[float] = None) -> tuple:


scale = calculate_scale(query, scale)
N, Hq, H, L, S, E, Ev, layout, query, key, value = flash_sdpa_args(query, key, value)

mlir_template = MLIRFlashSDPATemplate([query, key, value], layout, scale)

# _scaled_dot_product_flash_attention has to return a tuple which has 9 values
# since its backward(_scaled_dot_product_flash_attention_backward) needs that values.
# (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
return (mlir_template.generate().output_node(), None, None, None, None, None, None, None, None)

def conv_layout(
x: TensorBox,
weight: TensorBox,
Expand Down Expand Up @@ -188,4 +211,6 @@ def custom_unsafe_index(x, indices):
lowerings.update({getattr(aten._sparse_addmm, overload): sparse_addmm for overload in aten._sparse_addmm.overloads()})
lowerings.update({getattr(aten._unsafe_index, overload): custom_unsafe_index for overload in aten._unsafe_index.overloads()})
if extension_config.CONFIG_USE_TIMING_POOLING:
lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template
lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template

lowerings.update({getattr(aten._scaled_dot_product_fused_attention_overrideable, overload): tuned_flash_sdpa for overload in aten._scaled_dot_product_fused_attention_overrideable.overloads()})
Loading