-
Notifications
You must be signed in to change notification settings - Fork 16
[Fusilli,FusilliPlugin] Scalar value support #161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,6 +37,7 @@ | |
| #include "fusilli/node/pointwise_node.h" | ||
| #include "fusilli/support/extras.h" | ||
|
|
||
| #include <bit> // C++20 | ||
| #include <cassert> | ||
| #include <cctype> | ||
| #include <cstddef> | ||
|
|
@@ -161,6 +162,40 @@ inline std::string getPermuteOpsAsm(const std::shared_ptr<TensorAttr> &tensor, | |
| return oss.str(); | ||
| } | ||
|
|
||
| // Emits a scalar TensorAttr as a constant tensor literal in MLIR assembly. | ||
| // The result SSA name is the tensor's value name (e.g. %alpha). | ||
| inline std::string | ||
| getScalarConstantAsm(const std::shared_ptr<TensorAttr> &tensor) { | ||
| assert(tensor->isScalar() && tensor->getScalarValue().has_value() && | ||
| "getScalarConstantAsm called with non-scalar tensor"); | ||
| std::string resultName = tensor->getValueNameAsm(); | ||
| std::string mlirType = kDataTypeToMlirTypeAsm.at(tensor->getDataType()); | ||
| std::string resultType = tensor->getTensorTypeAsm(/*isValueTensor=*/true, | ||
| /*useLogicalDims=*/true); | ||
| // std::visit generates a compile time switch statement executing lambda | ||
| // instantiation per variant alternative. | ||
| // Example output: | ||
| // float 1.0f → dense<0x3F800000> : tensor<1xf32> | ||
| // double 1.0 → dense<0x3FF0000000000000> : tensor<1xf64> | ||
| // int32 42 → dense<0x0000002A> : tensor<1xi32> | ||
| // int64 42L → dense<0x000000000000002A> : tensor<1xi64> | ||
| return std::visit( | ||
| [&](auto val) -> std::string { | ||
| using UInt = std::conditional_t<sizeof(val) == 4, uint32_t, uint64_t>; | ||
| // {:0{}X} -> hex format with runtime width: | ||
| // - 0 pad with zeros | ||
| // - {} runtime width (sizeof(val)*2: 4 bytes→8, 8 bytes→16) | ||
| // - X uppercase hex | ||
| return std::format( | ||
| "\n{} = torch.vtensor.literal(dense<0x{:0{}X}> : tensor<1x{}>) " | ||
| ": {}\n", | ||
| resultName, std::bit_cast<UInt>(val), sizeof(val) * 2, mlirType, | ||
| resultType); | ||
| }, | ||
| tensor->getScalarValue() | ||
| .value()); // std::variant<int64_t, int32_t, float, double> | ||
|
AaronStGeorge marked this conversation as resolved.
|
||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // | ||
| // TensorAttr ASM Emitter Methods | ||
|
|
@@ -169,11 +204,9 @@ inline std::string getPermuteOpsAsm(const std::shared_ptr<TensorAttr> &tensor, | |
|
|
||
| // Emits a ranked tensor type in MLIR assembly representation. | ||
| // | ||
| // This expects ranked tensors (non-scalar) as we blanket generate a | ||
| // `!torch.vtensor` (or `!torch.tensor` if mutable) type. The caller | ||
| // is responsible to check for this. In the future we may want to extend | ||
| // this (or add new methods) for scalar types (such as `!torch.int` or | ||
| // `!torch.bool`). | ||
| // This expects ranked tensors as we blanket generate a `!torch.vtensor` (or | ||
| // `!torch.tensor` if mutable) type. The caller is responsible to check for | ||
| // this. | ||
| // | ||
| // Example: | ||
| // | ||
|
|
@@ -198,9 +231,15 @@ inline std::string getPermuteOpsAsm(const std::shared_ptr<TensorAttr> &tensor, | |
| // t.getTensorTypeAsm(/*isValueTensor=*/false, | ||
| // /*useLogicalDims=*/false) | ||
| // --> "!torch.tensor<[2,4,3],f32>" | ||
| // | ||
| // Scalars (dim={1}, stride={1}) also work through this path: | ||
| // | ||
| // TensorAttr s(2.0f); | ||
| // s.getTensorTypeAsm(/*isValueTensor=*/true, | ||
| // /*useLogicalDims=*/true) | ||
| // --> "!torch.vtensor<[1],f32>" | ||
| inline std::string TensorAttr::getTensorTypeAsm(bool isValueTensor, | ||
| bool useLogicalDims) const { | ||
| assert(!isScalar() && "TensorAttr::getTensorTypeAsm expects a ranked tensor"); | ||
| assert(!getDim().empty() && | ||
| "TensorAttr::getTensorTypeAsm expects non-empty dims"); | ||
| assert(!getStride().empty() && | ||
|
|
@@ -326,6 +365,13 @@ module @module {{ | |
| getOperandNamesAndTypesAsm() // {1} | ||
| ); | ||
|
|
||
| // Emit scalar constants (`torch.vtensor.literal`) for all scalar graph inputs | ||
| // at the top of the function body. | ||
| for (const auto &input : fullGraphInputsSorted_) { | ||
| if (input->isScalar()) | ||
| output += getScalarConstantAsm(input); | ||
| } | ||
|
|
||
| return output; | ||
| } | ||
|
|
||
|
|
@@ -971,12 +1017,19 @@ inline std::string LayerNormNode::getNormalizedShapeOpsAsm() const { | |
| /*suffix=*/layernormAttr.getName()); | ||
| } | ||
|
|
||
| // Get epsilon constant op in MLIR assembly format. | ||
| // Get epsilon extraction op in MLIR assembly format. The scalar constant | ||
| // `torch.vtensor.literal` is emitted once at graph level | ||
| // (Graph::emitNodePreAsm). Here we extract the float value with | ||
| // `torch.aten.item` for use with `torch.aten.layer_norm` which expects | ||
| // `!torch.float`. | ||
| inline std::string LayerNormNode::getEpsilonOpsAsm() const { | ||
| float eps = | ||
| std::get<float>(layernormAttr.getEpsilon()->getScalarValue().value()); | ||
| return std::format("%eps_{} = torch.constant.float {:e}", | ||
| layernormAttr.getName(), eps); | ||
| std::string suffix = layernormAttr.getName(); | ||
| auto eps = layernormAttr.getEpsilon(); | ||
| std::string tensorName = eps->getValueNameAsm(); | ||
| std::string tensorType = | ||
| eps->getTensorTypeAsm(/*isValueTensor=*/true, /*useLogicalDims=*/true); | ||
| return std::format(" %eps_{} = torch.aten.item {} : {} -> !torch.float\n", | ||
| suffix, tensorName, tensorType); | ||
|
Comment on lines
+1031
to
+1032
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the whitespace before It introduces an unnecessary indent in the generated MLIR: module @module {
func.func @main(%result_: !torch.tensor<[16,128,64,32],f32>, %arg0_x: !torch.vtensor<[16,128,64,32],f32>) attributes {torch.assume_strict_symbolic_shapes} {
%layernorm_infer_EPSILON = torch.vtensor.literal(dense<0x3727C5AC> : tensor<1xf32>) : !torch.vtensor<[1],f32>
%normalized_shape_val_0_layernorm_infer = torch.constant.int 128
%normalized_shape_val_1_layernorm_infer = torch.constant.int 64
%normalized_shape_val_2_layernorm_infer = torch.constant.int 32
%normalized_shape_layernorm_infer = torch.prim.ListConstruct %normalized_shape_val_0_layernorm_infer, %normalized_shape_val_1_layernorm_infer, %normalized_shape_val_2_layernorm_infer : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%eps_layernorm_infer = torch.aten.item %layernorm_infer_EPSILON : !torch.vtensor<[1],f32> -> !torch.float
%permute_x_val_0_layernorm_infer = torch.constant.int 0
%permute_x_val_1_layernorm_infer = torch.constant.int 1
%permute_x_val_2_layernorm_infer = torch.constant.int 2
%permute_x_val_3_layernorm_infer = torch.constant.int 3
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's one of the reasons we use the raw multi-line strings with a schema, so the spacing is precise. |
||
| } | ||
|
|
||
| // This gets called by the recursive `emitAsmSubtree()` method to emit | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same: prefer to rewrite this using raw multi-line strings
R"(...)"to prevent indendation issues in the generated ASM: