Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class Val;
f(ScanOp); \
f(Merge); \
f(Partition); \
f(Combine); \
f(Swizzle); \
f(Swizzle2D); \
f(Resize); \
Expand Down
100 changes: 100 additions & 0 deletions csrc/ir/internal_base_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,106 @@ std::pair<IterDomain*, RaggedIterDomain*> RaggedIterDomain::partition(
return {component_id, ragged_id};
}

IterDomain* RaggedIterDomain::combine(
IterDomain* component,
RaggedIterDomain* ragged) {
NVF_ERROR(component != nullptr, "combine: component IterDomain is null");
NVF_ERROR(ragged != nullptr, "combine: ragged IterDomain is null");

NVF_ERROR(
!component->isA<RaggedIterDomain>(),
"combine: component must be a regular IterDomain, got RaggedIterDomain: ",
component->toString());

// Validate that component and ragged have compatible properties
NVF_ERROR_EQ(
component->getParallelType(),
ParallelType::Serial,
"Combining parallelized IterDomain not supported: ",
component->toString());

NVF_ERROR_EQ(
ragged->getParallelType(),
ParallelType::Serial,
"Combining parallelized RaggedIterDomain not supported: ",
ragged->toString());

NVF_ERROR_EQ(
component->getIterType(),
IterType::Iteration,
"combine: only IterType::Iteration is supported for component, got ",
component->getIterType(),
" for IterDomain: ",
component->toString());

NVF_ERROR_EQ(
ragged->getIterType(),
IterType::Iteration,
"combine: only IterType::Iteration is supported for ragged, got ",
ragged->getIterType(),
" for RaggedIterDomain: ",
ragged->toString());
Comment on lines +1060 to +1095
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Missing validation that component and ragged are semantically compatible or from the same partition. Consider verifying:

  • That the component extent matches the ragged extents dimension (component extent == extents_tv logical domain size)
  • Or that both came from the same Partition operation (component->definition() == ragged->definition())

Without this, arbitrary unrelated IterDomains could be combined.


// Validate component-ragged pairing when Partition definition is available
// (Option 3 of doc/dev/ragged_iter_domain_combine_design_doc.md).
// Only validate when the RaggedIterDomain has a direct Partition definition.
// After propagation (e.g., set() operations), the definition may be nullptr,
// in which case we trust the user to provide the correct component.
if (ragged->definition() != nullptr &&
ragged->definition()->isA<Partition>()) {
auto* partition = ragged->definition()->as<Partition>();
IterDomain* expected_component = partition->component();

NVF_ERROR(
component == expected_component,
"combine: component mismatch. The provided component does not match ",
"the component from the Partition that created this "
"RaggedIterDomain.\n",
" Provided component: ",
component->toString(),
"\n",
" Expected component: ",
expected_component->toString());
}
// If no Partition definition (after set, in segmented fusion, or external
// input), trust the user and proceed without validation

// The combined extent is the sum of all extents in the ragged dimension
// For a 1D extents tensor [e0, e1, ..., en-1], the total is sum(extents)
TensorView* extents_tv = ragged->extents();
NVF_ERROR(extents_tv != nullptr, "combine: ragged extents tensor is null");

// It is still assumed the extents tensor is just 1D
NVF_ERROR_EQ(
std::ssize(extents_tv->getLogicalDomain()),
1,
"Unexpected rank of extent tensor: ",
extents_tv->toString());

auto container = component->container();
auto zero = container->zeroVal(DataType::Index);

// Create a symbolic extent for the combined IterDomain
// This represents the sum of all ragged extents, i.e.,
// sum(extents_tv, {0}). We could use the sum output as the extent
// but we would need to extract the scalar value out of the 0-dim
// tensor. For now, we leave it as a symbolic Val.
Val* combined_extent =
IrBuilder::createInContainer<Val>(container, DataType::Index);
Comment on lines +1141 to +1142
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The symbolic extent approach leaves the combined extent as an opaque Val without establishing its relationship to the actual sum. During lowering/indexing, the system needs to compute actual indices, but there's no expression connecting combined_extent to sum(extents_tv, {0}).

Check whether this symbolic extent will cause issues during:

  • Index computation when iterating over the combined dimension
  • Extent analysis passes that need to know actual sizes
  • Fusion validation that checks dimension compatibility


// Create the combined IterDomain with the symbolic extent
IterDomain* combined_id = IterDomainBuilder(zero, combined_extent)
.parallel_type(ParallelType::Serial)
.iter_type(IterType::Iteration)
.build();

// Create the Combine expression linking component + ragged -> combined
IrBuilder::createInContainer<Combine>(
container, combined_id, component, ragged);
Comment on lines +1151 to +1152
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Unlike Partition which stores the extents tensor as an attribute (line 2634 in internal_nodes.cpp), Combine doesn't store it anywhere. While the extents can be accessed via ragged->extents(), this creates an inconsistency.

If the ragged IterDomain later loses its connection to the original extents (e.g., through cloning or transformation), the Combine expression won't have a direct reference to them.

Suggested change
IrBuilder::createInContainer<Combine>(
container, combined_id, component, ragged);
// Store extents as an attribute for consistency with Partition
IrBuilder::createInContainer<Combine>(
container, combined_id, component, ragged, ragged->extents());

Note: This would require updating the Combine class signature to accept and store extents as an attribute.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


return combined_id;
}

TensorDomain::TensorDomain(
IrBuilderPasskey passkey,
std::vector<IterDomain*> logical_domain,
Expand Down
16 changes: 16 additions & 0 deletions csrc/ir/internal_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,22 @@ class NVF_API RaggedIterDomain : public IterDomain {
IterDomain* in,
TensorView* extents);

//! Combine a component IterDomain with a RaggedIterDomain to flatten
//! This is the inverse of partition, creating a regular IterDomain
//!
//! \param component Component IterDomain (extent = num_components)
//! \param ragged RaggedIterDomain with variable extents per component
//! \return Regular IterDomain with extent = sum of all component extents
//!
//! This operation flattens the ragged structure back into a single dimension.
//! Example: component extent=3, ragged extents=[127, 0, 198]
//! -> output extent = 325 (= 127 + 0 + 198)
//!
//! Note: We use "combine" instead of "merge" to differentiate from the
//! regular IterDomain::merge operation which only works with regular
//! IterDomains.
static IterDomain* combine(IterDomain* component, RaggedIterDomain* ragged);

//! Override cloneWithoutRFactor to preserve RaggedIterDomain type
IterDomain* cloneWithoutRFactor(bool map_with_original = false) override;

Expand Down
27 changes: 27 additions & 0 deletions csrc/ir/internal_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2653,6 +2653,33 @@ std::string Partition::toInlineString(int indent_size) const {

NVFUSER_DEFINE_CLONE_AND_CREATE(Partition)

Combine::Combine(
IrBuilderPasskey passkey,
IterDomain* out,
IterDomain* component,
RaggedIterDomain* ragged)
: Expr(passkey) {
addOutput(out);
addInput(component);
addInput(ragged);
}

std::string Combine::toString(int indent_size) const {
std::stringstream ss;
ss << "Combine: ";
ss << "component: " << component()->toString();
ss << " + ragged: " << ragged()->toString();
ss << " -> " << out()->toString();
ss << "\n";
return ss.str();
}

std::string Combine::toInlineString(int indent_size) const {
NVF_CHECK(false, "Combine can not be printed inline");
}

NVFUSER_DEFINE_CLONE_AND_CREATE(Combine)

Swizzle::Swizzle(
IrBuilderPasskey passkey,
IterDomain* out_x,
Expand Down
38 changes: 38 additions & 0 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1845,6 +1845,44 @@ class NVF_API Partition : public Expr {
}
};

//! Combine a component IterDomain with a RaggedIterDomain to flatten
//! This is the inverse of Partition, merging component and ragged dimensions
//! into a single regular IterDomain
class NVF_API Combine : public Expr {
public:
using Expr::Expr;

Combine(
IrBuilderPasskey,
IterDomain* out,
IterDomain* component,
RaggedIterDomain* ragged);

NVFUSER_DECLARE_CLONE_AND_CREATE

const char* getOpString() const override {
return "Combine";
}

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;

//! Output IterDomain (combined/flattened dimension)
IterDomain* out() const {
return output(0)->as<IterDomain>();
}

//! Component dimension input (extent = num_components)
IterDomain* component() const {
return input(0)->as<IterDomain>();
}

//! Ragged dimension input (variable extents per component)
RaggedIterDomain* ragged() const {
return input(1)->as<RaggedIterDomain>();
}
};

class Swizzle : public Expr {
public:
using Expr::Expr;
Expand Down
Loading