Skip to content

[Fusilli,FusilliPlugin] Scalar value support#161

Open
AaronStGeorge wants to merge 1 commit intoiree-org:mainfrom
AaronStGeorge:p022-scalar-value-support
Open

[Fusilli,FusilliPlugin] Scalar value support#161
AaronStGeorge wants to merge 1 commit intoiree-org:mainfrom
AaronStGeorge:p022-scalar-value-support

Conversation

@AaronStGeorge
Copy link
Contributor

@AaronStGeorge AaronStGeorge commented Feb 16, 2026

Adds support for scalar tensor value. Fusilli and the plugin will now support graphs of the form alpha * A, alpha * A * B * beta, etc.

@sjain-stanford sjain-stanford linked an issue Feb 17, 2026 that may be closed by this pull request
@AaronStGeorge AaronStGeorge force-pushed the p022-scalar-value-support branch 5 times, most recently from 18c725d to 45a1a2a Compare February 17, 2026 22:58
Adds support for scalar tensor value. Fusilli and the plugin will now
support graphs of the form `alpha * A`, `alpha * A * beta * B`, etc.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@AaronStGeorge AaronStGeorge force-pushed the p022-scalar-value-support branch from 45a1a2a to 025cadc Compare February 17, 2026 23:09
@AaronStGeorge AaronStGeorge marked this pull request as ready for review February 17, 2026 23:26
Comment on lines +354 to +355
dest = fusilli::TensorAttr(
/*scalar=*/src->value_as_Float32Value()->value());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will clear dest's name. I think we need to set the name at the end of this path.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I'll fix that up.

Comment on lines +1266 to +1274
auto emitInput = [&](const std::shared_ptr<TensorAttr> &input,
const std::string &permLabel) -> std::string {
if (!input)
return "";
if (input->isScalar())
return getScalarConstantAsm(input, uniqueSSASuffix);
return getPermuteOpsAsm(input, permLabel, uniqueSSASuffix,
/*isInput=*/true);
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should be changing PointwiseNode::emitNodePreAsm. Node's shouldn't care/need to know if their inputs are inlined or not. The graph already handles emitting ASM for tensor arguments so it seems like a natural extension for it to also emit torch.vtensor.literal for inlined scalars.

We can use this in LayerNormNode::getEpsilonOpsAsm(), too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is entirely on me. I think Aaron was initially leaning towards having the graph handle it. At the time my concerns with that approach were: 1) inlined scalars don't appear as function arguments (not tracked in the variant pack) so might need a special way to query individual ops's inputs to gather the scalars, and 2) treating node-specific scalars as globals at the graph (two layernorm nodes using the same eps).

I was wrong with 1, as they do appear in the graph inputs (just not in the variant pack).

My concern with 2 is not a hard one.

Copy link
Contributor Author

@AaronStGeorge AaronStGeorge Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll look into emitting it at the graph level? This code is the extent of my contributions to the asm emitter so I don't have strong opinions on the architecture, happy to go with the consensus opinion here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah let's go with what Ian suggested.

Comment on lines +178 to +194
return std::visit(
[&](auto val) -> std::string {
using T = decltype(val);
if constexpr (std::is_floating_point_v<T>) { // float, double
return std::format(
"\n{} = torch.vtensor.literal(dense<{:e}> : tensor<1x{}>) : "
"{}\n",
resultName, val, mlirType, resultType);
} else { // int64_t, int32_t
return std::format(
"\n{} = torch.vtensor.literal(dense<{}> : tensor<1x{}>) : "
"{}\n",
resultName, val, mlirType, resultType);
}
},
tensor->getScalarValue()
.value()); // std::variant<int64_t, int32_t, float, double>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe emit hex here? Then, this can be simplified to a single std::format call (need to supply bitwidth) and would prevent potential float precision issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Fusilli-plugin] Add scalar value support

3 participants

Comments