Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/fusilli/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,9 @@ Graph::layernorm(const std::shared_ptr<TensorAttr> &x,
scale->setName(layernormAttr.getName() + "_SCALE");
if (bias && bias->getName().empty())
bias->setName(layernormAttr.getName() + "_BIAS");
auto eps = layernormAttr.getEpsilon();
if (eps && eps->getName().empty())
eps->setName(layernormAttr.getName() + "_EPSILON");

FUSILLI_LOG_LABEL_ENDL("INFO: Adding LayerNorm '" << layernormAttr.getName()
<< "' to Graph");
Expand Down
75 changes: 64 additions & 11 deletions include/fusilli/support/asm_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "fusilli/node/pointwise_node.h"
#include "fusilli/support/extras.h"

#include <bit> // C++20
#include <cassert>
#include <cctype>
#include <cstddef>
Expand Down Expand Up @@ -161,6 +162,40 @@ inline std::string getPermuteOpsAsm(const std::shared_ptr<TensorAttr> &tensor,
return oss.str();
}

// Emits a scalar TensorAttr as a constant tensor literal in MLIR assembly.
// The result SSA name is the tensor's value name (e.g. %alpha).
inline std::string
getScalarConstantAsm(const std::shared_ptr<TensorAttr> &tensor) {
assert(tensor->isScalar() && tensor->getScalarValue().has_value() &&
"getScalarConstantAsm called with non-scalar tensor");
std::string resultName = tensor->getValueNameAsm();
std::string mlirType = kDataTypeToMlirTypeAsm.at(tensor->getDataType());
std::string resultType = tensor->getTensorTypeAsm(/*isValueTensor=*/true,
/*useLogicalDims=*/true);
// std::visit generates a compile time switch statement executing lambda
// instantiation per variant alternative.
// Example output:
// float 1.0f → dense<0x3F800000> : tensor<1xf32>
// double 1.0 → dense<0x3FF0000000000000> : tensor<1xf64>
// int32 42 → dense<0x0000002A> : tensor<1xi32>
// int64 42L → dense<0x000000000000002A> : tensor<1xi64>
return std::visit(
[&](auto val) -> std::string {
using UInt = std::conditional_t<sizeof(val) == 4, uint32_t, uint64_t>;
// {:0{}X} -> hex format with runtime width:
// - 0 pad with zeros
// - {} runtime width (sizeof(val)*2: 4 bytes→8, 8 bytes→16)
// - X uppercase hex
return std::format(
"\n{} = torch.vtensor.literal(dense<0x{:0{}X}> : tensor<1x{}>) "
": {}\n",
Comment on lines +189 to +191
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same: prefer to rewrite this using raw multi-line strings R"(...)" to prevent indendation issues in the generated ASM:

module @module {
  func.func @main(%result_: !torch.tensor<[16,128,64,32],f32>, %arg0_x: !torch.vtensor<[16,128,64,32],f32>) attributes {torch.assume_strict_symbolic_shapes} {
  
%layernorm_infer_EPSILON = torch.vtensor.literal(dense<0x3727C5AC> : tensor<1xf32>) : !torch.vtensor<[1],f32>

    %normalized_shape_val_0_layernorm_infer = torch.constant.int 128
    %normalized_shape_val_1_layernorm_infer = torch.constant.int 64
    %normalized_shape_val_2_layernorm_infer = torch.constant.int 32
    %normalized_shape_layernorm_infer = torch.prim.ListConstruct %normalized_shape_val_0_layernorm_infer, %normalized_shape_val_1_layernorm_infer, %normalized_shape_val_2_layernorm_infer : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>

resultName, std::bit_cast<UInt>(val), sizeof(val) * 2, mlirType,
resultType);
},
tensor->getScalarValue()
.value()); // std::variant<int64_t, int32_t, float, double>
Comment thread
AaronStGeorge marked this conversation as resolved.
}

//===----------------------------------------------------------------------===//
//
// TensorAttr ASM Emitter Methods
Expand All @@ -169,11 +204,9 @@ inline std::string getPermuteOpsAsm(const std::shared_ptr<TensorAttr> &tensor,

// Emits a ranked tensor type in MLIR assembly representation.
//
// This expects ranked tensors (non-scalar) as we blanket generate a
// `!torch.vtensor` (or `!torch.tensor` if mutable) type. The caller
// is responsible to check for this. In the future we may want to extend
// this (or add new methods) for scalar types (such as `!torch.int` or
// `!torch.bool`).
// This expects ranked tensors as we blanket generate a `!torch.vtensor` (or
// `!torch.tensor` if mutable) type. The caller is responsible to check for
// this.
//
// Example:
//
Expand All @@ -198,9 +231,15 @@ inline std::string getPermuteOpsAsm(const std::shared_ptr<TensorAttr> &tensor,
// t.getTensorTypeAsm(/*isValueTensor=*/false,
// /*useLogicalDims=*/false)
// --> "!torch.tensor<[2,4,3],f32>"
//
// Scalars (dim={1}, stride={1}) also work through this path:
//
// TensorAttr s(2.0f);
// s.getTensorTypeAsm(/*isValueTensor=*/true,
// /*useLogicalDims=*/true)
// --> "!torch.vtensor<[1],f32>"
inline std::string TensorAttr::getTensorTypeAsm(bool isValueTensor,
bool useLogicalDims) const {
assert(!isScalar() && "TensorAttr::getTensorTypeAsm expects a ranked tensor");
assert(!getDim().empty() &&
"TensorAttr::getTensorTypeAsm expects non-empty dims");
assert(!getStride().empty() &&
Expand Down Expand Up @@ -326,6 +365,13 @@ module @module {{
getOperandNamesAndTypesAsm() // {1}
);

// Emit scalar constants (`torch.vtensor.literal`) for all scalar graph inputs
// at the top of the function body.
for (const auto &input : fullGraphInputsSorted_) {
if (input->isScalar())
output += getScalarConstantAsm(input);
}

return output;
}

Expand Down Expand Up @@ -971,12 +1017,19 @@ inline std::string LayerNormNode::getNormalizedShapeOpsAsm() const {
/*suffix=*/layernormAttr.getName());
}

// Get epsilon constant op in MLIR assembly format.
// Get epsilon extraction op in MLIR assembly format. The scalar constant
// `torch.vtensor.literal` is emitted once at graph level
// (Graph::emitNodePreAsm). Here we extract the float value with
// `torch.aten.item` for use with `torch.aten.layer_norm` which expects
// `!torch.float`.
inline std::string LayerNormNode::getEpsilonOpsAsm() const {
float eps =
std::get<float>(layernormAttr.getEpsilon()->getScalarValue().value());
return std::format("%eps_{} = torch.constant.float {:e}",
layernormAttr.getName(), eps);
std::string suffix = layernormAttr.getName();
auto eps = layernormAttr.getEpsilon();
std::string tensorName = eps->getValueNameAsm();
std::string tensorType =
eps->getTensorTypeAsm(/*isValueTensor=*/true, /*useLogicalDims=*/true);
return std::format(" %eps_{} = torch.aten.item {} : {} -> !torch.float\n",
suffix, tensorName, tensorType);
Comment on lines +1031 to +1032
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the whitespace before %eps... intentional? There was no space in the previous case.

It introduces an unnecessary indent in the generated MLIR:

module @module {
  func.func @main(%result_: !torch.tensor<[16,128,64,32],f32>, %arg0_x: !torch.vtensor<[16,128,64,32],f32>) attributes {torch.assume_strict_symbolic_shapes} {
  
%layernorm_infer_EPSILON = torch.vtensor.literal(dense<0x3727C5AC> : tensor<1xf32>) : !torch.vtensor<[1],f32>

    %normalized_shape_val_0_layernorm_infer = torch.constant.int 128
    %normalized_shape_val_1_layernorm_infer = torch.constant.int 64
    %normalized_shape_val_2_layernorm_infer = torch.constant.int 32
    %normalized_shape_layernorm_infer = torch.prim.ListConstruct %normalized_shape_val_0_layernorm_infer, %normalized_shape_val_1_layernorm_infer, %normalized_shape_val_2_layernorm_infer : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>

        %eps_layernorm_infer = torch.aten.item %layernorm_infer_EPSILON : !torch.vtensor<[1],f32> -> !torch.float

    %permute_x_val_0_layernorm_infer = torch.constant.int 0
    %permute_x_val_1_layernorm_infer = torch.constant.int 1
    %permute_x_val_2_layernorm_infer = torch.constant.int 2
    %permute_x_val_3_layernorm_infer = torch.constant.int 3

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's one of the reasons we use the raw multi-line strings with a schema, so the spacing is precise.

}

// This gets called by the recursive `emitAsmSubtree()` method to emit
Expand Down
41 changes: 36 additions & 5 deletions plugins/hipdnn-plugin/include/graph_import.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,37 @@ class GraphImport {
}

// Import new tensor.
auto fusilliTensorAttr =
fusilli::TensorAttr().setName(std::format("{}_{}", name, uid)); // C++20
fusilli::TensorAttr fusilliTensorAttr;
if (isPassByValue(hipDnnTensorAttr)) { // handle scalar tensors
switch (hipDnnTensorAttr->value_type()) {
case hipdnn_data_sdk::data_objects::TensorValue::Float32Value:
fusilliTensorAttr = fusilli::TensorAttr(
hipDnnTensorAttr->value_as_Float32Value()->value());
break;
case hipdnn_data_sdk::data_objects::TensorValue::Float64Value:
fusilliTensorAttr = fusilli::TensorAttr(
hipDnnTensorAttr->value_as_Float64Value()->value());
break;
case hipdnn_data_sdk::data_objects::TensorValue::Int32Value:
fusilliTensorAttr = fusilli::TensorAttr(
hipDnnTensorAttr->value_as_Int32Value()->value());
break;
default:
return fusilli::error(
fusilli::ErrorCode::NotImplemented,
"Unsupported scalar type in hipdnn -> fusilli graph translation.");
}
}
fusilliTensorAttr.setName(std::format("{}_{}", name, uid)); // C++20
FUSILLI_CHECK_ERROR(importAttrs(fusilliTensorAttr, hipDnnTensorAttr));
std::shared_ptr<fusilli::TensorAttr> graphInput =
fusilliGraph.tensor(fusilliTensorAttr);

// Track boundary tensor.
uidToIOTensor[uid] = graphInput;
// Scalar constants are embedded in the MLIR IR and don't need device
// buffers, so exclude them from the IO tensor map that drives variant
// pack construction at execution time.
if (!graphInput->isScalar())
uidToIOTensor[uid] = graphInput;

return ok(graphInput);
};
Expand Down Expand Up @@ -327,7 +350,15 @@ class GraphImport {
return fusilli::ok();
};

// Import all tensor attrs src -> dest.
// Whether the hipDNN tensor carries a pass-by-value scalar (equivalent to
// hipDNN frontend's TensorAttributes::get_pass_by_value()).
static bool
isPassByValue(const hipdnn_data_sdk::data_objects::TensorAttributes *src) {
return src->value_type() !=
hipdnn_data_sdk::data_objects::TensorValue::NONE;
}

// Import tensor attrs (dims, strides, datatype) from hipDNN to fusilli.
fusilli::ErrorObject
importAttrs(fusilli::TensorAttr &dest,
const hipdnn_data_sdk::data_objects::TensorAttributes *src) {
Expand Down
15 changes: 15 additions & 0 deletions plugins/hipdnn-plugin/test/integration/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,21 @@ add_fusilli_plugin_test(
FUSILLI_PLUGIN_PATH="../${FUSILLI_PLUGIN_RELATIVE_PATH}"
)

add_fusilli_plugin_test(
NAME fusilli_plugin_simple_scaled_matmul_accumulate_test
SRCS matmul/simple_scaled_matmul_accumulate.cpp
DEPS
GTest::gtest_main
hip::host
hipdnn_frontend
hipdnn_plugin_sdk
hipdnn_data_sdk
hipdnn_test_sdk
Threads::Threads
COMPILE_DEFS
FUSILLI_PLUGIN_PATH="../${FUSILLI_PLUGIN_RELATIVE_PATH}"
)

add_fusilli_plugin_test(
NAME fusilli_plugin_simple_pointwise_test
SRCS pointwise/simple_pointwise.cpp
Expand Down
Loading
Loading