diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 822ababb149..cee1fa911e2 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -116,6 +116,7 @@ class Val; f(ScanOp); \ f(Merge); \ f(Partition); \ + f(Combine); \ f(Swizzle); \ f(Swizzle2D); \ f(Resize); \ diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 698a047db6f..db995ec6a93 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -1054,6 +1054,106 @@ std::pair RaggedIterDomain::partition( return {component_id, ragged_id}; } +IterDomain* RaggedIterDomain::combine( + IterDomain* component, + RaggedIterDomain* ragged) { + NVF_ERROR(component != nullptr, "combine: component IterDomain is null"); + NVF_ERROR(ragged != nullptr, "combine: ragged IterDomain is null"); + + NVF_ERROR( + !component->isA(), + "combine: component must be a regular IterDomain, got RaggedIterDomain: ", + component->toString()); + + // Validate that component and ragged have compatible properties + NVF_ERROR_EQ( + component->getParallelType(), + ParallelType::Serial, + "Combining parallelized IterDomain not supported: ", + component->toString()); + + NVF_ERROR_EQ( + ragged->getParallelType(), + ParallelType::Serial, + "Combining parallelized RaggedIterDomain not supported: ", + ragged->toString()); + + NVF_ERROR_EQ( + component->getIterType(), + IterType::Iteration, + "combine: only IterType::Iteration is supported for component, got ", + component->getIterType(), + " for IterDomain: ", + component->toString()); + + NVF_ERROR_EQ( + ragged->getIterType(), + IterType::Iteration, + "combine: only IterType::Iteration is supported for ragged, got ", + ragged->getIterType(), + " for RaggedIterDomain: ", + ragged->toString()); + + // Validate component-ragged pairing when Partition definition is available + // (Option 3 of doc/dev/ragged_iter_domain_combine_design_doc.md). + // Only validate when the RaggedIterDomain has a direct Partition definition. + // After propagation (e.g., set() operations), the definition may be nullptr, + // in which case we trust the user to provide the correct component. + if (ragged->definition() != nullptr && + ragged->definition()->isA()) { + auto* partition = ragged->definition()->as(); + IterDomain* expected_component = partition->component(); + + NVF_ERROR( + component == expected_component, + "combine: component mismatch. The provided component does not match ", + "the component from the Partition that created this " + "RaggedIterDomain.\n", + " Provided component: ", + component->toString(), + "\n", + " Expected component: ", + expected_component->toString()); + } + // If no Partition definition (after set, in segmented fusion, or external + // input), trust the user and proceed without validation + + // The combined extent is the sum of all extents in the ragged dimension + // For a 1D extents tensor [e0, e1, ..., en-1], the total is sum(extents) + TensorView* extents_tv = ragged->extents(); + NVF_ERROR(extents_tv != nullptr, "combine: ragged extents tensor is null"); + + // It is still assumed the extents tensor is just 1D + NVF_ERROR_EQ( + std::ssize(extents_tv->getLogicalDomain()), + 1, + "Unexpected rank of extent tensor: ", + extents_tv->toString()); + + auto container = component->container(); + auto zero = container->zeroVal(DataType::Index); + + // Create a symbolic extent for the combined IterDomain + // This represents the sum of all ragged extents, i.e., + // sum(extents_tv, {0}). We could use the sum output as the extent + // but we would need to extract the scalar value out of the 0-dim + // tensor. For now, we leave it as a symbolic Val. + Val* combined_extent = + IrBuilder::createInContainer(container, DataType::Index); + + // Create the combined IterDomain with the symbolic extent + IterDomain* combined_id = IterDomainBuilder(zero, combined_extent) + .parallel_type(ParallelType::Serial) + .iter_type(IterType::Iteration) + .build(); + + // Create the Combine expression linking component + ragged -> combined + IrBuilder::createInContainer( + container, combined_id, component, ragged); + + return combined_id; +} + TensorDomain::TensorDomain( IrBuilderPasskey passkey, std::vector logical_domain, diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 9d40fcfbf1e..2d1872563e5 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -499,6 +499,22 @@ class NVF_API RaggedIterDomain : public IterDomain { IterDomain* in, TensorView* extents); + //! Combine a component IterDomain with a RaggedIterDomain to flatten + //! This is the inverse of partition, creating a regular IterDomain + //! + //! \param component Component IterDomain (extent = num_components) + //! \param ragged RaggedIterDomain with variable extents per component + //! \return Regular IterDomain with extent = sum of all component extents + //! + //! This operation flattens the ragged structure back into a single dimension. + //! Example: component extent=3, ragged extents=[127, 0, 198] + //! -> output extent = 325 (= 127 + 0 + 198) + //! + //! Note: We use "combine" instead of "merge" to differentiate from the + //! regular IterDomain::merge operation which only works with regular + //! IterDomains. + static IterDomain* combine(IterDomain* component, RaggedIterDomain* ragged); + //! Override cloneWithoutRFactor to preserve RaggedIterDomain type IterDomain* cloneWithoutRFactor(bool map_with_original = false) override; diff --git a/csrc/ir/internal_nodes.cpp b/csrc/ir/internal_nodes.cpp index 219d1669b48..116eadad676 100644 --- a/csrc/ir/internal_nodes.cpp +++ b/csrc/ir/internal_nodes.cpp @@ -2653,6 +2653,33 @@ std::string Partition::toInlineString(int indent_size) const { NVFUSER_DEFINE_CLONE_AND_CREATE(Partition) +Combine::Combine( + IrBuilderPasskey passkey, + IterDomain* out, + IterDomain* component, + RaggedIterDomain* ragged) + : Expr(passkey) { + addOutput(out); + addInput(component); + addInput(ragged); +} + +std::string Combine::toString(int indent_size) const { + std::stringstream ss; + ss << "Combine: "; + ss << "component: " << component()->toString(); + ss << " + ragged: " << ragged()->toString(); + ss << " -> " << out()->toString(); + ss << "\n"; + return ss.str(); +} + +std::string Combine::toInlineString(int indent_size) const { + NVF_CHECK(false, "Combine can not be printed inline"); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(Combine) + Swizzle::Swizzle( IrBuilderPasskey passkey, IterDomain* out_x, diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 9393dc3016b..863304ce1be 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -1845,6 +1845,44 @@ class NVF_API Partition : public Expr { } }; +//! Combine a component IterDomain with a RaggedIterDomain to flatten +//! This is the inverse of Partition, merging component and ragged dimensions +//! into a single regular IterDomain +class NVF_API Combine : public Expr { + public: + using Expr::Expr; + + Combine( + IrBuilderPasskey, + IterDomain* out, + IterDomain* component, + RaggedIterDomain* ragged); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return "Combine"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + + //! Output IterDomain (combined/flattened dimension) + IterDomain* out() const { + return output(0)->as(); + } + + //! Component dimension input (extent = num_components) + IterDomain* component() const { + return input(0)->as(); + } + + //! Ragged dimension input (variable extents per component) + RaggedIterDomain* ragged() const { + return input(1)->as(); + } +}; + class Swizzle : public Expr { public: using Expr::Expr; diff --git a/doc/dev/ragged_iter_domain_combine_design_doc.md b/doc/dev/ragged_iter_domain_combine_design_doc.md new file mode 100644 index 00000000000..fc05758e672 --- /dev/null +++ b/doc/dev/ragged_iter_domain_combine_design_doc.md @@ -0,0 +1,354 @@ + + +# Design Document: Component IterDomain Tracking for RaggedIterDomain + +## Problem Statement + +When calling `RaggedIterDomain::combine(component, ragged)`, we need to validate that the `component` IterDomain is the correct one that was originally paired with `ragged` during the Partition operation that created it. + +### The Challenge + +The naive approach of checking `ragged->definition()` for a Partition expression fails because: + +1. **Tensor-level operations break the definition chain**: Operations like `set()` create new TensorViews with new IterDomains +2. **IterDomains are propagated without definitions**: The new IterDomains are clones/descendants but don't have the original Partition as their definition +3. **The pairing information is lost**: After propagation, there's no explicit link between the ragged IterDomain and its paired component + +### Concrete Example + +```cpp +// tv0: [i0] - regular tensor +auto result = asNested(tv0, 0, extents); // Creates Partition expression +// result.ragged: RaggedIterDomain with Partition definition +// result.component: Component IterDomain + +auto tv1 = set(tv0); // Propagates IterDomains +// tv1 has a RaggedIterDomain, but it's a clone without Partition definition + +combine(result.component, tv1->getRaggedDomain()); // How do we validate? +``` + +## Design Alternatives + +### Option 1: Store Component Pointer in RaggedIterDomain + +**Approach**: Add a `component_` member variable to `RaggedIterDomain` that points to the paired component IterDomain. + +**Implementation**: +```cpp +class RaggedIterDomain : public IterDomain { + private: + TensorView* extents_ = nullptr; + IterDomain* component_ = nullptr; // NEW: paired component + + public: + IterDomain* component() const { return component_; } + void setComponent(IterDomain* component) { component_ = component; } +}; +``` + +**How It Works**: +1. When `Partition` creates a RaggedIterDomain, it sets the component pointer +2. When IterDomains are cloned (e.g., during `set()`), the component pointer is cloned/mapped too +3. In `combine()`, validate that the provided component matches `ragged->component()` + +**Pros**: +- ✅ Simple and direct solution +- ✅ Component pointer is automatically preserved during cloning +- ✅ Fast O(1) lookup - no graph traversal needed +- ✅ Follows existing pattern (similar to how `extents_` is stored) +- ✅ Self-documenting - makes the pairing explicit in the data structure + +**Cons**: +- ❌ **CRITICAL: Dependency ordering not guaranteed** - Since `component_` is not an input to RaggedIterDomain's definition, IR graph traversal (during lowering, cloning, replay) has no guarantee that the component IterDomain will be visited/cloned before the ragged IterDomain. This can lead to: + - Dangling pointers during cloning (trying to remap component before it's cloned) + - Incorrect mappings in IrCloner when component hasn't been processed yet + - Failures in topological traversal algorithms that expect dependencies to be explicit +- ❌ **Fragile during replacements** - When IterDomains are replaced (e.g., via `replaceAllUsesWith()`), the component pointer in ragged doesn't get updated automatically. Would require special-case handling throughout the codebase to maintain this hidden dependency. +- ❌ **Strong implicit coupling** - Creates a dependency that's not reflected in the IR graph structure, making the IR harder to reason about and maintain. Optimization passes and transformations that don't know about this hidden link could break the invariant. +- ❌ Component pointer could become stale if component IterDomain is replaced/transformed + +**Why This Is Problematic**: + +The fundamental issue is that this approach tries to store a relationship that *should* be part of the IR graph structure as an *out-of-band* pointer. nvFuser's IR infrastructure is designed around explicit dependency edges (inputs/outputs of expressions). Adding a pointer that doesn't follow these edges creates a parallel tracking mechanism that must be manually maintained across all IR operations: + +1. **IrCloner** would need to special-case the component pointer remapping, but it can't guarantee ordering +2. **replaceAllUsesWith()** and similar operations wouldn't know to update the component pointer +3. **Replay** operations that transform IterDomains wouldn't propagate the component link correctly +4. **Serialization/deserialization** would need special handling for this out-of-band pointer + +--- + +### Option 2: Traverse IR Graph to Find Original Partition + +**Approach**: Walk backward through the IterDomain definition chain to find the original Partition expression. + +**Implementation**: +```cpp +IterDomain* findOriginalComponent(RaggedIterDomain* ragged) { + // Traverse backward through set operations, clones, etc. + auto* current = ragged; + while (current != nullptr) { + if (current->definition() && current->definition()->isA()) { + return current->definition()->as()->component(); + } + // Follow the chain backward (e.g., through set operations) + current = getSourceIterDomain(current); + } + return nullptr; // No Partition found +} +``` + +**How It Works**: +1. Start from the given RaggedIterDomain +2. Traverse backward through the IR graph following definition chains +3. Find the original Partition expression +4. Extract the component from that Partition + +**Pros**: +- ✅ No additional memory overhead +- ✅ No new state to maintain +- ✅ Always finds the "true" original component by traversing the IR +- ✅ Component pointer can't become stale (computed on demand) + +**Cons**: +- ❌ **CRITICAL: Fusion segmentation breaks traversal** - When a fusion is segmented (split into multiple kernels), each segment contains only a subset of the full IR graph. A segment may contain a RaggedIterDomain that needs to be combined, but the original Partition expression that created it may be in a different segment. Traversal cannot cross segment boundaries, making it impossible to find the original component. +- ❌ **CRITICAL: External ragged tensors have no Partition** - When RaggedIterDomain support is extended in the future to accept ragged tensors from PyTorch as fusion inputs, these would arrive as RaggedIterDomains without any Partition expression in the nvFuser IR. There would be nothing to traverse back to, yet we still need to know the component for validation. +- ❌ **Unreliable chain traversal** - Even when a Partition exists in the same segment, the definition chain can be broken or complex: + - Operations like `set()` intentionally break the definition chain + - Multiple paths back through different transformations + - Split/merge operations on the path complicate tracking +- ❌ Requires IR graph traversal - O(n) where n is chain depth +- ❌ Complex implementation - need to handle all propagation patterns (set, replay, clone, etc.) +- ❌ Performance cost on every combine() call + +**Why This Is Problematic**: + +This approach assumes the Partition expression is always reachable, but there are fundamental scenarios where it isn't: + +1. **Segmented Fusions**: nvFuser segments complex fusions into multiple kernels. Each segment is scheduled and lowered independently. A RaggedIterDomain in segment N may have been created by a Partition in segment M, but segment boundaries are opaque - you can't traverse across them. + +2. **Future External Inputs**: When RaggedIterDomain support is extended to accept ragged tensors from PyTorch as fusion inputs, these RaggedIterDomains will have no corresponding nvFuser Partition expression. They represent already-partitioned data from outside nvFuser. + +3. **Definition Chain Breaks**: Even within a segment, operations like `set()` intentionally create new IterDomains without definitions, breaking the chain. + +The fundamental flaw is assuming component information can be recovered from the IR graph structure, when in reality the information may not exist in the graph at all. + +--- + +### Option 3: Track Component in Partition Expression Only + +**Approach**: Only validate when a direct Partition definition exists, otherwise trust the user. + +**Implementation**: +```cpp +void combine(IterDomain* component, RaggedIterDomain* ragged) { + // Only validate if we can find a Partition + if (ragged->definition() && ragged->definition()->isA()) { + auto* partition = ragged->definition()->as(); + NVF_ERROR(component == partition->component(), + "Component doesn't match partition"); + } + // Otherwise, no validation - trust the user + + // Proceed with combine... +} +``` + +**How It Works**: +1. Check if ragged has a Partition definition +2. If yes, validate the component +3. If no, skip validation and trust the user provided the correct component + +**Pros**: +- ✅ Minimal implementation - no new infrastructure +- ✅ No memory overhead +- ✅ Simple to understand +- ✅ Validation when possible, permissive when not + +**Cons**: +- ⚠️ Validation is incomplete - only validates when Partition definition is directly available +- ⚠️ After propagation operations (set, segmentation), relies on user correctness + +**Why This Is Actually Reasonable**: + +This approach aligns with how nvFuser handles other operations: +- **Arithmetic operations** (add, mul, etc.) assume inputs have matching shapes - they don't validate +- **User responsibility**: If users call `combine(component, ragged)`, we trust they're providing the correct component +- **Validation where possible**: When we CAN validate (Partition definition exists), we do +- **Fail-fast when detectable**: Catches obvious errors early in the fusion definition +- **Pragmatic**: Acknowledges that complete validation isn't feasible given segmentation and external inputs + +The key insight is that `combine()` is a user-facing API. Users are expected to know which component pairs with which ragged domain, just as they're expected to know when tensor shapes are compatible for arithmetic operations. + +--- + +### Option 4: Store Component Pairing in TensorDomain + +**Approach**: Store component-ragged pairings in TensorDomain rather than in RaggedIterDomain itself. + +**Implementation**: +```cpp +// In TensorDomain +class TensorDomain { + private: + std::vector logical_domain_; + // Other domain vectors... + + // NEW: Track ragged-component pairings for IterDomains in this TensorDomain + struct RaggedComponentPair { + RaggedIterDomain* ragged; + IterDomain* component; + }; + std::vector ragged_component_pairs_; + + public: + // Get the component for a ragged IterDomain in this TensorDomain + IterDomain* getComponentFor(RaggedIterDomain* ragged) const; + + // Register a ragged-component pairing (called when creating from Partition) + void registerRaggedComponentPair(RaggedIterDomain* ragged, IterDomain* component); +}; +``` + +**How It Works**: +1. When Partition creates a TensorView with ragged and component IterDomains, register the pairing in the TensorDomain +2. The pairing is stored alongside the IterDomains themselves, ensuring both ragged and component are in `allIds()` +3. When tensor operations (like `set()`) propagate TensorDomains, they also propagate the pairing information +4. In `combine()`, look up the component from the TensorView's TensorDomain + +**Pros**: +- ✅ **Looser coupling**: The relationship is stored in TensorDomain, not in RaggedIterDomain itself +- ✅ **Follows containment**: TensorDomain already owns and manages its IterDomains, so it's natural to manage their relationships +- ✅ **Explicit in domain operations**: Operations that propagate TensorDomain can explicitly propagate pairings +- ✅ **Validates across propagation**: Works even after `set()` if the pairing is propagated correctly +- ✅ **Both IDs guaranteed present**: Since both must be in `allIds()`, dependency ordering is less problematic + +**Cons**: +- ❌ **Propagation must be explicit**: Every operation that creates/clones TensorDomain must handle pairing propagation +- ❌ **More complex than Option 3**: Requires changes to TensorDomain and all operations that manipulate it +- ❌ **Still has propagation challenges**: Operations like replay, resize, or transformations need to update pairings +- ❌ **Segmentation issues remain**: After fusion segmentation, TensorDomain in one segment may not have the original pairing information + +**Key Challenge**: + +The main implementation challenge is ensuring pairing propagation through all tensor operations: +- `set()`: Must copy pairings from input TensorDomain to output +- `view/reshape`: Must map pairings through transformations +- Replay operations: Must track how ragged and component are transformed +- Cloning: Must clone pairings along with IterDomains + +**Why This Is Better Than Option 1**: + +Unlike storing the pointer in RaggedIterDomain: +- TensorDomain already manages relationships between IterDomains (root→logical→allocation mappings) +- Both ragged and component are explicitly part of the domain, reducing implicit dependencies +- The coupling is at the TensorDomain level, not at the individual IterDomain level + +**Why This May Not Be Worth It**: + +While architecturally cleaner than Option 1, it's still significantly more complex than Option 3: +- Requires modifying TensorDomain and many tensor operations +- Still doesn't solve segmentation (segments may not preserve original TensorDomain) +- Adds complexity for validation that may not be critical (users can track pairings) + +If Option 3's "trust the user" approach is sufficient, Option 4's additional complexity may not be justified. + +--- + +## Analysis Summary + +### Why Options 1 & 2 Are Not Viable + +**Option 1 (Stored Pointer)**: Fundamentally flawed due to dependency ordering. The component pointer would be an out-of-band dependency not reflected in the IR graph. IR traversal algorithms follow explicit input/output edges, with no guarantee that component will be processed before ragged during cloning/lowering/replay. Violates nvFuser's design principle of explicit dependency edges. + +**Option 2 (IR Traversal)**: Fails in two critical scenarios: +1. **Fusion Segmentation**: Partition expression may be in a different segment, unreachable via traversal +2. **Future External Inputs**: When RaggedIterDomain support is extended to accept ragged tensors from PyTorch as fusion inputs, these will have no nvFuser Partition expression to traverse to + +These aren't edge cases - they're fundamental use cases that must be supported. + +### Why Option 3 Is The Pragmatic Choice + +**Option 3** aligns with nvFuser's design philosophy: like arithmetic operations that assume shape compatibility, `combine()` trusts users to provide correct inputs. It validates when Partition definition exists but otherwise relies on user correctness. Simple to implement, handles all use cases (propagation, segmentation, external inputs), and acknowledges that complete validation is architecturally infeasible. + +**Option 4 (TensorDomain Pairing)** is architecturally cleaner than Option 1 (looser coupling) but requires extensive changes to TensorDomain operations and still has segmentation issues. Could be a future enhancement if user errors become problematic, but Option 3's simplicity is preferred for now. + +## Recommendation + +### Proposed Solution: **Option 3 - Validate When Partition Definition Exists** + +**This is the current design choice.** We will reconsider Option 4 (TensorDomain Pairing) if it proves more appropriate based on practical experience or future requirements. + +**Rationale**: + +Option 3 is the most reasonable approach because it: + +1. **Aligns with nvFuser's design philosophy**: Like arithmetic operations that assume shape compatibility, `combine()` trusts users to provide correct inputs +2. **Provides validation where feasible**: When a Partition definition is directly accessible, we validate the component +3. **Simple and maintainable**: No complex infrastructure, no global state, no dependency ordering issues +4. **Handles all use cases**: Works for direct Partition usage, propagated domains, segmented fusions, and future external inputs +5. **Pragmatic**: Acknowledges that complete validation is architecturally infeasible + +**Implementation**: + +```cpp +void combine(IterDomain* component, RaggedIterDomain* ragged) { + // Basic validation (null checks, type checks, etc.) + NVF_ERROR(component != nullptr && ragged != nullptr, "Null inputs"); + NVF_ERROR(!component->isRaggedDomain(), "Component must be regular IterDomain"); + + // Validate against Partition definition if available + if (ragged->definition() && ragged->definition()->isA()) { + auto* partition = ragged->definition()->as(); + NVF_ERROR( + component == partition->component(), + "Component mismatch: provided ", component->toString(), + " but Partition expects ", partition->component()->toString()); + } + + // If no Partition definition (after set, in segmented fusion, or external input), + // trust the user and proceed + + // Create combined IterDomain... +} +``` + +**What This Means**: + +- ✅ Early error detection when Partition definition is available +- ✅ No architectural violations or fragile infrastructure +- ✅ Users are responsible for correct usage (like other operations) +- ✅ Works across all scenarios (propagation, segmentation, external inputs) +- ⚠️ After propagation/segmentation, incorrect usage won't be caught by validation +- ⚠️ Users must track component-ragged pairings themselves + +**Comparison to Other Operations**: + +This is consistent with how nvFuser handles other operations: +- `add(tv1, tv2)` doesn't validate that shapes match - user responsibility +- `set(tv)` doesn't validate all properties - user responsibility +- `combine(component, ragged)` doesn't always validate pairing - user responsibility + +## Implementation Notes + +1. **Testing Strategy**: + - Test validation when Partition definition exists (should catch errors) + - Test that validation is skipped after `set()` operations (should succeed with correct usage) + - Document user responsibility in API documentation + +2. **Future Considerations**: + - Option 4 (TensorDomain Pairing) remains a viable alternative if the current approach proves insufficient + - We will reconsider Option 4 based on practical experience, user feedback, or new requirements + - If incorrect `combine()` usage becomes a common source of bugs, we can implement Option 4's more comprehensive validation + - For now, follow the principle of trusting user-facing APIs + - The `extents_` pointer handling may also need similar considerations in the future + +3. **Documentation**: + - Clearly document that users must provide the correct component that was paired with the ragged domain + - Note that validation is best-effort and may not catch all errors + - Provide examples of correct usage patterns diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index bd3c88fe748..f8142bf4c9c 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -340,6 +340,101 @@ TEST_F(RaggedIterDomainTest, TensorViewPartition) { EXPECT_EQ(tv0->axis(0)->definition(), tv0->axis(1)->definition()); } +// Test combining component and ragged IterDomains (inverse of partition) +TEST_F(RaggedIterDomainTest, CombineBasic) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create extents tensor + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + + // Create a regular IterDomain to partition + auto orig_id = IterDomainBuilder( + fusion.zeroVal(DataType::Index), + IrBuilder::create(325L, DataType::Index)) + .build(); + + // Partition into component and ragged + auto [component_id, ragged_id] = + RaggedIterDomain::partition(orig_id, extents); + + // Verify partition worked + EXPECT_NE(component_id, nullptr); + EXPECT_NE(ragged_id, nullptr); + EXPECT_TRUE(component_id->isA()); + EXPECT_TRUE(ragged_id->isA()); + + // Now combine them back + auto combined_id = RaggedIterDomain::combine(component_id, ragged_id); + + // Verify combine worked + EXPECT_NE(combined_id, nullptr); + EXPECT_TRUE(combined_id->isA()); + EXPECT_FALSE(combined_id->isA()); + + // Verify the combine has a definition (Combine expr) + EXPECT_NE(combined_id->definition(), nullptr); + EXPECT_TRUE(combined_id->definition()->isA()); + + // Verify the Combine expression has correct inputs + auto combine_expr = combined_id->definition()->as(); + EXPECT_EQ(combine_expr->component(), component_id); + EXPECT_EQ(combine_expr->ragged(), ragged_id); + EXPECT_EQ(combine_expr->out(), combined_id); +} + +// Test combine validation: null component +TEST_F(RaggedIterDomainTest, CombineValidationNullComponent) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + + auto ragged_id = + IrBuilder::create(extents, IterType::Iteration); + + // Should fail with null component + EXPECT_THROW( + RaggedIterDomain::combine(nullptr, ragged_id), nvfuser::nvfError); +} + +// Test combine validation: null ragged +TEST_F(RaggedIterDomainTest, CombineValidationNullRagged) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto component_id = IterDomainBuilder( + fusion.zeroVal(DataType::Index), + IrBuilder::create(3L, DataType::Index)) + .build(); + + // Should fail with null ragged + EXPECT_THROW( + RaggedIterDomain::combine(component_id, nullptr), nvfuser::nvfError); +} + +// Test combine validation: component is RaggedIterDomain +TEST_F(RaggedIterDomainTest, CombineValidationComponentIsRagged) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto extents1 = makeSymbolicTensor(1, DataType::Index); + auto extents2 = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents1); + fusion.addInput(extents2); + + auto ragged_id1 = + IrBuilder::create(extents1, IterType::Iteration); + auto ragged_id2 = + IrBuilder::create(extents2, IterType::Iteration); + + // Should fail when component is also RaggedIterDomain + EXPECT_THROW( + RaggedIterDomain::combine(ragged_id1, ragged_id2), nvfuser::nvfError); +} + // asNested basic functionality TEST_F(RaggedIterDomainTest, AsNestedBasic) { Fusion fusion; @@ -386,6 +481,147 @@ TEST_F(RaggedIterDomainTest, AsNestedBasic) { EXPECT_EQ(nested->axis(0)->definition(), nested->axis(1)->definition()); } +// Test combining nested tensor back to normal tensor +TEST_F(RaggedIterDomainTest, AsNestedThenCombine) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + + // Create nested tensor from dimension 0 + auto nested = asNested(data, extents, 0); + + // Verify nested tensor has 3 dimensions: [component, ragged, original_dim1] + EXPECT_EQ(nested->nDims(), 3); + EXPECT_TRUE(nested->axis(0)->isStrictlyA()); + EXPECT_TRUE(nested->axis(1)->isA()); + + // Get the component and ragged IterDomains + auto component_id = nested->axis(0); + auto ragged_id = nested->axis(1)->as(); + + // Combine them back into a normal IterDomain + auto combined_id = RaggedIterDomain::combine(component_id, ragged_id); + + // Verify the combined IterDomain is a regular IterDomain, not ragged + EXPECT_NE(combined_id, nullptr); + EXPECT_TRUE(combined_id->isStrictlyA()); + EXPECT_FALSE(combined_id->isA()); + + // Verify the combined IterDomain has a Combine definition + EXPECT_NE(combined_id->definition(), nullptr); + EXPECT_TRUE(combined_id->definition()->isA()); + + // Verify the Combine expression has correct inputs + auto combine_expr = combined_id->definition()->as(); + EXPECT_EQ(combine_expr->component(), component_id); + EXPECT_EQ(combine_expr->ragged(), ragged_id); + EXPECT_EQ(combine_expr->out(), combined_id); + + // Verify that the component came from the same Partition as the ragged + EXPECT_NE(component_id->definition(), nullptr); + EXPECT_TRUE(component_id->definition()->isA()); + EXPECT_EQ(component_id->definition(), ragged_id->definition()); + + auto partition_expr = component_id->definition()->as(); + EXPECT_EQ(partition_expr->component(), component_id); + EXPECT_EQ(partition_expr->ragged(), ragged_id); +} + +// Test combining nested tensor back to normal tensor after set operation +TEST_F(RaggedIterDomainTest, AsNestedThenSetThenCombine) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + + // Create nested tensor from dimension 0 + auto nested = asNested(data, extents, 0); + + // Insert a set operation after asNested + auto nested_copy = set(nested); + + // Verify nested_copy tensor has 3 dimensions: [component, ragged, + // original_dim1] + EXPECT_EQ(nested_copy->nDims(), 3); + EXPECT_TRUE(nested_copy->axis(0)->isStrictlyA()); + EXPECT_TRUE(nested_copy->axis(1)->isA()); + + // Get the component and ragged IterDomains from the copy + auto component_id = nested_copy->axis(0); + auto ragged_id = nested_copy->axis(1)->as(); + + // Combine them back into a normal IterDomain. Even though + // component_id and ragged_id are not directly produced by a + // partition, this should succeed. See the next test + // (AsNestedThenSetThenCombineInvalidComponent) for a failing example. + auto combined_id = RaggedIterDomain::combine(component_id, ragged_id); + + // Verify the combined IterDomain is a regular IterDomain, not ragged + EXPECT_NE(combined_id, nullptr); + EXPECT_TRUE(combined_id->isStrictlyA()); + EXPECT_FALSE(combined_id->isA()); + + // Verify the combined IterDomain has a Combine definition + EXPECT_NE(combined_id->definition(), nullptr); + EXPECT_TRUE(combined_id->definition()->isA()); + + // Verify the Combine expression has correct inputs + auto combine_expr = combined_id->definition()->as(); + EXPECT_EQ(combine_expr->component(), component_id); + EXPECT_EQ(combine_expr->ragged(), ragged_id); + EXPECT_EQ(combine_expr->out(), combined_id); +} + +// Test combining with invalid component (not from same partition) - should +// Test combining after set operation with invalid component +// With Option 3 validation strategy, this does NOT throw an error +// because after set(), the RaggedIterDomain loses its Partition definition +// and validation is skipped (trusts the user) +TEST_F(RaggedIterDomainTest, AsNestedThenSetThenCombineInvalidComponent) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + + // Create nested tensor from dimension 0 + auto nested = asNested(data, extents, 0); + + // Insert a set operation after asNested + auto nested_copy = set(nested); + + // Verify nested_copy tensor has 3 dimensions: [component, ragged, + // original_dim1] + EXPECT_EQ(nested_copy->nDims(), 3); + EXPECT_TRUE(nested_copy->axis(0)->isStrictlyA()); + EXPECT_TRUE(nested_copy->axis(1)->isA()); + + // Get the ragged IterDomain from the copy + auto ragged_id = nested_copy->axis(1)->as(); + + // Use an INVALID component: the third axis instead of the first + // This is NOT the component from the partition, it's the original second + // dimension + auto invalid_component_id = nested_copy->axis(2); + + // With Option 3: After set(), the RaggedIterDomain no longer has a + // Partition definition, so validation is skipped and the operation succeeds. + // The user is responsible for providing the correct component. + EXPECT_NO_THROW(RaggedIterDomain::combine(invalid_component_id, ragged_id)); +} + // asNested on different dimensions TEST_F(RaggedIterDomainTest, AsNestedDifferentDimension) { Fusion fusion;