[Fusilli,FusilliPlugin] Scalar value support#161
[Fusilli,FusilliPlugin] Scalar value support#161AaronStGeorge wants to merge 1 commit intoiree-org:mainfrom
Conversation
18c725d to
45a1a2a
Compare
Adds support for scalar tensor value. Fusilli and the plugin will now support graphs of the form `alpha * A`, `alpha * A * beta * B`, etc. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
45a1a2a to
025cadc
Compare
| dest = fusilli::TensorAttr( | ||
| /*scalar=*/src->value_as_Float32Value()->value()); |
There was a problem hiding this comment.
This will clear dest's name. I think we need to set the name at the end of this path.
There was a problem hiding this comment.
Good catch! I'll fix that up.
| auto emitInput = [&](const std::shared_ptr<TensorAttr> &input, | ||
| const std::string &permLabel) -> std::string { | ||
| if (!input) | ||
| return ""; | ||
| if (input->isScalar()) | ||
| return getScalarConstantAsm(input, uniqueSSASuffix); | ||
| return getPermuteOpsAsm(input, permLabel, uniqueSSASuffix, | ||
| /*isInput=*/true); | ||
| }; |
There was a problem hiding this comment.
I don't think we should be changing PointwiseNode::emitNodePreAsm. Node's shouldn't care/need to know if their inputs are inlined or not. The graph already handles emitting ASM for tensor arguments so it seems like a natural extension for it to also emit torch.vtensor.literal for inlined scalars.
We can use this in LayerNormNode::getEpsilonOpsAsm(), too.
There was a problem hiding this comment.
This is entirely on me. I think Aaron was initially leaning towards having the graph handle it. At the time my concerns with that approach were: 1) inlined scalars don't appear as function arguments (not tracked in the variant pack) so might need a special way to query individual ops's inputs to gather the scalars, and 2) treating node-specific scalars as globals at the graph (two layernorm nodes using the same eps).
I was wrong with 1, as they do appear in the graph inputs (just not in the variant pack).
My concern with 2 is not a hard one.
There was a problem hiding this comment.
I'll look into emitting it at the graph level? This code is the extent of my contributions to the asm emitter so I don't have strong opinions on the architecture, happy to go with the consensus opinion here.
There was a problem hiding this comment.
Yeah let's go with what Ian suggested.
| return std::visit( | ||
| [&](auto val) -> std::string { | ||
| using T = decltype(val); | ||
| if constexpr (std::is_floating_point_v<T>) { // float, double | ||
| return std::format( | ||
| "\n{} = torch.vtensor.literal(dense<{:e}> : tensor<1x{}>) : " | ||
| "{}\n", | ||
| resultName, val, mlirType, resultType); | ||
| } else { // int64_t, int32_t | ||
| return std::format( | ||
| "\n{} = torch.vtensor.literal(dense<{}> : tensor<1x{}>) : " | ||
| "{}\n", | ||
| resultName, val, mlirType, resultType); | ||
| } | ||
| }, | ||
| tensor->getScalarValue() | ||
| .value()); // std::variant<int64_t, int32_t, float, double> |
There was a problem hiding this comment.
Maybe emit hex here? Then, this can be simplified to a single std::format call (need to supply bitwidth) and would prevent potential float precision issues.
Adds support for scalar tensor value. Fusilli and the plugin will now support graphs of the form
alpha * A,alpha * A * B * beta, etc.