From d87e6d7b36784ff4a5135ea082c8105ff8eb2b8e Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 10:30:14 -0800 Subject: [PATCH 01/31] Initial introduction of RaggedIterDomain --- CMakeLists.txt | 1 + csrc/dispatch.h | 1 + csrc/ir/graphviz.cpp | 16 +++ csrc/ir/graphviz.h | 1 + csrc/ir/internal_base_nodes.cpp | 78 ++++++++++++ csrc/ir/internal_base_nodes.h | 47 ++++++++ csrc/mutator.cpp | 25 ++++ csrc/type.cpp | 2 + csrc/type.h | 1 + tests/cpp/test_ragged_iter_domain.cpp | 166 ++++++++++++++++++++++++++ 10 files changed, 338 insertions(+) create mode 100644 tests/cpp/test_ragged_iter_domain.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 5c40ab5a615..8a5170b2e8c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -991,6 +991,7 @@ list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/tests/cpp/test_polymorphic_value.cpp ${NVFUSER_ROOT}/tests/cpp/test_predicate_elimination.cpp ${NVFUSER_ROOT}/tests/cpp/test_preseg_passes.cpp + ${NVFUSER_ROOT}/tests/cpp/test_ragged_iter_domain.cpp ${NVFUSER_ROOT}/tests/cpp/test_reduction.cpp ${NVFUSER_ROOT}/tests/cpp/test_reduction_pointwise.cpp ${NVFUSER_ROOT}/tests/cpp/test_remove_bcast_squeeze.cpp diff --git a/csrc/dispatch.h b/csrc/dispatch.h index c2f235f8aab..f5d0c6a10f9 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -62,6 +62,7 @@ class Val; #define DISPATCH_FOR_ALL_VALS(f) \ f(IterDomain); \ + f(RaggedIterDomain); \ f(TensorDomain); \ f(TensorView); \ f(NamedScalar); diff --git a/csrc/ir/graphviz.cpp b/csrc/ir/graphviz.cpp index 7cbd23f7dd3..306abe32784 100644 --- a/csrc/ir/graphviz.cpp +++ b/csrc/ir/graphviz.cpp @@ -68,6 +68,14 @@ class IrNodeLabel final : private OptInConstDispatch { label_ << ")"; } + void handle(const RaggedIterDomain* id) override { + label_ << "Ragged" << id->getIterType(); + label_ << id->getParallelType(); + label_ << "(extents="; + label_ << IrNodeLabel::gen(id->extents()); + label_ << ")"; + } + void handle(const Split* split) override { label_ << "Split(inner=" << (split->innerSplit() ? "true" : "false") << ", factor=" << IrNodeLabel::gen(split->factor()) << ")"; @@ -356,6 +364,14 @@ void IrGraphGenerator::handle(const IterDomain* id) { addArc(id->extent(), id, "[color=gray]"); } +void IrGraphGenerator::handle(const RaggedIterDomain* id) { + graph_def_ << " " << getid(id) << " [label=\"" << IrNodeLabel::gen(id) + << "\", shape=cds, color=orange, fontsize=10];\n"; + + // Add arc from extents tensor to the ragged dimension + addArc(id->extents(), id, "[color=orange]"); +} + void IrGraphGenerator::handle(const Val* s) { printValue(s, IrNodeLabel::gen(s, detail_level_)); } diff --git a/csrc/ir/graphviz.h b/csrc/ir/graphviz.h index 49c0991044d..788f533b608 100644 --- a/csrc/ir/graphviz.h +++ b/csrc/ir/graphviz.h @@ -80,6 +80,7 @@ class IrGraphGenerator : private OptInConstDispatch { void handle(const TensorDomain*) override; void handle(const TensorView*) override; void handle(const IterDomain*) override; + void handle(const RaggedIterDomain*) override; void handle(const Val*) override; void handle(const NamedScalar*) override; diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index bb7cdba891c..e88a6d24ded 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -787,6 +787,84 @@ void validateLoopDomain( } // namespace +RaggedIterDomain::RaggedIterDomain( + IrBuilderPasskey passkey, + TensorView* extents, + IterType iter_type, + ParallelType parallel_type) + : IterDomain( + passkey, + /*start=*/passkey.ir_container_->zeroVal(), + /*extent=*/passkey.ir_container_->oneVal(), // Placeholder + /*expanded_extent=*/nullptr, + /*stop_offset=*/nullptr, + parallel_type, + iter_type, + /*is_rfactor_domain=*/false, + /*is_padded_dimension=*/false, + /*is_clustered_blocks=*/false, + /*padded_to_size=*/std::nullopt), + extents_(extents) { + // Extents must be non-null + NVF_ERROR( + extents_ != nullptr, "RaggedIterDomain requires non-null extents tensor"); + + // Extents must have integer dtype + NVF_ERROR_EQ( + extents_->dtype(), + DataType::Index, + "RaggedIterDomain extents must have index type, got ", + extents_->dtype()); + + // Only IterType::Iteration is supported at this moment + NVF_ERROR_EQ( + iter_type, + IterType::Iteration, + "Only IterType::Iteration is supported: ", + iter_type); +} + +RaggedIterDomain::RaggedIterDomain( + const RaggedIterDomain* src, + IrCloner* ir_cloner) + : IterDomain(src, ir_cloner), extents_(ir_cloner->clone(src->extents_)) {} + +NVFUSER_DEFINE_CLONE(RaggedIterDomain) + +bool RaggedIterDomain::sameAs(const Statement* other) const { + if (this == other) { + return true; + } + + if (!other->isA()) { + return false; + } + + auto other_ragged = other->as(); + + // Compare parent IterDomain properties + if (!IterDomain::sameAs(other)) { + return false; + } + + // Compare extents tensor + return extents_->sameAs(other_ragged->extents_); +} + +std::string RaggedIterDomain::toString(int indent_size) const { + std::stringstream ss; + ss << "iRagged{"; + ss << "extents=" << extents_->toString(); + ss << ", iter_type=" << getIterType(); + ss << ", parallel_type=" << getParallelType(); + ss << "}"; + return ss.str(); +} + +std::string RaggedIterDomain::toInlineString(int indent_size) const { + return toString(indent_size); +} + TensorDomain::TensorDomain( IrBuilderPasskey passkey, std::vector logical_domain, diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index a41466a34a7..a39bac1d00c 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -25,6 +25,8 @@ namespace nvfuser { // Friends for direct access to split class TensorDomain; class IterDomain; +class RaggedIterDomain; +class TensorView; class ReplayTransformations; class IndexReferenceReplay; class ViewTransform; @@ -418,6 +420,51 @@ class NVF_API IterDomain : public Val { std::optional padded_to_size_ = std::nullopt; }; +//! RaggedIterDomain represents a dimension with variable extents +//! (ragged/jagged dimension). Used for PyTorch nested tensors. +//! Unlike IterDomain, the extent varies per component +//! and is stored as a TensorView rather than a single Val. +//! +//! Key properties: +//! - extents_: TensorView containing extent for each component (1D, 2D, or N-D) +//! - Uniform execution properties: ParallelType, IterType apply to all +//! components +class NVF_API RaggedIterDomain : public IterDomain { + public: + //! \param extents TensorView containing component extents (must be integer + //! type) + //! \param iter_type Iteration type (Iteration, Reduction, etc.) + //! Only Iteration is allowed ATM. + //! \param parallel_type Parallelization strategy (applies + //! uniformly) + RaggedIterDomain( + IrBuilderPasskey passkey, + TensorView* extents, + IterType iter_type = IterType::Iteration, + ParallelType parallel_type = ParallelType::Serial); + + //! Cloning constructor for IR cloning + RaggedIterDomain(const RaggedIterDomain* src, IrCloner* ir_cloner); + + NVFUSER_DECLARE_CLONE + + bool sameAs(const Statement* other) const override; + + std::string toString(int indent_size = 0) const override; + + std::string toInlineString(int indent_size = 0) const override; + + //! Accessor for the extents tensor + TensorView* extents() const { + return extents_; + } + + private: + //! Extent tensor containing all component extents + //! Can be 1D, 2D, or N-D depending on nesting structure + TensorView* extents_ = nullptr; +}; + //! TensorDomain holds a vector of IterDomains. It holds an IterDomain for every //! logical axis in its associated tensor. TensorDomain does not directly hold //! the Tensor it is associated with, and in theory could be associated with diff --git a/csrc/mutator.cpp b/csrc/mutator.cpp index 5d586e303ac..6106a853ba2 100644 --- a/csrc/mutator.cpp +++ b/csrc/mutator.cpp @@ -134,6 +134,31 @@ void OptOutMutator::mutate(IterDomain* id) { } } +void OptOutMutator::mutate(RaggedIterDomain* id) { + // Mutate the extents TensorView + auto mutated_extents = maybeMutated(id->extents()); + + // Check if anything changed + if (mutated_extents->sameAs(id->extents())) { + return; + } + + // Create a new RaggedIterDomain with mutated extents + auto new_id = IrBuilder::createInContainer( + id->container(), + mutated_extents->as(), + id->getIterType(), + id->getParallelType()); + + // Register the mutation + registerMutation(id, new_id); + + // Preserve definition if it exists + if (Expr* def = id->definition()) { + mutateExprOutputsOnly(def); + } +} + void OptOutMutator::mutate(TensorDomain* td) { bool mutated = false; diff --git a/csrc/type.cpp b/csrc/type.cpp index 02ea6a9cd5a..8d8eea0f62b 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -322,6 +322,8 @@ static const char* val_type2string(ValType t) { return "TensorDomain"; case ValType::IterDomain: return "IterDomain"; + case ValType::RaggedIterDomain: + return "RaggedIterDomain"; case ValType::Others: return "Scalar"; case ValType::NamedScalar: diff --git a/csrc/type.h b/csrc/type.h index b011976fe83..9e91909c09d 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -45,6 +45,7 @@ namespace nvfuser { enum class ValType { TensorDomain, IterDomain, + RaggedIterDomain, TensorView, NamedScalar, Predicate, diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp new file mode 100644 index 00000000000..7c3596152c4 --- /dev/null +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -0,0 +1,166 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include +#include +#include +#include +#include + +namespace nvfuser { + +using RaggedIterDomainTest = NVFuserTest; + +// Basic construction of RaggedIterDomain +TEST_F(RaggedIterDomainTest, Construction) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create a TensorView to use as extents + // This represents component sizes [3, 5, 2] + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + + // Create RaggedIterDomain + auto ragged_id = IrBuilder::create( + extents, IterType::Iteration, ParallelType::Serial); + + // Verify properties + EXPECT_NE(ragged_id, nullptr); + EXPECT_TRUE(ragged_id->isA()); + EXPECT_EQ(ragged_id->getIterType(), IterType::Iteration); + EXPECT_EQ(ragged_id->getParallelType(), ParallelType::Serial); + EXPECT_EQ(ragged_id->extents(), extents); + EXPECT_FALSE(ragged_id->isRFactorProduct()); + + // Verify extent is not null (it's the sum of extents) + EXPECT_NE(ragged_id->extent(), nullptr); +} + +// RaggedIterDomain with parallelization +TEST_F(RaggedIterDomainTest, Parallelization) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + + // Create with TIDx parallelization + auto ragged_parallel = IrBuilder::create( + extents, IterType::Iteration, ParallelType::TIDx); + + EXPECT_EQ(ragged_parallel->getParallelType(), ParallelType::TIDx); + EXPECT_TRUE(ragged_parallel->isThreadDim()); + + // Test that parallelize method works (inherited from IterDomain) + ragged_parallel->parallelize(ParallelType::TIDy); + EXPECT_EQ(ragged_parallel->getParallelType(), ParallelType::TIDy); +} + +// sameAs comparison +TEST_F(RaggedIterDomainTest, SameAsComparison) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto extents1 = makeSymbolicTensor(1, DataType::Index); + auto extents2 = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents1); + fusion.addInput(extents2); + + auto ragged1 = IrBuilder::create( + extents1, IterType::Iteration, ParallelType::Serial); + + auto ragged3 = IrBuilder::create( + extents2, // Different extents + IterType::Iteration, + ParallelType::Serial); + + // Same object + EXPECT_TRUE(ragged1->sameAs(ragged1)); + + // Different extents + EXPECT_FALSE(ragged1->sameAs(ragged3)); + + // RaggedIterDomain vs regular IterDomain + auto regular_id = + IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) + .build(); + EXPECT_FALSE(ragged1->sameAs(regular_id)); +} + +// Printing/toString +TEST_F(RaggedIterDomainTest, Printing) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + + auto ragged_id = IrBuilder::create( + extents, IterType::Iteration, ParallelType::TIDx); + + // Print it + std::string str = ragged_id->toString(); + + // Verify output contains expected elements + EXPECT_NE(str.find("iRagged"), std::string::npos); + EXPECT_NE(str.find("extents"), std::string::npos); + + // Also test toInlineString + std::string inline_str = ragged_id->toInlineString(); + EXPECT_FALSE(inline_str.empty()); +} + +// Multi-dimensional extents tensor +TEST_F(RaggedIterDomainTest, MultiDimensionalExtents) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create 2D extents tensor for nested ragged structure + auto extents_2d = makeSymbolicTensor(2, DataType::Index); + fusion.addInput(extents_2d); + + auto ragged_nested = IrBuilder::create( + extents_2d, IterType::Iteration, ParallelType::Serial); + + EXPECT_NE(ragged_nested, nullptr); + EXPECT_EQ(ragged_nested->extents(), extents_2d); +} + +// Validation - null extents should fail +TEST_F(RaggedIterDomainTest, ValidationNullExtents) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Attempt to create with null extents should throw + EXPECT_THROW( + IrBuilder::create( + nullptr, // null extents + IterType::Iteration, + ParallelType::Serial), + nvfuser::nvfError); +} + +// Validation - non-integer extents should fail +TEST_F(RaggedIterDomainTest, ValidationNonIntegerExtents) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create float extents (should fail) + auto float_extents = makeSymbolicTensor(1, DataType::Float); + fusion.addInput(float_extents); + + EXPECT_THROW( + IrBuilder::create( + float_extents, IterType::Iteration, ParallelType::Serial), + nvfuser::nvfError); +} + +} // namespace nvfuser From f16fc4d1f92bdcdf12dbb2fe723af12f921f29a5 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 10:42:54 -0800 Subject: [PATCH 02/31] cleanup --- csrc/ir/internal_base_nodes.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index e88a6d24ded..b9ff9b02681 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -853,10 +853,11 @@ bool RaggedIterDomain::sameAs(const Statement* other) const { std::string RaggedIterDomain::toString(int indent_size) const { std::stringstream ss; - ss << "iRagged{"; + ss << getIterType(); + ss << getParallelType(); + ss << name(); + ss << "Ragged{"; ss << "extents=" << extents_->toString(); - ss << ", iter_type=" << getIterType(); - ss << ", parallel_type=" << getParallelType(); ss << "}"; return ss.str(); } From 23d55f15df8b041271ab202929e213b950d2e0a3 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 11:40:11 -0800 Subject: [PATCH 03/31] fix --- tests/cpp/test_ragged_iter_domain.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 7c3596152c4..249d31afed3 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -110,7 +110,7 @@ TEST_F(RaggedIterDomainTest, Printing) { std::string str = ragged_id->toString(); // Verify output contains expected elements - EXPECT_NE(str.find("iRagged"), std::string::npos); + EXPECT_NE(str.find("Ragged"), std::string::npos); EXPECT_NE(str.find("extents"), std::string::npos); // Also test toInlineString From 8392332ab5316fb58add8eae53e7700460e5a7dc Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 11:48:56 -0800 Subject: [PATCH 04/31] fix --- csrc/ir/internal_base_nodes.cpp | 30 +++++++++++++++++++++++++++++- csrc/ir/internal_base_nodes.h | 17 +++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index b9ff9b02681..abb5db26d2c 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -130,7 +130,34 @@ IterDomain::IterDomain( bool is_padded_dimension, bool is_clustered_blocks, std::optional padded_to_size) - : Val(passkey, ValType::IterDomain), + : IterDomain( + passkey, + ValType::IterDomain, + start, + extent, + expanded_extent, + stop_offset, + parallel_type, + iter_type, + is_rfactor_domain, + is_padded_dimension, + is_clustered_blocks, + padded_to_size) {} + +IterDomain::IterDomain( + IrBuilderPasskey passkey, + ValType vtype, + Val* start, + Val* extent, + Val* expanded_extent, + Val* stop_offset, + ParallelType parallel_type, + IterType iter_type, + bool is_rfactor_domain, + bool is_padded_dimension, + bool is_clustered_blocks, + std::optional padded_to_size) + : Val(passkey, vtype), start_(start), extent_(extent), expanded_extent_(expanded_extent), @@ -794,6 +821,7 @@ RaggedIterDomain::RaggedIterDomain( ParallelType parallel_type) : IterDomain( passkey, + ValType::RaggedIterDomain, /*start=*/passkey.ir_container_->zeroVal(), /*extent=*/passkey.ir_container_->oneVal(), // Placeholder /*expanded_extent=*/nullptr, diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index a39bac1d00c..fb4fd8651ce 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -391,6 +391,23 @@ class NVF_API IterDomain : public Val { friend TensorDomain; friend ReplayTransformations; friend IndexReferenceReplay; + friend RaggedIterDomain; + + //! Protected constructor for derived classes (e.g., RaggedIterDomain) + //! that need to override the ValType + IterDomain( + IrBuilderPasskey passkey, + ValType vtype, + Val* start, + Val* extent, + Val* expanded_extent, + Val* stop_offset, + ParallelType parallel_type, + IterType iter_type, + bool is_rfactor_domain, + bool is_padded_dimension, + bool is_clustered_blocks, + std::optional padded_to_size); private: //! Valid range is defined as [start:-stop_offset] From 787dfecff93345a0d1959bd83aa764d4e2514f2a Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 11:54:06 -0800 Subject: [PATCH 05/31] unit test --- tests/cpp/test_ragged_iter_domain.cpp | 31 +++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 249d31afed3..856bbfea2aa 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -163,4 +163,35 @@ TEST_F(RaggedIterDomainTest, ValidationNonIntegerExtents) { nvfuser::nvfError); } +// ValType test - ensure RaggedIterDomain has correct ValType +TEST_F(RaggedIterDomainTest, ValType) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + + auto ragged_id = IrBuilder::create( + extents, IterType::Iteration, ParallelType::Serial); + + // Verify ValType is RaggedIterDomain, not IterDomain + EXPECT_EQ(ragged_id->vtype(), ValType::RaggedIterDomain); + EXPECT_NE(ragged_id->vtype(), ValType::IterDomain); + + // Verify getValType also returns the correct type + EXPECT_TRUE(ragged_id->getValType().has_value()); + EXPECT_EQ(ragged_id->getValType().value(), ValType::RaggedIterDomain); + + // Compare with a regular IterDomain + auto regular_id = + IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) + .build(); + EXPECT_EQ(regular_id->vtype(), ValType::IterDomain); + EXPECT_NE(regular_id->vtype(), ValType::RaggedIterDomain); + + // Verify they have different types + EXPECT_NE(ragged_id->vtype(), regular_id->vtype()); +} + } // namespace nvfuser From a0b40a39559affa87e93f461922e5a2aaecd8974 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 11:57:44 -0800 Subject: [PATCH 06/31] cleanup --- tests/cpp/test_ragged_iter_domain.cpp | 47 +++++++++------------------ 1 file changed, 15 insertions(+), 32 deletions(-) diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 856bbfea2aa..7274fcbb36b 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -23,7 +23,7 @@ TEST_F(RaggedIterDomainTest, Construction) { FusionGuard fg(&fusion); // Create a TensorView to use as extents - // This represents component sizes [3, 5, 2] + // This represents component sizes such as [3, 5, 2] auto extents = makeSymbolicTensor(1, DataType::Index); fusion.addInput(extents); @@ -41,6 +41,20 @@ TEST_F(RaggedIterDomainTest, Construction) { // Verify extent is not null (it's the sum of extents) EXPECT_NE(ragged_id->extent(), nullptr); + + // Verify ValType is RaggedIterDomain, not IterDomain + EXPECT_EQ(ragged_id->vtype(), ValType::RaggedIterDomain); + EXPECT_NE(ragged_id->vtype(), ValType::IterDomain); + EXPECT_TRUE(ragged_id->getValType().has_value()); + EXPECT_EQ(ragged_id->getValType().value(), ValType::RaggedIterDomain); + + // Compare with a regular IterDomain to ensure different types + auto regular_id = + IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) + .build(); + EXPECT_EQ(regular_id->vtype(), ValType::IterDomain); + EXPECT_NE(ragged_id->vtype(), regular_id->vtype()); } // RaggedIterDomain with parallelization @@ -163,35 +177,4 @@ TEST_F(RaggedIterDomainTest, ValidationNonIntegerExtents) { nvfuser::nvfError); } -// ValType test - ensure RaggedIterDomain has correct ValType -TEST_F(RaggedIterDomainTest, ValType) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto extents = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(extents); - - auto ragged_id = IrBuilder::create( - extents, IterType::Iteration, ParallelType::Serial); - - // Verify ValType is RaggedIterDomain, not IterDomain - EXPECT_EQ(ragged_id->vtype(), ValType::RaggedIterDomain); - EXPECT_NE(ragged_id->vtype(), ValType::IterDomain); - - // Verify getValType also returns the correct type - EXPECT_TRUE(ragged_id->getValType().has_value()); - EXPECT_EQ(ragged_id->getValType().value(), ValType::RaggedIterDomain); - - // Compare with a regular IterDomain - auto regular_id = - IterDomainBuilder( - fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) - .build(); - EXPECT_EQ(regular_id->vtype(), ValType::IterDomain); - EXPECT_NE(regular_id->vtype(), ValType::RaggedIterDomain); - - // Verify they have different types - EXPECT_NE(ragged_id->vtype(), regular_id->vtype()); -} - } // namespace nvfuser From dbdd917ee07ec787949399f694cceb8742717511 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 13:05:35 -0800 Subject: [PATCH 07/31] Fix IterVisitor --- csrc/iter_visitor.cpp | 9 +++++++++ tests/cpp/test_ragged_iter_domain.cpp | 24 ++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index 2386a4729b2..0c8c208417e 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -69,6 +69,15 @@ class MemberStatements : public OptOutDispatch { next_stmts_.push_back(stmt->stopOffset()); } + void handle(RaggedIterDomain* stmt) final { + // Visit the standard IterDomain fields + next_stmts_.push_back(stmt->start()); + next_stmts_.push_back(stmt->extent()); + next_stmts_.push_back(stmt->stopOffset()); + // Visit the extents TensorView (ragged-specific field) + next_stmts_.push_back(stmt->extents()); + } + void handle(TensorDomain* stmt) final { for (const std::vector* dom : stmt->allDomains()) { next_stmts_.insert(next_stmts_.end(), dom->begin(), dom->end()); diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 7274fcbb36b..032a1d154c0 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -177,4 +177,28 @@ TEST_F(RaggedIterDomainTest, ValidationNonIntegerExtents) { nvfuser::nvfError); } +// IterVisitor test - ensure graph traversal visits extents field +TEST_F(RaggedIterDomainTest, IterVisitor) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create extents TensorView + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + + // Create RaggedIterDomain + auto ragged_id = IrBuilder::create( + extents, IterType::Iteration, ParallelType::Serial); + + // Collect all statements reachable from the RaggedIterDomain + // Use traverse_members=true to visit member fields + std::vector from_vals = {ragged_id}; + auto all_stmts = StmtSort::getStmtsTo(from_vals, /*traverse_members=*/true); + + // Verify the extents TensorView is visited (this is the critical check) + EXPECT_TRUE( + std::find(all_stmts.begin(), all_stmts.end(), extents) != all_stmts.end()) + << "IterVisitor should traverse the extents_ field of RaggedIterDomain"; +} + } // namespace nvfuser From cdbd81e46bb57610da0326345cf6ae09f68e90ac Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 13:43:28 -0800 Subject: [PATCH 08/31] cleanup --- tests/cpp/test_ragged_iter_domain.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 032a1d154c0..4002f854947 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -191,11 +191,10 @@ TEST_F(RaggedIterDomainTest, IterVisitor) { extents, IterType::Iteration, ParallelType::Serial); // Collect all statements reachable from the RaggedIterDomain - // Use traverse_members=true to visit member fields std::vector from_vals = {ragged_id}; auto all_stmts = StmtSort::getStmtsTo(from_vals, /*traverse_members=*/true); - // Verify the extents TensorView is visited (this is the critical check) + // Verify the extents TensorView is visited EXPECT_TRUE( std::find(all_stmts.begin(), all_stmts.end(), extents) != all_stmts.end()) << "IterVisitor should traverse the extents_ field of RaggedIterDomain"; From d4c8d7f72bb07ce9815d9c43aabbea6d8add1e92 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 11:15:03 -0800 Subject: [PATCH 09/31] WIP: partition --- csrc/ir/internal_base_nodes.cpp | 90 +++++++++++++++++++++++++++ csrc/ir/internal_base_nodes.h | 14 +++++ tests/cpp/test_ragged_iter_domain.cpp | 88 ++++++++++++++++++++++++-- 3 files changed, 188 insertions(+), 4 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index abb5db26d2c..006adad536f 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -894,6 +895,95 @@ std::string RaggedIterDomain::toInlineString(int indent_size) const { return toString(indent_size); } +std::pair RaggedIterDomain::partition( + IterDomain* in, + TensorView* offsets) { + NVF_ERROR(in != nullptr, "partition: input IterDomain is null"); + + NVF_ERROR( + !in->isA(), + "partition: input is already RaggedIterDomain, cannot partition again"); + + NVF_ERROR_EQ(in->getParallelType(), ParallelType::Serial, + "Partitioning of parallelized IterDomain not supported: ", + in->toString()); + + NVF_ERROR(offsets != nullptr, "partition: offsets tensor is null"); + + NVF_ERROR( + offsets->dtype() == DataType::Index, + "partition: offsets must have Index type, got ", + offsets->dtype()); + + const auto& offsets_domain = offsets->getLogicalDomain(); + NVF_ERROR( + !offsets_domain.empty(), + "partition: offsets tensor must have at least one dimension"); + + auto container = in->container(); + + // Compute extents from offsets: extents[i] = offsets[i+1] - offsets[i] + // Slice along the last dimension of the offsets tensor + // offsets_left = offsets[..., :-1] (all but last element in last dim) + // offsets_right = offsets[..., 1:] (all but first element in last dim) + + const auto last_dim = offsets_domain.size() - 1; + auto offsets_len = offsets_domain[last_dim]->extent(); + + auto zero = container->zeroVal(DataType::Index); + auto one = container->oneVal(DataType::Index); + auto len_minus_one = sub(offsets_len, one); + + // Build slice ranges for all dimensions + // For all dimensions except the last, use full range (:) + // For the last dimension, use [:-1] for left and [1:] for right + std::vector left_ranges; + std::vector right_ranges; + + for (const auto i : arange(offsets_domain.size())) { + if (i < last_dim) { + // Full range for non-last dimensions + Slice s; + s.start = zero; + s.stop = offsets_domain[i]->extent(); + left_ranges.push_back(s); + right_ranges.push_back(s); + } else { + // Last dimension: left uses [:-1], right uses [1:] + Slice left_s; + left_s.start = zero; + left_s.stop = len_minus_one; + left_ranges.push_back(left_s); + + Slice right_s; + right_s.start = one; + right_s.stop = offsets_len; + right_ranges.push_back(right_s); + } + } + + auto offsets_left = slice(offsets, left_ranges); + auto offsets_right = slice(offsets, right_ranges); + + // Compute extents: extents = offsets_right - offsets_left + auto extents = sub(offsets_right, offsets_left); + + // Create batch IterDomain + // Batch extent = number of components = len(offsets) - 1 + auto batch_extent = len_minus_one; + auto batch_id = IterDomainBuilder(zero, batch_extent) + .parallel_type(ParallelType::Serial) + .iter_type(IterType::Iteration) + .build(); + + // Create RaggedIterDomain with computed extents + auto ragged_id = IrBuilder::create( + extents, in->getIterType()); + + // Return pair + return {batch_id, ragged_id}; +} + TensorDomain::TensorDomain( IrBuilderPasskey passkey, std::vector logical_domain, diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index fb4fd8651ce..59f099ec517 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -471,6 +471,20 @@ class NVF_API RaggedIterDomain : public IterDomain { std::string toInlineString(int indent_size = 0) const override; + //! Partition an IterDomain into batch and ragged dimensions + //! Creates a batch IterDomain and a RaggedIterDomain based on offsets + //! + //! \param in Input IterDomain to partition (must be regular IterDomain) + //! \param offsets Offset tensor defining partition boundaries + //! Shape: [num_components + 1], values: [0, off1, off2, ..., total] + //! Extents are computed as: extents[i] = offsets[i+1] - offsets[i] + //! \return Pair of (batch_id, ragged_id) + //! batch_id: IterDomain with extent = num_components + //! ragged_id: RaggedIterDomain with extents computed from offsets + static std::pair partition( + IterDomain* in, + TensorView* offsets); + //! Accessor for the extents tensor TensorView* extents() const { return extents_; diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 4002f854947..b537f6c95a8 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -126,10 +126,6 @@ TEST_F(RaggedIterDomainTest, Printing) { // Verify output contains expected elements EXPECT_NE(str.find("Ragged"), std::string::npos); EXPECT_NE(str.find("extents"), std::string::npos); - - // Also test toInlineString - std::string inline_str = ragged_id->toInlineString(); - EXPECT_FALSE(inline_str.empty()); } // Multi-dimensional extents tensor @@ -200,4 +196,88 @@ TEST_F(RaggedIterDomainTest, IterVisitor) { << "IterVisitor should traverse the extents_ field of RaggedIterDomain"; } +// Partition operation - basic test +TEST_F(RaggedIterDomainTest, PartitionBasic) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create input IterDomain + auto input_id = IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) + .build(); + + // Create a symbolic offset tensor + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Partition the IterDomain + auto [batch_id, ragged_id] = RaggedIterDomain::partition(input_id, offsets); + + // Verify batch IterDomain + EXPECT_NE(batch_id, nullptr); + EXPECT_TRUE(batch_id->isA()); + EXPECT_FALSE(batch_id->isA()); + EXPECT_EQ(batch_id->getIterType(), IterType::Iteration); + + // Verify RaggedIterDomain + EXPECT_NE(ragged_id, nullptr); + EXPECT_TRUE(ragged_id->isA()); + EXPECT_EQ(ragged_id->getIterType(), IterType::Iteration); + EXPECT_NE(ragged_id->extents(), nullptr); +} + +// Partition operation - multi-dimensional offsets +TEST_F(RaggedIterDomainTest, PartitionMultiDimensional) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto input_id = IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(100L, DataType::Index)) + .build(); + + // Create 2D offsets tensor for nested ragged structure + auto offsets_2d = makeSymbolicTensor(2, DataType::Index); + fusion.addInput(offsets_2d); + + // Partition should work with multi-dimensional offsets + auto [batch_id, ragged_id] = RaggedIterDomain::partition(input_id, offsets_2d); + + EXPECT_NE(batch_id, nullptr); + EXPECT_NE(ragged_id, nullptr); + EXPECT_TRUE(ragged_id->isA()); + EXPECT_NE(ragged_id->extents(), nullptr); +} + +// Partition operation - validation tests +TEST_F(RaggedIterDomainTest, PartitionValidation) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto input_id = IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) + .build(); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Test 1: Null input should fail + EXPECT_THROW(RaggedIterDomain::partition(nullptr, offsets), nvfuser::nvfError); + + // Test 2: Null offsets should fail + EXPECT_THROW(RaggedIterDomain::partition(input_id, nullptr), nvfuser::nvfError); + + // Test 3: Non-Index offsets should fail + auto float_offsets = makeSymbolicTensor(1, DataType::Float); + fusion.addInput(float_offsets); + EXPECT_THROW( + RaggedIterDomain::partition(input_id, float_offsets), nvfuser::nvfError); + + // Test 4: Cannot partition RaggedIterDomain + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + auto ragged_id = IrBuilder::create( + extents, IterType::Iteration, ParallelType::Serial); + EXPECT_THROW(RaggedIterDomain::partition(ragged_id, offsets), nvfuser::nvfError); +} + } // namespace nvfuser From 9575a13b09e1b5f275588df5a956c25e987c43ab Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 18:26:22 -0800 Subject: [PATCH 10/31] Partition expr --- csrc/dispatch.h | 1 + csrc/ir/internal_base_nodes.cpp | 31 ++++++---- csrc/ir/internal_base_nodes.h | 8 +-- csrc/ir/internal_nodes.cpp | 33 ++++++++++ csrc/ir/internal_nodes.h | 44 ++++++++++++++ tests/cpp/test_ragged_iter_domain.cpp | 87 ++++++++++++++++++--------- 6 files changed, 159 insertions(+), 45 deletions(-) diff --git a/csrc/dispatch.h b/csrc/dispatch.h index d14d5257fe0..822ababb149 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -115,6 +115,7 @@ class Val; f(TopKOp); \ f(ScanOp); \ f(Merge); \ + f(Partition); \ f(Swizzle); \ f(Swizzle2D); \ f(Resize); \ diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 006adad536f..c95014bdadb 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -904,9 +904,11 @@ std::pair RaggedIterDomain::partition( !in->isA(), "partition: input is already RaggedIterDomain, cannot partition again"); - NVF_ERROR_EQ(in->getParallelType(), ParallelType::Serial, - "Partitioning of parallelized IterDomain not supported: ", - in->toString()); + NVF_ERROR_EQ( + in->getParallelType(), + ParallelType::Serial, + "Partitioning of parallelized IterDomain not supported: ", + in->toString()); NVF_ERROR(offsets != nullptr, "partition: offsets tensor is null"); @@ -968,20 +970,23 @@ std::pair RaggedIterDomain::partition( // Compute extents: extents = offsets_right - offsets_left auto extents = sub(offsets_right, offsets_left); - // Create batch IterDomain - // Batch extent = number of components = len(offsets) - 1 - auto batch_extent = len_minus_one; - auto batch_id = IterDomainBuilder(zero, batch_extent) - .parallel_type(ParallelType::Serial) - .iter_type(IterType::Iteration) - .build(); + // Create component IterDomain + // Component extent = number of components = len(offsets) - 1 + auto component_extent = len_minus_one; + auto component_id = IterDomainBuilder(zero, component_extent) + .parallel_type(ParallelType::Serial) + .iter_type(IterType::Iteration) + .build(); // Create RaggedIterDomain with computed extents - auto ragged_id = IrBuilder::create( - extents, in->getIterType()); + auto ragged_id = + IrBuilder::create(extents, in->getIterType()); + + // Create the Partition expr to represent this transformation + IrBuilder::create(component_id, ragged_id, in, extents); // Return pair - return {batch_id, ragged_id}; + return {component_id, ragged_id}; } TensorDomain::TensorDomain( diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 59f099ec517..2a2e85d0458 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -471,15 +471,15 @@ class NVF_API RaggedIterDomain : public IterDomain { std::string toInlineString(int indent_size = 0) const override; - //! Partition an IterDomain into batch and ragged dimensions - //! Creates a batch IterDomain and a RaggedIterDomain based on offsets + //! Partition an IterDomain into component and ragged dimensions + //! Creates a component IterDomain and a RaggedIterDomain based on offsets //! //! \param in Input IterDomain to partition (must be regular IterDomain) //! \param offsets Offset tensor defining partition boundaries //! Shape: [num_components + 1], values: [0, off1, off2, ..., total] //! Extents are computed as: extents[i] = offsets[i+1] - offsets[i] - //! \return Pair of (batch_id, ragged_id) - //! batch_id: IterDomain with extent = num_components + //! \return Pair of (component_id, ragged_id) + //! component_id: IterDomain with extent = num_components //! ragged_id: RaggedIterDomain with extents computed from offsets static std::pair partition( IterDomain* in, diff --git a/csrc/ir/internal_nodes.cpp b/csrc/ir/internal_nodes.cpp index 2f0f39afaa5..bdbe706a3cf 100644 --- a/csrc/ir/internal_nodes.cpp +++ b/csrc/ir/internal_nodes.cpp @@ -2609,6 +2609,39 @@ std::string Merge::toInlineString(int indent_size) const { NVFUSER_DEFINE_CLONE_AND_CREATE(Merge) +Partition::Partition( + IrBuilderPasskey passkey, + IterDomain* component, + RaggedIterDomain* ragged, + IterDomain* in, + TensorView* extents) + : Expr(passkey) { + addOutput(component); + addOutput(ragged); + addInput(in); + // Should the extents tensor be an input rather than an attribute? + addAttribute(extents); +} + +std::string Partition::toString(int indent_size) const { + std::stringstream ss; + ss << "Partition: "; + ss << in()->toString(); + ss << " by extents " << extents()->toString(); + ss << " -> component: "; + ss << component()->toString(); + ss << ", ragged: "; + ss << ragged()->toString(); + ss << "\n"; + return ss.str(); +} + +std::string Partition::toInlineString(int indent_size) const { + NVF_CHECK(false, "Partition can not be printed inline"); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(Partition) + Swizzle::Swizzle( IrBuilderPasskey passkey, IterDomain* out_x, diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index fea0f565082..9393dc3016b 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -1801,6 +1801,50 @@ class NVF_API Merge : public Expr { } }; +//! Partition an IterDomain into component and ragged dimensions +//! Creates a component IterDomain and a RaggedIterDomain based on extents +//! tensor The extents tensor contains the extent for each component +class NVF_API Partition : public Expr { + public: + using Expr::Expr; + + Partition( + IrBuilderPasskey, + IterDomain* component, + RaggedIterDomain* ragged, + IterDomain* in, + TensorView* extents); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return "Partition"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + + //! Component dimension output (extent = num_components) + IterDomain* component() const { + return output(0)->as(); + } + + //! Ragged dimension output (variable extents per component) + RaggedIterDomain* ragged() const { + return output(1)->as(); + } + + //! Input IterDomain being partitioned + IterDomain* in() const { + return input(0)->as(); + } + + //! Extents tensor containing extent for each component + TensorView* extents() const { + return attributeVal(0)->as(); + } +}; + class Swizzle : public Expr { public: using Expr::Expr; diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index b537f6c95a8..a9d4c79c67e 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -32,7 +32,7 @@ TEST_F(RaggedIterDomainTest, Construction) { extents, IterType::Iteration, ParallelType::Serial); // Verify properties - EXPECT_NE(ragged_id, nullptr); + EXPECT_TRUE(ragged_id != nullptr); EXPECT_TRUE(ragged_id->isA()); EXPECT_EQ(ragged_id->getIterType(), IterType::Iteration); EXPECT_EQ(ragged_id->getParallelType(), ParallelType::Serial); @@ -40,7 +40,7 @@ TEST_F(RaggedIterDomainTest, Construction) { EXPECT_FALSE(ragged_id->isRFactorProduct()); // Verify extent is not null (it's the sum of extents) - EXPECT_NE(ragged_id->extent(), nullptr); + EXPECT_TRUE(ragged_id->extent() != nullptr); // Verify ValType is RaggedIterDomain, not IterDomain EXPECT_EQ(ragged_id->vtype(), ValType::RaggedIterDomain); @@ -140,7 +140,7 @@ TEST_F(RaggedIterDomainTest, MultiDimensionalExtents) { auto ragged_nested = IrBuilder::create( extents_2d, IterType::Iteration, ParallelType::Serial); - EXPECT_NE(ragged_nested, nullptr); + EXPECT_TRUE(ragged_nested != nullptr); EXPECT_EQ(ragged_nested->extents(), extents_2d); } @@ -202,28 +202,53 @@ TEST_F(RaggedIterDomainTest, PartitionBasic) { FusionGuard fg(&fusion); // Create input IterDomain - auto input_id = IterDomainBuilder( - fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) - .build(); + auto input_id = + IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(-1, DataType::Index)) + .build(); // Create a symbolic offset tensor auto offsets = makeSymbolicTensor(1, DataType::Index); fusion.addInput(offsets); // Partition the IterDomain - auto [batch_id, ragged_id] = RaggedIterDomain::partition(input_id, offsets); + auto [component_id, ragged_id] = + RaggedIterDomain::partition(input_id, offsets); - // Verify batch IterDomain - EXPECT_NE(batch_id, nullptr); - EXPECT_TRUE(batch_id->isA()); - EXPECT_FALSE(batch_id->isA()); - EXPECT_EQ(batch_id->getIterType(), IterType::Iteration); + // Verify component IterDomain + EXPECT_TRUE(component_id != nullptr); + EXPECT_TRUE(component_id->isA()); + EXPECT_FALSE(component_id->isA()); // Verify RaggedIterDomain - EXPECT_NE(ragged_id, nullptr); + EXPECT_TRUE(ragged_id != nullptr); EXPECT_TRUE(ragged_id->isA()); - EXPECT_EQ(ragged_id->getIterType(), IterType::Iteration); - EXPECT_NE(ragged_id->extents(), nullptr); + EXPECT_TRUE(ragged_id->extents() != nullptr); + + // Verify that a Partition expr was created + EXPECT_TRUE(component_id->definition() != nullptr); + EXPECT_TRUE(component_id->definition()->isA()); + + // Both outputs should have the same definition (the Partition expr) + EXPECT_EQ(component_id->definition(), ragged_id->definition()); + + // Verify the Partition expr structure + auto partition_expr = component_id->definition()->as(); + EXPECT_EQ(partition_expr->component(), component_id); + EXPECT_EQ(partition_expr->ragged(), ragged_id); + EXPECT_EQ(partition_expr->in(), input_id); + EXPECT_EQ(partition_expr->extents(), ragged_id->extents()); + + // Verify the expr has correct inputs and outputs + EXPECT_EQ(partition_expr->inputs().size(), 1); + EXPECT_EQ(partition_expr->outputs().size(), 2); + EXPECT_EQ(partition_expr->input(0), input_id); + EXPECT_EQ(partition_expr->output(0), component_id); + EXPECT_EQ(partition_expr->output(1), ragged_id); + + // Test toString + std::string str = partition_expr->toString(); + EXPECT_TRUE(str.find("Partition") != std::string::npos); } // Partition operation - multi-dimensional offsets @@ -231,21 +256,23 @@ TEST_F(RaggedIterDomainTest, PartitionMultiDimensional) { Fusion fusion; FusionGuard fg(&fusion); - auto input_id = IterDomainBuilder( - fusion.zeroVal(), IrBuilder::create(100L, DataType::Index)) - .build(); + auto input_id = + IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(100L, DataType::Index)) + .build(); // Create 2D offsets tensor for nested ragged structure auto offsets_2d = makeSymbolicTensor(2, DataType::Index); fusion.addInput(offsets_2d); // Partition should work with multi-dimensional offsets - auto [batch_id, ragged_id] = RaggedIterDomain::partition(input_id, offsets_2d); + auto [component_id, ragged_id] = + RaggedIterDomain::partition(input_id, offsets_2d); - EXPECT_NE(batch_id, nullptr); - EXPECT_NE(ragged_id, nullptr); + EXPECT_TRUE(component_id != nullptr); + EXPECT_TRUE(ragged_id != nullptr); EXPECT_TRUE(ragged_id->isA()); - EXPECT_NE(ragged_id->extents(), nullptr); + EXPECT_TRUE(ragged_id->extents() != nullptr); } // Partition operation - validation tests @@ -253,18 +280,21 @@ TEST_F(RaggedIterDomainTest, PartitionValidation) { Fusion fusion; FusionGuard fg(&fusion); - auto input_id = IterDomainBuilder( - fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) - .build(); + auto input_id = + IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) + .build(); auto offsets = makeSymbolicTensor(1, DataType::Index); fusion.addInput(offsets); // Test 1: Null input should fail - EXPECT_THROW(RaggedIterDomain::partition(nullptr, offsets), nvfuser::nvfError); + EXPECT_THROW( + RaggedIterDomain::partition(nullptr, offsets), nvfuser::nvfError); // Test 2: Null offsets should fail - EXPECT_THROW(RaggedIterDomain::partition(input_id, nullptr), nvfuser::nvfError); + EXPECT_THROW( + RaggedIterDomain::partition(input_id, nullptr), nvfuser::nvfError); // Test 3: Non-Index offsets should fail auto float_offsets = makeSymbolicTensor(1, DataType::Float); @@ -277,7 +307,8 @@ TEST_F(RaggedIterDomainTest, PartitionValidation) { fusion.addInput(extents); auto ragged_id = IrBuilder::create( extents, IterType::Iteration, ParallelType::Serial); - EXPECT_THROW(RaggedIterDomain::partition(ragged_id, offsets), nvfuser::nvfError); + EXPECT_THROW( + RaggedIterDomain::partition(ragged_id, offsets), nvfuser::nvfError); } } // namespace nvfuser From a054ae0c89000f4aa9f7dc2cc0497ee3b12e71dd Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 20:37:57 -0800 Subject: [PATCH 11/31] TensorView::partition --- csrc/ir/interface_nodes.h | 9 ++++++ csrc/ir/internal_base_nodes.cpp | 16 ++++++++++ csrc/ir/internal_base_nodes.h | 3 ++ csrc/tensor_view.cpp | 45 +++++++++++++++++++++++++++ tests/cpp/test_ragged_iter_domain.cpp | 34 ++++++++++++++++++++ 5 files changed, 107 insertions(+) diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index e9172080640..ea0236a1a7b 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -619,6 +619,15 @@ class NVF_API TensorView : public Val { return merge(axis, axis + 1); } + // Partition "axis" into component and ragged dimensions based on offsets + // The offsets tensor defines partition boundaries where: + // Shape: [num_components + 1], values: [0, off1, off2, ..., total] + // Extents are computed as: extents[i] = offsets[i+1] - offsets[i] + // Returns this TensorView with the axis replaced by component and ragged dims + // e.g. partition(0, offsets) on tv[id{N}] results in: + // tv[id{num_components}, ragged_id{extents}] + TensorView* partition(int64_t axis, TensorView* offsets); + // Flatten the axis from `from` to `to` into a single axis. // Both `from` and `to` are inclusive. TensorView* flatten(int64_t from = 0, int64_t to = -1); diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index c95014bdadb..095eb208484 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -1593,6 +1593,22 @@ void TensorDomain::merge(int64_t axis_o, int64_t axis_i) { loop_domain_.insert(loop_domain_.begin() + td_outer_pos, merged_id); } +// Partition "axis" into component and ragged dimensions. Follow the +// pattern of TensorDomain::split. +void TensorDomain::partition(int64_t axis, TensorView* offsets) { + NVF_ERROR(nDims() > 0, "Tried to do partition on a 0-dim domain"); + axis = wrapDim(axis); + + IterDomain* id = this->axis(axis); + + auto [component_id, ragged_id] = RaggedIterDomain::partition(id, offsets); + + // Remove the original axis and insert component and ragged dimensions + loop_domain_.erase(loop_domain_.begin() + axis); + loop_domain_.insert(loop_domain_.begin() + axis, ragged_id); + loop_domain_.insert(loop_domain_.begin() + axis, component_id); +} + // Reorder axes according to map[old_pos] = new_pos void TensorDomain::reorder( const std::unordered_map& old2new_) { diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 2a2e85d0458..5888f71b989 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -790,6 +790,9 @@ class NVF_API TensorDomain : public Val { // axis is by default placed at original position axis_o void merge(int64_t axis_o, int64_t axis_i); + // Partition axis into component and ragged dimensions based on offsets + void partition(int64_t axis, TensorView* offsets); + // Reorder axes according to map[old_pos] = new_pos void reorder(const std::unordered_map& old2new); diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index 7e0d39c02bd..eabe2004f0c 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -561,6 +561,51 @@ TensorView* TensorView::merge(int64_t axis_o, int64_t axis_i) { return this; } +// Partition "axis" into component and ragged dimensions based on +// offsets. Follow the pattern of TensorView::split. +TensorView* TensorView::partition(int64_t axis, TensorView* offsets) { + NVF_ERROR( + nDims() > 0, + "Tried to do partition on a 0-dim TensorView. ", + "Tensor: ", + toString()); + + axis = wrapDim(axis); + + NVF_CHECK( + axis >= getMaxComputePosition(), + "Cannot partition axis within compute at position. Axis = ", + axis, + " computePosition = ", + getMaxComputePosition(), + ". Tensor: ", + toString()); + + NVF_CHECK( + axis >= getMaybeMaxProducerPosition(), + "Cannot partition axis within max producer position. Axis = ", + axis, + " maxProducerPosition = ", + getMaybeMaxProducerPosition(), + ". Tensor: ", + toString()); + + NVF_CHECK( + this->axis(axis)->getParallelType() == ParallelType::Serial, + "Partitioning an axis (", + this->axis(axis)->toString(), + ") of non-Serial parallel type is not supported at this time." + " Parallelization strategy must be set after calling partition: ", + toString()); + + if (offsets->dtype() != DataType::Index) { + offsets = castOp(DataType::Index, offsets); + } + + domain()->partition(axis, offsets); + return this; +} + TensorView* TensorView::resize( int64_t axis, Val* left_expansion, diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index a9d4c79c67e..3571d3c11a2 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -311,4 +311,38 @@ TEST_F(RaggedIterDomainTest, PartitionValidation) { RaggedIterDomain::partition(ragged_id, offsets), nvfuser::nvfError); } +// TensorView::partition operation +TEST_F(RaggedIterDomainTest, TensorViewPartition) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create a 2D TensorView + auto tv0 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(tv0); + + // Create offsets tensor + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Partition the first axis + tv0->partition(0, offsets); + + // Verify the tensor now has 3 dimensions: [component, ragged, original_dim1] + EXPECT_EQ(tv0->nDims(), 3); + + // First axis should be a regular IterDomain (component) + EXPECT_TRUE(tv0->axis(0)->isA()); + + // Second axis should be a RaggedIterDomain + EXPECT_TRUE(tv0->axis(1)->isA()); + + // Third axis should be the original second dimension + EXPECT_TRUE(tv0->axis(2)->isA()); + + // Verify both partition outputs have the same definition + EXPECT_TRUE(tv0->axis(0)->definition() != nullptr); + EXPECT_TRUE(tv0->axis(0)->definition()->isA()); + EXPECT_EQ(tv0->axis(0)->definition(), tv0->axis(1)->definition()); +} + } // namespace nvfuser From 69dbe0fd19c374a6cb3db0a0956a4c308cd1f9aa Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 21:05:58 -0800 Subject: [PATCH 12/31] cleanup --- csrc/ir/internal_base_nodes.cpp | 63 ++++++++++----------------- tests/cpp/test_ragged_iter_domain.cpp | 32 +++----------- 2 files changed, 30 insertions(+), 65 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 095eb208484..51f855dcf80 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -912,60 +912,43 @@ std::pair RaggedIterDomain::partition( NVF_ERROR(offsets != nullptr, "partition: offsets tensor is null"); - NVF_ERROR( - offsets->dtype() == DataType::Index, + NVF_ERROR_EQ( + offsets->dtype(), + DataType::Index, "partition: offsets must have Index type, got ", offsets->dtype()); const auto& offsets_domain = offsets->getLogicalDomain(); - NVF_ERROR( - !offsets_domain.empty(), - "partition: offsets tensor must have at least one dimension"); + NVF_ERROR_EQ( + offsets_domain.size(), + 1, + "partition: offsets tensor must be 1D, got ", + offsets_domain.size(), + "D tensor. Multi-dimensional offsets not yet supported."); auto container = in->container(); // Compute extents from offsets: extents[i] = offsets[i+1] - offsets[i] - // Slice along the last dimension of the offsets tensor - // offsets_left = offsets[..., :-1] (all but last element in last dim) - // offsets_right = offsets[..., 1:] (all but first element in last dim) + // offsets_left = offsets[:-1] (all but last element) + // offsets_right = offsets[1:] (all but first element) - const auto last_dim = offsets_domain.size() - 1; - auto offsets_len = offsets_domain[last_dim]->extent(); + auto offsets_len = offsets_domain[0]->extent(); auto zero = container->zeroVal(DataType::Index); auto one = container->oneVal(DataType::Index); auto len_minus_one = sub(offsets_len, one); - // Build slice ranges for all dimensions - // For all dimensions except the last, use full range (:) - // For the last dimension, use [:-1] for left and [1:] for right - std::vector left_ranges; - std::vector right_ranges; - - for (const auto i : arange(offsets_domain.size())) { - if (i < last_dim) { - // Full range for non-last dimensions - Slice s; - s.start = zero; - s.stop = offsets_domain[i]->extent(); - left_ranges.push_back(s); - right_ranges.push_back(s); - } else { - // Last dimension: left uses [:-1], right uses [1:] - Slice left_s; - left_s.start = zero; - left_s.stop = len_minus_one; - left_ranges.push_back(left_s); - - Slice right_s; - right_s.start = one; - right_s.stop = offsets_len; - right_ranges.push_back(right_s); - } - } - - auto offsets_left = slice(offsets, left_ranges); - auto offsets_right = slice(offsets, right_ranges); + // Slice offsets[:-1] + Slice left_slice; + left_slice.start = zero; + left_slice.stop = len_minus_one; + auto offsets_left = slice(offsets, {left_slice}); + + // Slice offsets[1:] + Slice right_slice; + right_slice.start = one; + right_slice.stop = offsets_len; + auto offsets_right = slice(offsets, {right_slice}); // Compute extents: extents = offsets_right - offsets_left auto extents = sub(offsets_right, offsets_left); diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 3571d3c11a2..c04e1f00e31 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -251,30 +251,6 @@ TEST_F(RaggedIterDomainTest, PartitionBasic) { EXPECT_TRUE(str.find("Partition") != std::string::npos); } -// Partition operation - multi-dimensional offsets -TEST_F(RaggedIterDomainTest, PartitionMultiDimensional) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto input_id = - IterDomainBuilder( - fusion.zeroVal(), IrBuilder::create(100L, DataType::Index)) - .build(); - - // Create 2D offsets tensor for nested ragged structure - auto offsets_2d = makeSymbolicTensor(2, DataType::Index); - fusion.addInput(offsets_2d); - - // Partition should work with multi-dimensional offsets - auto [component_id, ragged_id] = - RaggedIterDomain::partition(input_id, offsets_2d); - - EXPECT_TRUE(component_id != nullptr); - EXPECT_TRUE(ragged_id != nullptr); - EXPECT_TRUE(ragged_id->isA()); - EXPECT_TRUE(ragged_id->extents() != nullptr); -} - // Partition operation - validation tests TEST_F(RaggedIterDomainTest, PartitionValidation) { Fusion fusion; @@ -302,7 +278,13 @@ TEST_F(RaggedIterDomainTest, PartitionValidation) { EXPECT_THROW( RaggedIterDomain::partition(input_id, float_offsets), nvfuser::nvfError); - // Test 4: Cannot partition RaggedIterDomain + // Test 4: Multi-dimensional offsets should fail + auto offsets_2d = makeSymbolicTensor(2, DataType::Index); + fusion.addInput(offsets_2d); + EXPECT_THROW( + RaggedIterDomain::partition(input_id, offsets_2d), nvfuser::nvfError); + + // Test 5: Cannot partition RaggedIterDomain auto extents = makeSymbolicTensor(1, DataType::Index); fusion.addInput(extents); auto ragged_id = IrBuilder::create( From 2348dde73b40b6d306bf9324334176a217616e5c Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 21:44:44 -0800 Subject: [PATCH 13/31] cleanup --- csrc/ir/internal_base_nodes.cpp | 8 ++++++++ csrc/ir/internal_base_nodes.h | 4 +++- tests/cpp/test_ragged_iter_domain.cpp | 11 ++++++++++- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 4187ef96e2a..cc068fbfca2 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -910,6 +910,14 @@ std::pair RaggedIterDomain::partition( "Partitioning of parallelized IterDomain not supported: ", in->toString()); + NVF_ERROR_EQ( + in->getIterType(), + IterType::Iteration, + "partition: only IterType::Iteration is supported, got ", + in->getIterType(), + " for IterDomain: ", + in->toString()); + NVF_ERROR(offsets != nullptr, "partition: offsets tensor is null"); NVF_ERROR_EQ( diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 421c056be82..0187c408bd7 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -480,12 +480,14 @@ class NVF_API RaggedIterDomain : public IterDomain { //! Creates a component IterDomain and a RaggedIterDomain based on offsets //! //! \param in Input IterDomain to partition (must be regular IterDomain) - //! \param offsets Offset tensor defining partition boundaries + //! \param offsets Offset tensor defining partition boundaries (must be 1D) //! Shape: [num_components + 1], values: [0, off1, off2, ..., total] //! Extents are computed as: extents[i] = offsets[i+1] - offsets[i] //! \return Pair of (component_id, ragged_id) //! component_id: IterDomain with extent = num_components //! ragged_id: RaggedIterDomain with extents computed from offsets + //! + //! TODO: Support multi-dimensional offsets for nested ragged structures static std::pair partition( IterDomain* in, TensorView* offsets); diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index b343636a56f..8d16615bd64 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -288,7 +288,16 @@ TEST_F(RaggedIterDomainTest, PartitionValidation) { EXPECT_THROW( RaggedIterDomain::partition(input_id, offsets_2d), nvfuser::nvfError); - // Test 5: Cannot partition RaggedIterDomain + // Test 5: Non-Iteration IterType should fail + auto reduction_id = + IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) + .iter_type(IterType::Reduction) + .build(); + EXPECT_THROW( + RaggedIterDomain::partition(reduction_id, offsets), nvfuser::nvfError); + + // Test 6: Cannot partition RaggedIterDomain auto extents = makeSymbolicTensor(1, DataType::Index); fusion.addInput(extents); auto ragged_id = IrBuilder::create( From 7090b9c2bedacfd19f0cbf25f370829e4ae0f45b Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 23:39:07 -0800 Subject: [PATCH 14/31] WIP: asNested --- csrc/ops/alias.cpp | 65 ++++++++++ csrc/ops/alias.h | 23 ++++ tests/cpp/test_ragged_iter_domain.cpp | 171 ++++++++++++++++++++++++++ 3 files changed, 259 insertions(+) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 45c0f52a603..cfe2cbc5ad1 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1268,4 +1268,69 @@ TensorView* repeat( return out_tv; } +TensorView* asNested( + TensorView* data, + TensorView* offsets, + int64_t ragged_dim) { + // Basic null checks + NVF_ERROR(data != nullptr, "asNested: data tensor is null"); + NVF_ERROR(offsets != nullptr, "asNested: offsets tensor is null"); + + // Only 1D offset tensors are currently supported + NVF_CHECK( + offsets->nDims() == 1, + "asNested currently only supports 1D offset tensors, got ", + offsets->nDims(), + "D"); + + // Get the logical domain of the input, excluding reductions + auto inp_logical = TensorDomain::noReductions(data->getLogicalDomain()); + + // Clone the logical domain to create the root domain for output + std::vector root_domain; + root_domain.reserve(inp_logical.size()); + for (auto* id : inp_logical) { + root_domain.push_back(id->cloneWithoutRFactor()); + } + + // Partition the specified dimension in root domain + // This replaces one IterDomain with (component_id, ragged_id) + auto [component_id, ragged_id] = + RaggedIterDomain::partition(root_domain.at(ragged_dim), offsets); + + // Build the logical domain: replace ragged_dim with component and ragged + std::vector logical_domain; + logical_domain.reserve(root_domain.size() + 1); // One extra for the split + + for (const auto i : arange(root_domain.size())) { + if (static_cast(i) == ragged_dim) { + // Replace with component and ragged dimensions + logical_domain.push_back(component_id); + logical_domain.push_back(ragged_id); + } else { + logical_domain.push_back(root_domain.at(i)); + } + } + + // Create the output TensorView with the partitioned structure + auto* out = IrBuilder::create( + IrBuilder::create( + root_domain, + logical_domain, + logical_domain, + TensorDomain::getContiguityFilledWith(logical_domain, true)), + data->getDataType().value()); + + // Create a Partition expression to represent this transformation + // The Partition Expr outputs the component_id and ragged_id, and sets up + // the definitions for those IterDomains + IrBuilder::create(component_id, ragged_id, root_domain.at(ragged_dim), offsets); + + // Set the output TensorView's definition - this should be done via LoadStoreOp + // since we're creating an alias view + IrBuilder::create(LoadStoreOpType::Set, out, data); + + return out; +} + } // namespace nvfuser diff --git a/csrc/ops/alias.h b/csrc/ops/alias.h index 951b98b2a12..feea2924daa 100644 --- a/csrc/ops/alias.h +++ b/csrc/ops/alias.h @@ -156,6 +156,29 @@ NVF_API TensorView* slice( const std::vector& starts, const std::vector& stops); +//! Create a nested tensor view from a data tensor and offsets. +//! This is a convenience wrapper around TensorView::partition(). +//! +//! The function partitions the specified dimension of the data tensor into +//! a component dimension and a ragged dimension based on the provided offsets. +//! +//! \param data Input tensor to be converted to nested representation +//! \param offsets Offset tensor defining partition boundaries +//! Shape: [num_components + 1], values: [0, off1, off2, ..., total] +//! Extents are computed as: extents[i] = offsets[i+1] - offsets[i] +//! \param ragged_dim Dimension to partition into nested structure (0-indexed) +//! \return TensorView with a RaggedIterDomain at the specified dimension +//! +//! Example: +//! data shape: [10, ...] +//! offsets: [0, 3, 8, 10] +//! ragged_dim: 0 +//! Result: nested tensor with 3 components of sizes [3, 5, 2] +NVF_API TensorView* asNested( + TensorView* data, + TensorView* offsets, + int64_t ragged_dim); + // Splits `in`'s dimension `dim` into `chunks` chunks. All but the last chunk // will be of size `ceil(dim_size/chunks)`. Unlike `torch.chunk` which returns // only positive-size chunks and therefore may return fewer than `chunks` of diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 8d16615bd64..8f97c099026 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -340,4 +340,175 @@ TEST_F(RaggedIterDomainTest, TensorViewPartition) { EXPECT_EQ(tv0->axis(0)->definition(), tv0->axis(1)->definition()); } +// asNested basic functionality +TEST_F(RaggedIterDomainTest, AsNestedBasic) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create a 2D TensorView [10, 20] + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + // Create offsets tensor [num_components + 1] + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor from dimension 0 + auto nested = asNested(data, offsets, 0); + + // Verify the output is a new TensorView + EXPECT_NE(nested, nullptr); + EXPECT_NE(nested, data); + EXPECT_TRUE(nested->isA()); + + // Verify nested tensor has 3 dimensions: [component, ragged, original_dim1] + EXPECT_EQ(nested->nDims(), 3); + + // First axis should be a regular IterDomain (component) + EXPECT_TRUE(nested->axis(0)->isA()); + EXPECT_FALSE(nested->axis(0)->isA()); + + // Second axis should be a RaggedIterDomain + EXPECT_TRUE(nested->axis(1)->isA()); + + // Third axis should be the original second dimension + EXPECT_TRUE(nested->axis(2)->isA()); + EXPECT_FALSE(nested->axis(2)->isA()); + + // Verify the definition exists (LoadStoreOp for aliasing) + EXPECT_TRUE(nested->definition() != nullptr); + EXPECT_TRUE(nested->definition()->isA()); + + // Verify the component and ragged IterDomains have Partition as their definition + EXPECT_TRUE(nested->axis(0)->definition() != nullptr); + EXPECT_TRUE(nested->axis(0)->definition()->isA()); + EXPECT_EQ(nested->axis(0)->definition(), nested->axis(1)->definition()); +} + +// asNested on different dimensions +TEST_F(RaggedIterDomainTest, AsNestedDifferentDimension) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create a 3D TensorView [10, 20, 30] + auto data = makeSymbolicTensor(3, DataType::Float); + fusion.addInput(data); + + // Create offsets tensor + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Partition dimension 1 (middle dimension) + auto nested = asNested(data, offsets, 1); + + // Verify dimensions: [dim0, component, ragged, dim2] + EXPECT_EQ(nested->nDims(), 4); + + // First axis is original dim0 + EXPECT_TRUE(nested->axis(0)->isA()); + EXPECT_FALSE(nested->axis(0)->isA()); + + // Second axis is component + EXPECT_TRUE(nested->axis(1)->isA()); + EXPECT_FALSE(nested->axis(1)->isA()); + + // Third axis is ragged + EXPECT_TRUE(nested->axis(2)->isA()); + + // Fourth axis is original dim2 + EXPECT_TRUE(nested->axis(3)->isA()); + EXPECT_FALSE(nested->axis(3)->isA()); +} + +// asNested with 1D tensor +TEST_F(RaggedIterDomainTest, AsNested1DTensor) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create a 1D TensorView [10] + auto data = makeSymbolicTensor(1, DataType::Float); + fusion.addInput(data); + + // Create offsets tensor + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor from the only dimension + auto nested = asNested(data, offsets, 0); + + // Verify dimensions: [component, ragged] + EXPECT_EQ(nested->nDims(), 2); + + // First axis is component + EXPECT_TRUE(nested->axis(0)->isA()); + EXPECT_FALSE(nested->axis(0)->isA()); + + // Second axis is ragged + EXPECT_TRUE(nested->axis(1)->isA()); +} + +// asNested validation - null data +TEST_F(RaggedIterDomainTest, AsNestedValidationNullData) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Null data should throw + EXPECT_THROW(asNested(nullptr, offsets, 0), nvfuser::nvfError); +} + +// asNested validation - null offsets +TEST_F(RaggedIterDomainTest, AsNestedValidationNullOffsets) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + // Null offsets should throw + EXPECT_THROW(asNested(data, nullptr, 0), nvfuser::nvfError); +} + +// asNested validation - multi-dimensional offsets (not yet supported) +TEST_F(RaggedIterDomainTest, AsNestedValidationMultiDimOffsets) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + // 2D offsets should fail (only 1D supported currently) + auto offsets_2d = makeSymbolicTensor(2, DataType::Index); + fusion.addInput(offsets_2d); + + EXPECT_THROW(asNested(data, offsets_2d, 0), nvfuser::nvfError); +} + +// asNested preserves data type +TEST_F(RaggedIterDomainTest, AsNestedPreservesDataType) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Test with different data types + auto data_float = makeSymbolicTensor(2, DataType::Float); + auto data_double = makeSymbolicTensor(2, DataType::Double); + auto data_int = makeSymbolicTensor(2, DataType::Int); + fusion.addInput(data_float); + fusion.addInput(data_double); + fusion.addInput(data_int); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + auto nested_float = asNested(data_float, offsets, 0); + auto nested_double = asNested(data_double, offsets, 0); + auto nested_int = asNested(data_int, offsets, 0); + + EXPECT_EQ(nested_float->dtype(), DataType::Float); + EXPECT_EQ(nested_double->dtype(), DataType::Double); + EXPECT_EQ(nested_int->dtype(), DataType::Int); +} + } // namespace nvfuser From b07e285ab40e39da8cd4c701ca6045625aff47f4 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Sat, 13 Dec 2025 00:04:39 -0800 Subject: [PATCH 15/31] cleanup --- csrc/ops/alias.cpp | 14 +++---- tests/cpp/test_ragged_iter_domain.cpp | 53 ++++++--------------------- 2 files changed, 19 insertions(+), 48 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index cfe2cbc5ad1..84dbd99b589 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1277,11 +1277,10 @@ TensorView* asNested( NVF_ERROR(offsets != nullptr, "asNested: offsets tensor is null"); // Only 1D offset tensors are currently supported - NVF_CHECK( - offsets->nDims() == 1, - "asNested currently only supports 1D offset tensors, got ", + NVF_ERROR_EQ( offsets->nDims(), - "D"); + 1, + "asNested currently only supports 1D offset tensors"); // Get the logical domain of the input, excluding reductions auto inp_logical = TensorDomain::noReductions(data->getLogicalDomain()); @@ -1324,10 +1323,11 @@ TensorView* asNested( // Create a Partition expression to represent this transformation // The Partition Expr outputs the component_id and ragged_id, and sets up // the definitions for those IterDomains - IrBuilder::create(component_id, ragged_id, root_domain.at(ragged_dim), offsets); + IrBuilder::create( + component_id, ragged_id, root_domain.at(ragged_dim), offsets); - // Set the output TensorView's definition - this should be done via LoadStoreOp - // since we're creating an alias view + // Set the output TensorView's definition - this should be done via + // LoadStoreOp since we're creating an alias view IrBuilder::create(LoadStoreOpType::Set, out, data); return out; diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 8f97c099026..f7eaac14c2e 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -345,19 +345,19 @@ TEST_F(RaggedIterDomainTest, AsNestedBasic) { Fusion fusion; FusionGuard fg(&fusion); - // Create a 2D TensorView [10, 20] auto data = makeSymbolicTensor(2, DataType::Float); fusion.addInput(data); - // Create offsets tensor [num_components + 1] auto offsets = makeSymbolicTensor(1, DataType::Index); fusion.addInput(offsets); // Create nested tensor from dimension 0 auto nested = asNested(data, offsets, 0); + fusion.addOutput(nested); + // Verify the output is a new TensorView - EXPECT_NE(nested, nullptr); + EXPECT_TRUE(nested != nullptr); EXPECT_NE(nested, data); EXPECT_TRUE(nested->isA()); @@ -365,21 +365,21 @@ TEST_F(RaggedIterDomainTest, AsNestedBasic) { EXPECT_EQ(nested->nDims(), 3); // First axis should be a regular IterDomain (component) - EXPECT_TRUE(nested->axis(0)->isA()); + EXPECT_TRUE(nested->axis(0)->isStrictlyA()); EXPECT_FALSE(nested->axis(0)->isA()); // Second axis should be a RaggedIterDomain EXPECT_TRUE(nested->axis(1)->isA()); // Third axis should be the original second dimension - EXPECT_TRUE(nested->axis(2)->isA()); - EXPECT_FALSE(nested->axis(2)->isA()); + EXPECT_TRUE(nested->axis(2)->isStrictlyA()); // Verify the definition exists (LoadStoreOp for aliasing) EXPECT_TRUE(nested->definition() != nullptr); EXPECT_TRUE(nested->definition()->isA()); - // Verify the component and ragged IterDomains have Partition as their definition + // Verify the component and ragged IterDomains have Partition as their + // definition EXPECT_TRUE(nested->axis(0)->definition() != nullptr); EXPECT_TRUE(nested->axis(0)->definition()->isA()); EXPECT_EQ(nested->axis(0)->definition(), nested->axis(1)->definition()); @@ -390,11 +390,9 @@ TEST_F(RaggedIterDomainTest, AsNestedDifferentDimension) { Fusion fusion; FusionGuard fg(&fusion); - // Create a 3D TensorView [10, 20, 30] auto data = makeSymbolicTensor(3, DataType::Float); fusion.addInput(data); - // Create offsets tensor auto offsets = makeSymbolicTensor(1, DataType::Index); fusion.addInput(offsets); @@ -405,19 +403,16 @@ TEST_F(RaggedIterDomainTest, AsNestedDifferentDimension) { EXPECT_EQ(nested->nDims(), 4); // First axis is original dim0 - EXPECT_TRUE(nested->axis(0)->isA()); - EXPECT_FALSE(nested->axis(0)->isA()); + EXPECT_TRUE(nested->axis(0)->isStrictlyA()); // Second axis is component - EXPECT_TRUE(nested->axis(1)->isA()); - EXPECT_FALSE(nested->axis(1)->isA()); + EXPECT_TRUE(nested->axis(1)->isStrictlyA()); // Third axis is ragged EXPECT_TRUE(nested->axis(2)->isA()); // Fourth axis is original dim2 EXPECT_TRUE(nested->axis(3)->isA()); - EXPECT_FALSE(nested->axis(3)->isA()); } // asNested with 1D tensor @@ -436,12 +431,13 @@ TEST_F(RaggedIterDomainTest, AsNested1DTensor) { // Create nested tensor from the only dimension auto nested = asNested(data, offsets, 0); + fusion.addOutput(nested); + // Verify dimensions: [component, ragged] EXPECT_EQ(nested->nDims(), 2); // First axis is component - EXPECT_TRUE(nested->axis(0)->isA()); - EXPECT_FALSE(nested->axis(0)->isA()); + EXPECT_TRUE(nested->axis(0)->isStrictlyA()); // Second axis is ragged EXPECT_TRUE(nested->axis(1)->isA()); @@ -486,29 +482,4 @@ TEST_F(RaggedIterDomainTest, AsNestedValidationMultiDimOffsets) { EXPECT_THROW(asNested(data, offsets_2d, 0), nvfuser::nvfError); } -// asNested preserves data type -TEST_F(RaggedIterDomainTest, AsNestedPreservesDataType) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Test with different data types - auto data_float = makeSymbolicTensor(2, DataType::Float); - auto data_double = makeSymbolicTensor(2, DataType::Double); - auto data_int = makeSymbolicTensor(2, DataType::Int); - fusion.addInput(data_float); - fusion.addInput(data_double); - fusion.addInput(data_int); - - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); - - auto nested_float = asNested(data_float, offsets, 0); - auto nested_double = asNested(data_double, offsets, 0); - auto nested_int = asNested(data_int, offsets, 0); - - EXPECT_EQ(nested_float->dtype(), DataType::Float); - EXPECT_EQ(nested_double->dtype(), DataType::Double); - EXPECT_EQ(nested_int->dtype(), DataType::Int); -} - } // namespace nvfuser From a2c504baa3d0360c92d5bb50a65e3011c237f630 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 15 Dec 2025 13:58:52 -0800 Subject: [PATCH 16/31] asNested --- csrc/ops/alias.cpp | 12 +++--------- csrc/ops/alias.h | 45 ++++++++++++++++++++++----------------------- 2 files changed, 25 insertions(+), 32 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 84dbd99b589..30c3406633e 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1272,7 +1272,6 @@ TensorView* asNested( TensorView* data, TensorView* offsets, int64_t ragged_dim) { - // Basic null checks NVF_ERROR(data != nullptr, "asNested: data tensor is null"); NVF_ERROR(offsets != nullptr, "asNested: offsets tensor is null"); @@ -1320,14 +1319,9 @@ TensorView* asNested( TensorDomain::getContiguityFilledWith(logical_domain, true)), data->getDataType().value()); - // Create a Partition expression to represent this transformation - // The Partition Expr outputs the component_id and ragged_id, and sets up - // the definitions for those IterDomains - IrBuilder::create( - component_id, ragged_id, root_domain.at(ragged_dim), offsets); - - // Set the output TensorView's definition - this should be done via - // LoadStoreOp since we're creating an alias view + // For now, just use LoadStoreOp to represent the nesting + // operation. Does it make more sense to have a specific TensorView + // op like ReshapeOp? IrBuilder::create(LoadStoreOpType::Set, out, data); return out; diff --git a/csrc/ops/alias.h b/csrc/ops/alias.h index feea2924daa..f3bf769dd71 100644 --- a/csrc/ops/alias.h +++ b/csrc/ops/alias.h @@ -156,29 +156,6 @@ NVF_API TensorView* slice( const std::vector& starts, const std::vector& stops); -//! Create a nested tensor view from a data tensor and offsets. -//! This is a convenience wrapper around TensorView::partition(). -//! -//! The function partitions the specified dimension of the data tensor into -//! a component dimension and a ragged dimension based on the provided offsets. -//! -//! \param data Input tensor to be converted to nested representation -//! \param offsets Offset tensor defining partition boundaries -//! Shape: [num_components + 1], values: [0, off1, off2, ..., total] -//! Extents are computed as: extents[i] = offsets[i+1] - offsets[i] -//! \param ragged_dim Dimension to partition into nested structure (0-indexed) -//! \return TensorView with a RaggedIterDomain at the specified dimension -//! -//! Example: -//! data shape: [10, ...] -//! offsets: [0, 3, 8, 10] -//! ragged_dim: 0 -//! Result: nested tensor with 3 components of sizes [3, 5, 2] -NVF_API TensorView* asNested( - TensorView* data, - TensorView* offsets, - int64_t ragged_dim); - // Splits `in`'s dimension `dim` into `chunks` chunks. All but the last chunk // will be of size `ceil(dim_size/chunks)`. Unlike `torch.chunk` which returns // only positive-size chunks and therefore may return fewer than `chunks` of @@ -220,4 +197,26 @@ NVF_API TensorView* repeat( TensorView* inp, const std::vector& repeat_times); +//! Create a nested tensor view from a data tensor and offsets. +//! +//! The function partitions the specified dimension of the data tensor into +//! a component dimension and a ragged dimension based on the provided offsets. +//! +//! \param data Input tensor to be converted to nested representation +//! \param offsets Offset tensor defining partition boundaries +//! Shape: [num_components + 1], values: [0, off1, off2, ..., total] +//! Extents are computed as: extents[i] = offsets[i+1] - offsets[i] +//! \param ragged_dim Dimension to partition into nested structure +//! \return TensorView with a RaggedIterDomain at the specified dimension +//! +//! Example: +//! data shape: [10, ...] +//! offsets: [0, 3, 8, 10] +//! ragged_dim: 0 +//! Result: nested tensor with 3 components. [3, [3, 5, 2], ...] +NVF_API TensorView* asNested( + TensorView* data, + TensorView* offsets, + int64_t ragged_dim); + } // namespace nvfuser From b1d8cf40a0a2fbe9725b68bf08f2e6ea55b9981e Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 15 Dec 2025 14:11:21 -0800 Subject: [PATCH 17/31] warpdim --- csrc/ops/alias.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 30c3406633e..e32aa5e6b9c 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1291,6 +1291,8 @@ TensorView* asNested( root_domain.push_back(id->cloneWithoutRFactor()); } + ragged_dim = wrapDim(ragged_dim, std::ssize(inp_logical)); + // Partition the specified dimension in root domain // This replaces one IterDomain with (component_id, ragged_id) auto [component_id, ragged_id] = From 201c1480ac75dec581570276937d7d1f00513e28 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 17 Dec 2025 11:31:27 -0800 Subject: [PATCH 18/31] Make sure RaggedIterDomain is propagated to output tensors --- csrc/ir/internal_base_nodes.cpp | 7 +++ csrc/ir/internal_base_nodes.h | 2 + csrc/ops/utils.cpp | 29 +++++++++++ tests/cpp/test_ragged_iter_domain.cpp | 71 ++++++++++++++++----------- 4 files changed, 80 insertions(+), 29 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index cc068fbfca2..9d9984a3d11 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -1496,6 +1496,13 @@ bool TensorDomain::hasVectorize() const { }); } +bool TensorDomain::hasRaggedIterDomain() const { + return std::any_of( + logical().begin(), logical().end(), [](IterDomain* logical_id) { + return logical_id->isA(); + }); +} + std::optional TensorDomain::getReductionAxis() const { auto it = std::find_if( loop_domain_.begin(), loop_domain_.end(), [](const auto& id) { diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 0187c408bd7..cfade4ebcba 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -642,6 +642,8 @@ class NVF_API TensorDomain : public Val { bool hasSymbolicAxis() const; + bool hasRaggedIterDomain() const; + std::optional getReductionAxis() const; // The input logical domain. The root domain of a consumer should equal the diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 2db0b424d55..80aa95c1fe2 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -315,6 +315,28 @@ std::vector mapLinearOpIterDomains( return mapping; } +RaggedIterDomain* newOutputRaggedIterDomain( + const std::vector& input_ids, + const std::optional force_iter_type) { + NVF_ERROR( + std::ranges::all_of( + input_ids, + [](IterDomain* input_id) { + return input_id->isA(); + }), + "All input iter domains must be RaggedIterDomain"); + + NVF_ERROR(!input_ids.empty()); + RaggedIterDomain* ref_input_id = input_ids.front()->as(); + + NVF_ERROR(!force_iter_type.has_value(), "forced iter type not considered"); + + return IrBuilder::create( + ref_input_id->extents(), + ref_input_id->getIterType(), + ref_input_id->getParallelType()); +} + // Adding these pragmas since gcc-12.2.1 // incorrectly reports a warning with the use of evaluate #if defined(__GNUC__) && !defined(__clang__) @@ -324,6 +346,13 @@ std::vector mapLinearOpIterDomains( IterDomain* newOutputIterDomain( const std::vector& input_ids, const std::optional force_iter_type) { + NVF_ERROR(!input_ids.empty()); + + // If any input ID is a RaggedIterDomain, the output should also be ragged + if (input_ids.front()->isA()) { + return newOutputRaggedIterDomain(input_ids, force_iter_type); + } + // For the start and stop offsets, take the maximum of input axes. // For now, the offsets of both start and stop are always integer // constant, so we can statically compute them. It is unclear diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index f7eaac14c2e..3bd7127f78d 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -354,35 +354,48 @@ TEST_F(RaggedIterDomainTest, AsNestedBasic) { // Create nested tensor from dimension 0 auto nested = asNested(data, offsets, 0); - fusion.addOutput(nested); - - // Verify the output is a new TensorView - EXPECT_TRUE(nested != nullptr); - EXPECT_NE(nested, data); - EXPECT_TRUE(nested->isA()); - - // Verify nested tensor has 3 dimensions: [component, ragged, original_dim1] - EXPECT_EQ(nested->nDims(), 3); - - // First axis should be a regular IterDomain (component) - EXPECT_TRUE(nested->axis(0)->isStrictlyA()); - EXPECT_FALSE(nested->axis(0)->isA()); - - // Second axis should be a RaggedIterDomain - EXPECT_TRUE(nested->axis(1)->isA()); - - // Third axis should be the original second dimension - EXPECT_TRUE(nested->axis(2)->isStrictlyA()); - - // Verify the definition exists (LoadStoreOp for aliasing) - EXPECT_TRUE(nested->definition() != nullptr); - EXPECT_TRUE(nested->definition()->isA()); - - // Verify the component and ragged IterDomains have Partition as their - // definition - EXPECT_TRUE(nested->axis(0)->definition() != nullptr); - EXPECT_TRUE(nested->axis(0)->definition()->isA()); - EXPECT_EQ(nested->axis(0)->definition(), nested->axis(1)->definition()); + // This should still be a nested tensor + auto copy_of_nested = set(nested); + + fusion.addOutput(copy_of_nested); + + for (auto nested_tv : {nested, copy_of_nested}) { + // Verify the output is a new TensorView + EXPECT_TRUE(nested_tv != nullptr); + EXPECT_NE(nested_tv, data); + EXPECT_TRUE(nested_tv->isA()); + + // Verify nested_tv tensor has 3 dimensions: [component, ragged, + // original_dim1] + EXPECT_EQ(nested_tv->nDims(), 3); + + // First axis should be a regular IterDomain (component) + EXPECT_TRUE(nested_tv->axis(0)->isStrictlyA()); + EXPECT_FALSE(nested_tv->axis(0)->isA()); + + // Second axis should be a RaggedIterDomain + EXPECT_TRUE(nested_tv->axis(1)->isA()); + + // Third axis should be the original second dimension + EXPECT_TRUE(nested_tv->axis(2)->isStrictlyA()); + + if (nested_tv == nested) { + // Verify the definition exists (LoadStoreOp for aliasing) + EXPECT_TRUE(nested_tv->definition() != nullptr); + EXPECT_TRUE(nested_tv->definition()->isA()); + + // Verify the component and ragged IterDomains have Partition as their + // definition + EXPECT_TRUE(nested_tv->axis(0)->definition() != nullptr); + EXPECT_TRUE(nested_tv->axis(0)->definition()->isA()); + EXPECT_EQ( + nested_tv->axis(0)->definition(), nested_tv->axis(1)->definition()); + } else { + // The copy of the original nested tensor does not inherit the Partition + // op + EXPECT_TRUE(nested_tv->axis(0)->definition() == nullptr); + } + } } // asNested on different dimensions From 9e0b161b9adf62d62172f78a60184c9fd8ae4327 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 17 Dec 2025 13:37:32 -0800 Subject: [PATCH 19/31] Extend ops to be aware with RaggediterDomain --- csrc/ir/internal_base_nodes.cpp | 14 ++ csrc/ir/internal_base_nodes.h | 5 +- csrc/ops/alias.cpp | 3 +- csrc/ops/utils.cpp | 17 +- csrc/ops/utils.h | 6 + tests/cpp/test_ragged_iter_domain.cpp | 298 ++++++++++++++++++++++++++ 6 files changed, 334 insertions(+), 9 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 9d9984a3d11..3b7c31a89e9 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -895,6 +895,20 @@ std::string RaggedIterDomain::toString(int indent_size) const { return toInlineString(indent_size); } +IterDomain* RaggedIterDomain::cloneWithoutRFactor(bool map_with_original) { + // Create a new RaggedIterDomain with the same extents and properties + auto cloned = IrBuilder::create( + extents_, getIterType(), getParallelType()); + + // Optionally map the clone with the original in the Exact graph + if (map_with_original) { + // TODO: Implement mapping if needed + NVF_THROW("Not implemented"); + } + + return cloned; +} + std::pair RaggedIterDomain::partition( IterDomain* in, TensorView* offsets) { diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index cfade4ebcba..d56d4d21470 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -122,7 +122,7 @@ class NVF_API IterDomain : public Val { //! //! When map_with_original is true, the clone of the original is //! mapped in the Exact graph. - IterDomain* cloneWithoutRFactor(bool map_with_original = false); + virtual IterDomain* cloneWithoutRFactor(bool map_with_original = false); //! Clone a vector domains static std::vector clone( @@ -492,6 +492,9 @@ class NVF_API RaggedIterDomain : public IterDomain { IterDomain* in, TensorView* offsets); + //! Override cloneWithoutRFactor to preserve RaggedIterDomain type + IterDomain* cloneWithoutRFactor(bool map_with_original = false) override; + private: //! Extent tensor containing all component extents //! Can be 1D, 2D, or N-D depending on nesting structure diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index e32aa5e6b9c..870a1c186a3 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1038,8 +1038,7 @@ TensorView* broadcast( .iter_type(IterType::Broadcast) .build()); } else { - out_domain.push_back( - IterDomainBuilder(inp_domain[iinp]).resetSchedulingParams().build()); + out_domain.push_back(inp_domain[iinp]->cloneWithoutRFactor()); iinp++; } ibdim++; diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 80aa95c1fe2..2623f4c68be 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -316,8 +316,7 @@ std::vector mapLinearOpIterDomains( } RaggedIterDomain* newOutputRaggedIterDomain( - const std::vector& input_ids, - const std::optional force_iter_type) { + const std::vector& input_ids) { NVF_ERROR( std::ranges::all_of( input_ids, @@ -329,8 +328,6 @@ RaggedIterDomain* newOutputRaggedIterDomain( NVF_ERROR(!input_ids.empty()); RaggedIterDomain* ref_input_id = input_ids.front()->as(); - NVF_ERROR(!force_iter_type.has_value(), "forced iter type not considered"); - return IrBuilder::create( ref_input_id->extents(), ref_input_id->getIterType(), @@ -349,8 +346,16 @@ IterDomain* newOutputIterDomain( NVF_ERROR(!input_ids.empty()); // If any input ID is a RaggedIterDomain, the output should also be ragged - if (input_ids.front()->isA()) { - return newOutputRaggedIterDomain(input_ids, force_iter_type); + bool has_ragged = + std::any_of(input_ids.begin(), input_ids.end(), [](IterDomain* id) { + return id->isA(); + }); + + if (has_ragged) { + NVF_ERROR( + !force_iter_type.has_value(), + "force_iter_type not supported for RaggedIterDomain"); + return newOutputRaggedIterDomain(input_ids); } // For the start and stop offsets, take the maximum of input axes. diff --git a/csrc/ops/utils.h b/csrc/ops/utils.h index 44a98242a4d..3ceadc4aa6a 100644 --- a/csrc/ops/utils.h +++ b/csrc/ops/utils.h @@ -92,6 +92,12 @@ std::vector mapLinearOpIterDomains( size_t out_size, bool k_bcast); +// Creates an output RaggedIterDomain from input RaggedIterDomains at the same +// dimension position. All inputs must be RaggedIterDomain. Uses the extents, +// IterType, and ParallelType from the first input. +RaggedIterDomain* newOutputRaggedIterDomain( + const std::vector& input_ids); + // Takes a vector of aligned input iterdomains to create the output iterdomain. // This is used if the input iterdomains are not trivially mapped to the output // iterdomains. For eg: MatmulOp. If given, the forced_iter_type argument will diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 3bd7127f78d..ac3c3ef1f35 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -495,4 +495,302 @@ TEST_F(RaggedIterDomainTest, AsNestedValidationMultiDimOffsets) { EXPECT_THROW(asNested(data, offsets_2d, 0), nvfuser::nvfError); } +// Test binary operations with nested tensors +TEST_F(RaggedIterDomainTest, BinaryOpWithNestedTensors) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create two 2D input tensors + auto data1 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data1); + + auto data2 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data2); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensors from both inputs + auto nested1 = asNested(data1, offsets, 0); + auto nested2 = asNested(data2, offsets, 0); + + // Perform binary operation: add + auto result = add(nested1, nested2); + + fusion.addOutput(result); + + // Verify the result has 3 dimensions: [component, ragged, original_dim1] + EXPECT_EQ(result->nDims(), 3); + + // First axis should be a regular IterDomain (component) + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_FALSE(result->axis(0)->isA()); + + // Second axis should be a RaggedIterDomain + EXPECT_TRUE(result->axis(1)->isA()); + + // Third axis should be the original second dimension + EXPECT_TRUE(result->axis(2)->isStrictlyA()); +} + +// Test binary operation with mixed inputs (one ragged, one not) - should error +TEST_F(RaggedIterDomainTest, BinaryOpMixedInputsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data1 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data1); + + auto data2 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data2); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor from first input only + auto nested1 = asNested(data1, offsets, 0); + + // Try to add nested tensor with non-nested tensor + // This should fail because one is ragged and one is not + EXPECT_THROW(add(nested1, data2), nvfuser::nvfError); +} + +// Test binary operation with different offsets +TEST_F(RaggedIterDomainTest, BinaryOpDifferentRaggedStructures) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data1 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data1); + + auto data2 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data2); + + auto offsets1 = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets1); + + auto offsets2 = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets2); + + // Create nested tensors with different offset tensors + auto nested1 = asNested(data1, offsets1, 0); + auto nested2 = asNested(data2, offsets2, 0); + + // This would be an error if, for example, the values of the offset + // tensors are not equivalent, but, like normal tensors, we assume + // that is indeed the case. + auto result = add(nested1, nested2); + fusion.addOutput(result); + + EXPECT_TRUE(result->axis(1)->isA()); +} + +// Test unary operations with nested tensors +TEST_F(RaggedIterDomainTest, UnaryOpWithNestedTensors) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor + auto nested = asNested(data, offsets, 0); + + // Perform unary operation: neg + auto result = neg(nested); + + fusion.addOutput(result); + + // Verify the result preserves RaggedIterDomain structure + EXPECT_EQ(result->nDims(), 3); + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_TRUE(result->axis(1)->isA()); + EXPECT_TRUE(result->axis(2)->isStrictlyA()); +} + +// Test broadcast with nested tensors +TEST_F(RaggedIterDomainTest, BroadcastWithNestedTensors) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + auto result = broadcast(nested, {false, false, false, true}); + + fusion.addOutput(result); + + // Result should be: [component, ragged, dim1, broadcast_dim] + EXPECT_EQ(result->nDims(), 4); + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_TRUE(result->axis(1)->isA()); + EXPECT_TRUE(result->axis(2)->isStrictlyA()); + EXPECT_TRUE(result->axis(3)->isStrictlyA()); + EXPECT_TRUE(result->axis(3)->isBroadcast()); +} + +// Test squeeze on non-ragged dimension +TEST_F(RaggedIterDomainTest, SqueezeNonRaggedDim) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // First broadcast to add a dimension: [component, ragged, dim1, 1] + auto broadcasted = broadcast(nested, {false, false, false, true}); + + // Then squeeze the broadcast dimension (dimension index 3) + auto result = squeeze(broadcasted, {3}); + + fusion.addOutput(result); + + // Result should be: [component, ragged, dim1] + EXPECT_EQ(result->nDims(), 3); + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_TRUE(result->axis(1)->isA()); + EXPECT_TRUE(result->axis(2)->isStrictlyA()); +} + +// Test unsqueeze with nested tensors +TEST_F(RaggedIterDomainTest, UnsqueezeWithNestedTensors) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Unsqueeze to add dimension at the end + auto result = unsqueeze(nested, -1); + + fusion.addOutput(result); + + // Result should be: [component, ragged, dim1, 1] + EXPECT_EQ(result->nDims(), 4); + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_TRUE(result->axis(1)->isA()); + EXPECT_TRUE(result->axis(2)->isStrictlyA()); + EXPECT_TRUE(result->axis(3)->isStrictlyA()); +} + +// Test permute/transpose with nested tensors +TEST_F(RaggedIterDomainTest, PermuteWithNestedTensors) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Permute dimensions: swap ragged and dim1 + auto result = permute(nested, {0, 2, 1}); + + fusion.addOutput(result); + + // Result should be: [component, dim1, ragged] + EXPECT_EQ(result->nDims(), 3); + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_TRUE(result->axis(1)->isStrictlyA()); + EXPECT_TRUE(result->axis(2)->isA()); +} + +// Test reduction on non-ragged dimension +TEST_F(RaggedIterDomainTest, ReductionOnNonRaggedDim) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Reduce along the last dimension (non-ragged) + auto result = sum(nested, {2}); + + fusion.addOutput(result); + + // Result should be: [component, ragged] + // Debug: print the result structure + std::cout << "ReductionOnNonRaggedDim result dimensions: " << result->nDims() + << std::endl; + for (auto i : c10::irange(result->nDims())) { + std::cout << " axis " << i << ": " + << (result->axis(i)->isA() ? "RaggedIterDomain" + : "IterDomain") + << std::endl; + } + + EXPECT_EQ(result->nDims(), 2); + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_TRUE(result->axis(1)->isA()); +} + +// Test reduction on ragged dimension +TEST_F(RaggedIterDomainTest, ReductionOnRaggedDim) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Reduce along the ragged dimension (axis 1) + auto result = sum(nested, {1}); + + fusion.addOutput(result); + + // Result should be: [component, dim1] + // Both should be regular IterDomains (ragged dimension is reduced away) + // Debug: print the result structure + std::cout << "ReductionOnRaggedDim result dimensions: " << result->nDims() + << std::endl; + for (auto i : c10::irange(result->nDims())) { + std::cout << " axis " << i << ": " + << (result->axis(i)->isA() ? "RaggedIterDomain" + : "IterDomain") + << std::endl; + } + + EXPECT_EQ(result->nDims(), 2); + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_FALSE(result->axis(0)->isA()); + EXPECT_TRUE(result->axis(1)->isStrictlyA()); + EXPECT_FALSE(result->axis(1)->isA()); +} + } // namespace nvfuser From 60a2dd51b3e5e5321abc4cffa1b6f58c34c12cb3 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 17 Dec 2025 14:15:04 -0800 Subject: [PATCH 20/31] RaggedIterDomain and reduction --- csrc/ops/arith.cpp | 26 +++++++--- tests/cpp/test_ragged_iter_domain.cpp | 68 ++++++++++++++------------- 2 files changed, 55 insertions(+), 39 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 10e4f4007b8..f6dceef7f1d 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -1209,6 +1209,14 @@ TensorView* newForReduction( " of tensor ", tv); } + NVF_CHECK( + !id->isA(), + "Cannot reduce a RaggedIterDomain. Reduction of ragged dimensions is " + "not supported. " + "Tried to reduce ID = ", + id, + " of tensor ", + tv); new_id = IterDomainBuilder(id) // If the domain is being reduced, but it's coming in as an // expanded extent, we need to realize the expand. @@ -1217,12 +1225,18 @@ TensorView* newForReduction( .iter_type(IterType::Reduction) .build(); } else { - new_id = IterDomainBuilder(id) - .extent(id->extent()) - .resetSchedulingParams() - .parallel_type(id->getParallelType()) - .iter_type(id->getIterType()) - .build(); + // For non-reduced dimensions, preserve RaggedIterDomain if present + if (id->isA()) { + // Cast away const since cloneWithoutRFactor is not const + new_id = const_cast(id)->cloneWithoutRFactor(); + } else { + new_id = IterDomainBuilder(id) + .extent(id->extent()) + .resetSchedulingParams() + .parallel_type(id->getParallelType()) + .iter_type(id->getIterType()) + .build(); + } } new_domain.push_back(new_id); } diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index ac3c3ef1f35..2fa9cf04d2d 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -740,23 +740,17 @@ TEST_F(RaggedIterDomainTest, ReductionOnNonRaggedDim) { fusion.addOutput(result); // Result should be: [component, ragged] - // Debug: print the result structure - std::cout << "ReductionOnNonRaggedDim result dimensions: " << result->nDims() - << std::endl; - for (auto i : c10::irange(result->nDims())) { - std::cout << " axis " << i << ": " - << (result->axis(i)->isA() ? "RaggedIterDomain" - : "IterDomain") - << std::endl; - } + // Get non-reduction dimensions + auto non_reduction_domain = + TensorDomain::noReductions(result->getLogicalDomain()); - EXPECT_EQ(result->nDims(), 2); - EXPECT_TRUE(result->axis(0)->isStrictlyA()); - EXPECT_TRUE(result->axis(1)->isA()); + EXPECT_EQ(non_reduction_domain.size(), 2); + EXPECT_TRUE(non_reduction_domain[0]->isStrictlyA()); + EXPECT_TRUE(non_reduction_domain[1]->isA()); } -// Test reduction on ragged dimension -TEST_F(RaggedIterDomainTest, ReductionOnRaggedDim) { +// Test reduction on ragged dimension - should error +TEST_F(RaggedIterDomainTest, ReductionOnRaggedDimError) { Fusion fusion; FusionGuard fg(&fusion); @@ -769,28 +763,36 @@ TEST_F(RaggedIterDomainTest, ReductionOnRaggedDim) { // Create nested tensor: [component, ragged, dim1] auto nested = asNested(data, offsets, 0); - // Reduce along the ragged dimension (axis 1) - auto result = sum(nested, {1}); + // Try to reduce along the ragged dimension (axis 1) + // This should throw an error because reducing RaggedIterDomain is not allowed + EXPECT_THROW(sum(nested, {1}), nvfuser::nvfError); +} - fusion.addOutput(result); +// Test reduction on component dimension - should error (TODO) +TEST_F(RaggedIterDomainTest, ReductionOnComponentDimError) { + GTEST_SKIP() << "TODO: Implement validation to prevent reduction of " + "component dimension. " + << "Currently there is no explicit marking of which IterDomains " + "are component dimensions, " + << "so this validation cannot be implemented yet."; - // Result should be: [component, dim1] - // Both should be regular IterDomains (ragged dimension is reduced away) - // Debug: print the result structure - std::cout << "ReductionOnRaggedDim result dimensions: " << result->nDims() - << std::endl; - for (auto i : c10::irange(result->nDims())) { - std::cout << " axis " << i << ": " - << (result->axis(i)->isA() ? "RaggedIterDomain" - : "IterDomain") - << std::endl; - } + Fusion fusion; + FusionGuard fg(&fusion); - EXPECT_EQ(result->nDims(), 2); - EXPECT_TRUE(result->axis(0)->isStrictlyA()); - EXPECT_FALSE(result->axis(0)->isA()); - EXPECT_TRUE(result->axis(1)->isStrictlyA()); - EXPECT_FALSE(result->axis(1)->isA()); + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to reduce along the component dimension (axis 0) + // This should throw an error because reducing component dimensions is not + // allowed The component dimension defines the batch structure of the ragged + // tensor, and reducing it would destroy the ragged structure + EXPECT_THROW(sum(nested, {0}), nvfuser::nvfError); } } // namespace nvfuser From 566d63de9021698a38032f1f47af4ac178673204 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 17 Dec 2025 19:23:55 -0800 Subject: [PATCH 21/31] WIP --- csrc/ir/internal_base_nodes.cpp | 90 ++++++++++++++++++++++++++++++++- csrc/ir/internal_base_nodes.h | 6 +++ csrc/ops/alias.cpp | 50 +++++++++++++++++- csrc/ops/indexing.cpp | 50 ++++++++++++++++++ 4 files changed, 193 insertions(+), 3 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 3b7c31a89e9..8897ac94e8e 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -48,7 +48,11 @@ IterDomainBuilder::IterDomainBuilder(const IterDomain* id) is_rfactor_domain_(id->isRFactorProduct()), is_padded_dimension_(id->hasPaddingToMultipleOfWarp()), is_clustered_dimension_(id->isClusteredBlockDim()), - padded_to_size_(id->getMaybeSizeAfterPadding()) {} + padded_to_size_(id->getMaybeSizeAfterPadding()) { + if (id->isA()) { + ragged_extents_ = id->as()->extents(); + } +} IterDomainBuilder& IterDomainBuilder::resetSchedulingParams() { parallel_type_ = ParallelType::Serial; @@ -116,7 +120,13 @@ IterDomain* IterDomainBuilder::build() const { NVF_ERROR( start_ != nullptr && extent_ != nullptr, "Start and extent are required to build an iter domain."); - return IrBuilder::createInContainer(start_->container(), *this); + + if (ragged_extents_ != nullptr) { + return IrBuilder::createInContainer( + start_->container(), *this); + } else { + return IrBuilder::createInContainer(start_->container(), *this); + } } IterDomain::IterDomain( @@ -604,6 +614,11 @@ IterDomain* IterDomain::resize( "Non-zero stop offset not considered: ", in->toString()); + NVF_CHECK( + !in->isA(), + "Resizing RaggedIterDomain is not supported: ", + in->toString()); + // The overall extent is (in_extent + left_expansion + // right_expansion). This can be simplified for a slice op as // the right expansion should look like (slice_end_offset - @@ -815,6 +830,77 @@ void validateLoopDomain( } // namespace +RaggedIterDomain::RaggedIterDomain( + IrBuilderPasskey passkey, + const IterDomainBuilder& args) + : IterDomain( + passkey, + ValType::RaggedIterDomain, + args.start_, + args.extent_, + args.expanded_extent_, + args.stop_offset_, + args.parallel_type_, + args.iter_type_, + args.is_rfactor_domain_, + args.is_padded_dimension_, + args.is_clustered_dimension_, + args.padded_to_size_), + extents_(args.ragged_extents_) { + // Extents must be non-null + NVF_ERROR( + extents_ != nullptr, "RaggedIterDomain requires non-null extents tensor"); + + // Extents must have integer dtype + NVF_ERROR_EQ( + extents_->dtype(), + DataType::Index, + "RaggedIterDomain extents must have index type, got ", + extents_->dtype()); + + // Only IterType::Iteration is supported at this moment + NVF_ERROR_EQ( + iter_type_, + IterType::Iteration, + "Only IterType::Iteration is supported: ", + iter_type_); + + // RaggedIterDomain has specific requirements on member values + NVF_ERROR( + start_->isZeroInt(), + "RaggedIterDomain start must be zero, got: ", + start_->toInlineString()); + + NVF_ERROR( + extent_->isOneInt(), + "RaggedIterDomain extent must be one (placeholder), got: ", + extent_->toInlineString()); + + NVF_ERROR( + expanded_extent_ == nullptr, + "RaggedIterDomain does not support expanded_extent"); + + NVF_ERROR( + stop_offset_ == nullptr || stop_offset_->isZeroInt(), + "RaggedIterDomain stop_offset must be nullptr or zero, got: ", + stop_offset_ ? stop_offset_->toInlineString() : "nullptr"); + + NVF_ERROR( + !is_rfactor_domain_, "RaggedIterDomain does not support rfactor domains"); + + NVF_ERROR( + !is_padded_dimension_, + "RaggedIterDomain does not support padded dimensions"); + + NVF_ERROR( + !is_clustered_dimension_, + "RaggedIterDomain does not support clustered dimensions"); + + NVF_ERROR( + !padded_to_size_.has_value(), + "RaggedIterDomain does not support padded_to_size"); +} + RaggedIterDomain::RaggedIterDomain( IrBuilderPasskey passkey, TensorView* extents, diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index d56d4d21470..1f3e6658e01 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -61,6 +61,7 @@ class IterDomainBuilder { IterDomainBuilder& is_rfactor_domain(bool _is_rfactor_domain); IterDomainBuilder& is_padded_dimension(bool _is_padded_dimension); IterDomainBuilder& padded_to_size(std::optional _padded_to_size); + IterDomainBuilder& ragged_extents(TensorView* _ragged_extents); IterDomain* build() const; @@ -79,6 +80,9 @@ class IterDomainBuilder { bool is_padded_dimension_ = false; bool is_clustered_dimension_ = false; std::optional padded_to_size_ = std::nullopt; + + // For RaggedIterDomain: stores the extents tensor + TensorView* ragged_extents_ = nullptr; }; //! Simply a representation of an annotated 1D iterable from start to extent. @@ -448,6 +452,8 @@ class NVF_API IterDomain : public Val { //! components class NVF_API RaggedIterDomain : public IterDomain { public: + RaggedIterDomain(IrBuilderPasskey passkey, const IterDomainBuilder& args); + //! \param extents TensorView containing component extents (must be integer //! type) //! \param iter_type Iteration type (Iteration, Reduction, etc.) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 870a1c186a3..16298a04d37 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -42,6 +42,13 @@ TensorView* segment_set(TensorView* tv) { TensorView* view(TensorView* x, DataType dtype) { NVF_ERROR(x != nullptr, "Input is invalid."); + + NVF_CHECK( + !x->domain()->hasRaggedIterDomain(), + "View operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + x->toString()); + if (x->getDataType() == dtype) { return x; } @@ -142,6 +149,12 @@ TensorView* reshape(TensorView* inp_tv, const std::vector& new_sizes) { "Unsupported input tensor to reshape as its axes may be partial: ", inp_tv->toString()); + NVF_CHECK( + !inp_tv->domain()->hasRaggedIterDomain(), + "Reshape operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + inp_tv->toString()); + auto static_reshape_output = tryStaticReshape(inp_tv, inp_dom, new_sizes); if (static_reshape_output) { return static_reshape_output; @@ -239,6 +252,12 @@ TensorView* flatten(TensorView* x, int64_t start_dim, int64_t end_dim) { end_dim); NVF_CHECK(start_dim <= end_dim, "start_dim must be <= end_dim"); + NVF_CHECK( + !x->domain()->hasRaggedIterDomain(), + "Flatten operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + x->toString()); + if (start_dim == end_dim) { return x; } @@ -518,6 +537,11 @@ TensorView* pad( const std::vector& pad_widths, Val* value, std::optional iter_type_opt) { + NVF_CHECK( + !inp->domain()->hasRaggedIterDomain(), + "Padding a tensor with a RaggedIterDomain not supported: ", + inp->toString()); + DataType dt = inp->getDataType().value(); if (!value) { // Create a zero of the appropriate type @@ -623,6 +647,13 @@ TensorView* cat( std::optional iter_type_opt, bool manual_padding) { NVF_CHECK(!inputs.empty(), "No input tensor given"); + NVF_CHECK( + std::ranges ::none_of( + inputs, + [](TensorView* inp_tv) { + return inp_tv->domain()->hasRaggedIterDomain(); + }), + "Concatenating a tensor with a RaggedIterDomain not supported"); const auto dtype = inputs.at(0)->getDataType().value(); @@ -783,7 +814,12 @@ TensorView* slice( NVF_CHECK_EQ( ndims, std::ssize(ranges), - "The range vector must have the same number of Slice descriptors.") + "The range vector must have the same number of Slice descriptors."); + + NVF_CHECK( + !inp->domain()->hasRaggedIterDomain(), + "Slicing a tensor with a RaggedIterDomain not supported: ", + inp->toString()); ExpressionEvaluator expr_eval; @@ -1058,6 +1094,12 @@ TensorView* broadcast( TensorView* expand(TensorView* inp, const std::vector& expanded_sizes) { auto inp_domain = TensorDomain::noReductions(inp->getLogicalDomain()); + NVF_CHECK( + !inp->domain()->hasRaggedIterDomain(), + "Expand operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + inp->toString()); + NVF_CHECK_GE(expanded_sizes.size(), inp_domain.size()); inp = ops::maybe_broadcast_inner_to_rank(inp, expanded_sizes.size()); @@ -1180,6 +1222,12 @@ TensorView* expand_as(TensorView* inp, TensorView* other) { TensorView* repeat( TensorView* inp_tv, const std::vector& repeat_times) { + NVF_CHECK( + !inp_tv->domain()->hasRaggedIterDomain(), + "Repeat operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + inp_tv->toString()); + const auto ndims = TensorDomain::noReductions(inp_tv->getLogicalDomain()).size(); diff --git a/csrc/ops/indexing.cpp b/csrc/ops/indexing.cpp index fb2f1b3feda..a28ca67f72b 100644 --- a/csrc/ops/indexing.cpp +++ b/csrc/ops/indexing.cpp @@ -24,6 +24,12 @@ TensorView* select(TensorView* tv, int64_t dim, Val* index) { auto dom = TensorDomain::noReductions(tv->getLogicalDomain()); NVF_CHECK(!dom.empty(), "select can not be applied to 0d tensor."); + NVF_CHECK( + !tv->domain()->hasRaggedIterDomain(), + "Select operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + tv->toString()); + std::vector new_root; new_root.reserve(dom.size() - 1); dim = wrapDim(dim, (int64_t)dom.size()); @@ -46,6 +52,18 @@ TensorView* indexSelect( TensorView* lookup_tv, int64_t dim, TensorView* index_tv) { + NVF_CHECK( + !lookup_tv->domain()->hasRaggedIterDomain(), + "IndexSelect operation is not supported for tensors with RaggedIterDomain. " + "Input tensor (lookup_tv): ", + lookup_tv->toString()); + + NVF_CHECK( + !index_tv->domain()->hasRaggedIterDomain(), + "IndexSelect operation is not supported for tensors with RaggedIterDomain. " + "Index tensor (index_tv): ", + index_tv->toString()); + DataType dtype = lookup_tv->getDataType().value(); NVF_CHECK( dtype != DataType::Null, "Invalid datatype provided for new value."); @@ -131,6 +149,18 @@ TensorView* indexPutAccumulate( // torch.gather TensorView* gather(TensorView* inp, int64_t dim, TensorView* index) { + NVF_CHECK( + !inp->domain()->hasRaggedIterDomain(), + "Gather operation is not supported for tensors with RaggedIterDomain. " + "Input tensor (inp): ", + inp->toString()); + + NVF_CHECK( + !index->domain()->hasRaggedIterDomain(), + "Gather operation is not supported for tensors with RaggedIterDomain. " + "Index tensor (index): ", + index->toString()); + auto inp_domain = TensorDomain::noReductions(inp->getLogicalDomain()); auto idx_domain = TensorDomain::noReductions(index->getLogicalDomain()); NVF_CHECK( @@ -168,6 +198,26 @@ TensorView* scatter( TensorView* index, Val* src, std::optional accumulate_op) { + NVF_CHECK( + !self->domain()->hasRaggedIterDomain(), + "Scatter operation is not supported for tensors with RaggedIterDomain. " + "Input tensor (self): ", + self->toString()); + + NVF_CHECK( + !index->domain()->hasRaggedIterDomain(), + "Scatter operation is not supported for tensors with RaggedIterDomain. " + "Index tensor (index): ", + index->toString()); + + if (src->isA()) { + NVF_CHECK( + !src->as()->domain()->hasRaggedIterDomain(), + "Scatter operation is not supported for tensors with RaggedIterDomain. " + "Source tensor (src): ", + src->toString()); + } + auto self_dom = TensorDomain::noReductions(self->getLogicalDomain()); auto idx_dom = TensorDomain::noReductions(index->getLogicalDomain()); From 144b206c988c824ab1bfbcf926d53e9bcc0c85f5 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 17 Dec 2025 21:25:37 -0800 Subject: [PATCH 22/31] WIP --- csrc/ops/arith.cpp | 75 ++++++ tests/cpp/test_ragged_iter_domain.cpp | 336 ++++++++++++++++++++++++++ 2 files changed, 411 insertions(+) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index f6dceef7f1d..ce228a1d281 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -276,6 +276,12 @@ TensorView* randn_like(TensorView* tv, Val* philox_seed, Val* philox_offset) { isFloatingPointType(tv->dtype()), "input must have floating point type, but got ", tv->dtype()); + NVF_CHECK( + !tv->domain()->hasRaggedIterDomain(), + "randn_like operation is not supported for tensors with " + "RaggedIterDomain. " + "Input tensor: ", + tv->toString()); // Create a new output TV manually so that we carry over IterTypes, instead // of inferring them from the shape as we would if we used randn(). TensorView* out = ops::newOutputTV({tv}, tv->dtype()); @@ -303,6 +309,11 @@ TensorView* rand_like(TensorView* tv, Val* philox_seed, Val* philox_offset) { isFloatingPointType(tv->dtype()), "input must have floating point type, but got ", tv->dtype()); + NVF_CHECK( + !tv->domain()->hasRaggedIterDomain(), + "rand_like operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + tv->toString()); // Create a new output TV manually so that we carry over IterTypes, instead // of inferring them from the shape as we would if we used rand(). TensorView* out = ops::newOutputTV({tv}, tv->dtype()); @@ -339,6 +350,11 @@ TensorView* full( } TensorView* full_like(TensorView* tv, Val* fill_value, DataType dtype) { + NVF_CHECK( + !tv->domain()->hasRaggedIterDomain(), + "full_like operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + tv->toString()); fill_value = maybeCastOp(dtype, fill_value); TensorView* out = ops::newOutputTV({tv}, dtype); IrBuilder::create(out, fill_value); @@ -1575,6 +1591,31 @@ WelfordResult WelfordRaw( TensorView* init_avg, TensorView* init_var, Val* init_N) { + NVF_CHECK( + !tv->domain()->hasRaggedIterDomain(), + "WelfordRaw operation is not supported for tensors with " + "RaggedIterDomain. " + "Input tensor (tv): ", + tv->toString()); + + if (init_avg != nullptr) { + NVF_CHECK( + !init_avg->domain()->hasRaggedIterDomain(), + "WelfordRaw operation is not supported for tensors with " + "RaggedIterDomain. " + "Initial average tensor (init_avg): ", + init_avg->toString()); + } + + if (init_var != nullptr) { + NVF_CHECK( + !init_var->domain()->hasRaggedIterDomain(), + "WelfordRaw operation is not supported for tensors with " + "RaggedIterDomain. " + "Initial variance tensor (init_var): ", + init_var->toString()); + } + NVF_CHECK( TensorDomain::sameAs(tv->getLogicalDomain(), tv->getLoopDomain()), "Reducing a tensor once it's gone under transformations is not permitted " @@ -1645,6 +1686,28 @@ WelfordResult Welford( TensorView* init_avg, TensorView* init_var, Val* init_N) { + NVF_CHECK( + !tv->domain()->hasRaggedIterDomain(), + "Welford operation is not supported for tensors with RaggedIterDomain. " + "Input tensor (tv): ", + tv->toString()); + + if (init_avg != nullptr) { + NVF_CHECK( + !init_avg->domain()->hasRaggedIterDomain(), + "Welford operation is not supported for tensors with RaggedIterDomain. " + "Initial average tensor (init_avg): ", + init_avg->toString()); + } + + if (init_var != nullptr) { + NVF_CHECK( + !init_var->domain()->hasRaggedIterDomain(), + "Welford operation is not supported for tensors with RaggedIterDomain. " + "Initial variance tensor (init_var): ", + init_var->toString()); + } + NVF_CHECK( TensorDomain::sameAs(tv->getLogicalDomain(), tv->getLoopDomain()), "Reducing a tensor once it's gone under transformations is not permitted " @@ -1991,6 +2054,12 @@ TensorView* clamp(TensorView* in, Val* min_val, Val* max_val) { // sum_to operator TensorView* sum_to(TensorView* in, const std::vector& sum_to_size) { + NVF_CHECK( + !in->domain()->hasRaggedIterDomain(), + "sum_to operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + in->toString()); + const auto& logical = TensorDomain::noReductions(in->getLogicalDomain()); NVF_CHECK( @@ -2038,6 +2107,12 @@ TensorView* sum_to(TensorView* in, const std::vector& sum_to_size) { } TensorView* sum_to(TensorView* in, const std::vector& sum_to_size) { + NVF_CHECK( + !in->domain()->hasRaggedIterDomain(), + "sum_to operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + in->toString()); + const auto& logical = TensorDomain::noReductions(in->getLogicalDomain()); NVF_CHECK( diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 2fa9cf04d2d..a8e7deb4be6 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -768,6 +768,342 @@ TEST_F(RaggedIterDomainTest, ReductionOnRaggedDimError) { EXPECT_THROW(sum(nested, {1}), nvfuser::nvfError); } +// Test reshape with nested tensors - should error +TEST_F(RaggedIterDomainTest, ReshapeWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to reshape - this should throw an error because reshape is not + // supported for tensors with RaggedIterDomain + std::vector new_shape = { + IrBuilder::create(-1L, DataType::Index), nested->axis(2)->extent()}; + EXPECT_THROW(reshape(nested, new_shape), nvfuser::nvfError); +} + +// Test flatten with nested tensors - should error +TEST_F(RaggedIterDomainTest, FlattenWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to flatten - this should throw an error because flatten is not + // supported for tensors with RaggedIterDomain + EXPECT_THROW(flatten(nested, 0, 2), nvfuser::nvfError); +} + +// Test transpose with nested tensors +TEST_F(RaggedIterDomainTest, TransposeWithNestedTensors) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Transpose ragged and dim1 dimensions + auto result = transpose(nested, 1, 2); + + fusion.addOutput(result); + + // Expected: [component, dim1, ragged] + // Should preserve RaggedIterDomain type + auto non_reduction_domain = + TensorDomain::noReductions(result->getLogicalDomain()); + + EXPECT_EQ(non_reduction_domain.size(), 3); + EXPECT_TRUE(non_reduction_domain[0]->isStrictlyA()); + EXPECT_TRUE(non_reduction_domain[1]->isStrictlyA()); + EXPECT_TRUE(non_reduction_domain[2]->isA()); +} + +// Test slice on ragged dimension - should error +TEST_F(RaggedIterDomainTest, SliceRaggedDimensionError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to slice the ragged dimension (axis 1) + // This should error because resize on RaggedIterDomain is not allowed + EXPECT_THROW( + slice( + nested, + {{fusion.zeroVal(), fusion.oneVal()}, + {fusion.zeroVal(), fusion.oneVal()}, + {fusion.zeroVal(), nested->axis(2)->extent()}}), + nvfuser::nvfError); +} + +// Test cat on ragged dimension - should error +TEST_F(RaggedIterDomainTest, CatRaggedDimensionError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data1 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data1); + + auto data2 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data2); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensors with same structure + auto nested1 = asNested(data1, offsets, 0); + auto nested2 = asNested(data2, offsets, 0); + + // Try to concatenate along ragged dimension (axis 1) + // This should error because cat would need to resize RaggedIterDomain + EXPECT_THROW(cat({nested1, nested2}, 1), nvfuser::nvfError); +} + +// Test cat on non-ragged dimension - currently also errors +TEST_F(RaggedIterDomainTest, CatNonRaggedDimensionError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data1 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data1); + + auto data2 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data2); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensors with same structure + auto nested1 = asNested(data1, offsets, 0); + auto nested2 = asNested(data2, offsets, 0); + + // Try to concatenate along non-ragged dimension (axis 2) + // Currently cat rejects all tensors with RaggedIterDomain for safety + // In the future, this could be supported if concatenating along non-ragged + // dimensions + EXPECT_THROW(cat({nested1, nested2}, 2), nvfuser::nvfError); +} + +// Test expand with nested tensors - should error +TEST_F(RaggedIterDomainTest, ExpandWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to expand a broadcast dimension - should error + auto broadcasted = broadcast(nested, {false, false, false, true}); + EXPECT_THROW( + expand( + broadcasted, + {nested->axis(0)->extent(), + nested->axis(1)->extent(), + nested->axis(2)->extent(), + IrBuilder::create(5L, DataType::Index)}), + nvfuser::nvfError); +} + +// Test pad on ragged dimension - should error +TEST_F(RaggedIterDomainTest, PadRaggedDimensionError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to pad the ragged dimension (axis 1) + // This should error because pad uses resize on RaggedIterDomain + std::vector pad_widths = { + fusion.zeroVal(), + fusion.zeroVal(), // component: no padding + fusion.oneVal(), + fusion.oneVal(), // ragged: PADDING - should error + fusion.zeroVal(), + fusion.zeroVal() // dim1: no padding + }; + + EXPECT_THROW(pad(nested, pad_widths), nvfuser::nvfError); +} + +// Test select with nested tensors - should error +TEST_F(RaggedIterDomainTest, SelectWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to select from a non-ragged dimension - should error + EXPECT_THROW(select(nested, 0, fusion.zeroVal()), nvfuser::nvfError); +} + +// Test gather with nested tensors - should error +TEST_F(RaggedIterDomainTest, GatherWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + auto index = makeSymbolicTensor(3, DataType::Index); + fusion.addInput(index); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to gather from nested tensor - should error + EXPECT_THROW(gather(nested, 2, index), nvfuser::nvfError); +} + +// Test view operations with nested tensors - should error +TEST_F(RaggedIterDomainTest, ViewWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to change dtype via view - should error + EXPECT_THROW(view(nested, DataType::Half), nvfuser::nvfError); +} + +// Test select (indexing) with nested tensors - should error +TEST_F(RaggedIterDomainTest, SelectIndexingWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to select from component dimension - should error + EXPECT_THROW(select(nested, 0, fusion.zeroVal()), nvfuser::nvfError); +} + +// Test index_select with nested tensors - should error +TEST_F(RaggedIterDomainTest, IndexSelectWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + auto indices = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(indices); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to index select from non-ragged dimension - should error + EXPECT_THROW(indexSelect(nested, 2, indices), nvfuser::nvfError); +} + +// Test scatter with nested tensors - should error +TEST_F(RaggedIterDomainTest, ScatterWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + auto src = makeSymbolicTensor(3, DataType::Float); + fusion.addInput(src); + + auto indices = makeSymbolicTensor(3, DataType::Index); + fusion.addInput(indices); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to scatter into nested tensor - should error + EXPECT_THROW(scatter(nested, 2, indices, src), nvfuser::nvfError); +} + +// Test repeat with nested tensors - should error +TEST_F(RaggedIterDomainTest, RepeatWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to repeat along non-ragged dimension - should error + std::vector repeats = {1, 1, 2}; + EXPECT_THROW(repeat(nested, repeats), nvfuser::nvfError); +} + // Test reduction on component dimension - should error (TODO) TEST_F(RaggedIterDomainTest, ReductionOnComponentDimError) { GTEST_SKIP() << "TODO: Implement validation to prevent reduction of " From e2efe752bf4e77049509d2da1b9a9027c9798a45 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 17 Dec 2025 22:47:29 -0800 Subject: [PATCH 23/31] cleanup --- csrc/ir/internal_base_nodes.cpp | 5 - csrc/ops/alias.cpp | 23 +- csrc/ops/arith.cpp | 93 +------ csrc/ops/utils.cpp | 3 + tests/cpp/test_ragged_iter_domain.cpp | 359 +++++++------------------- 5 files changed, 117 insertions(+), 366 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 8897ac94e8e..15336e269ee 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -614,11 +614,6 @@ IterDomain* IterDomain::resize( "Non-zero stop offset not considered: ", in->toString()); - NVF_CHECK( - !in->isA(), - "Resizing RaggedIterDomain is not supported: ", - in->toString()); - // The overall extent is (in_extent + left_expansion + // right_expansion). This can be simplified for a slice op as // the right expansion should look like (slice_end_offset - diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 16298a04d37..ea87ae1d73c 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -42,13 +42,6 @@ TensorView* segment_set(TensorView* tv) { TensorView* view(TensorView* x, DataType dtype) { NVF_ERROR(x != nullptr, "Input is invalid."); - - NVF_CHECK( - !x->domain()->hasRaggedIterDomain(), - "View operation is not supported for tensors with RaggedIterDomain. " - "Input tensor: ", - x->toString()); - if (x->getDataType() == dtype) { return x; } @@ -538,8 +531,8 @@ TensorView* pad( Val* value, std::optional iter_type_opt) { NVF_CHECK( - !inp->domain()->hasRaggedIterDomain(), - "Padding a tensor with a RaggedIterDomain not supported: ", + inp->domain()->hasRaggedIterDomain(), + "Padding a tensor with RaggedIterDomain not supported: ", inp->toString()); DataType dt = inp->getDataType().value(); @@ -647,13 +640,14 @@ TensorView* cat( std::optional iter_type_opt, bool manual_padding) { NVF_CHECK(!inputs.empty(), "No input tensor given"); + NVF_CHECK( - std::ranges ::none_of( + std::ranges::none_of( inputs, [](TensorView* inp_tv) { return inp_tv->domain()->hasRaggedIterDomain(); }), - "Concatenating a tensor with a RaggedIterDomain not supported"); + "Concat with a tensor with RaggedIterDomain not supported"); const auto dtype = inputs.at(0)->getDataType().value(); @@ -818,7 +812,7 @@ TensorView* slice( NVF_CHECK( !inp->domain()->hasRaggedIterDomain(), - "Slicing a tensor with a RaggedIterDomain not supported: ", + "Slicing a tensor with RaggedIterDomain not supported: ", inp->toString()); ExpressionEvaluator expr_eval; @@ -1328,6 +1322,11 @@ TensorView* asNested( 1, "asNested currently only supports 1D offset tensors"); + NVF_CHECK( + !data->domain()->hasRaggedIterDomain(), + "Multiple level of nesting is not supported: ", + data->toString()); + // Get the logical domain of the input, excluding reductions auto inp_logical = TensorDomain::noReductions(data->getLogicalDomain()); diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index ce228a1d281..89d6dfc3a43 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -276,12 +276,6 @@ TensorView* randn_like(TensorView* tv, Val* philox_seed, Val* philox_offset) { isFloatingPointType(tv->dtype()), "input must have floating point type, but got ", tv->dtype()); - NVF_CHECK( - !tv->domain()->hasRaggedIterDomain(), - "randn_like operation is not supported for tensors with " - "RaggedIterDomain. " - "Input tensor: ", - tv->toString()); // Create a new output TV manually so that we carry over IterTypes, instead // of inferring them from the shape as we would if we used randn(). TensorView* out = ops::newOutputTV({tv}, tv->dtype()); @@ -309,11 +303,6 @@ TensorView* rand_like(TensorView* tv, Val* philox_seed, Val* philox_offset) { isFloatingPointType(tv->dtype()), "input must have floating point type, but got ", tv->dtype()); - NVF_CHECK( - !tv->domain()->hasRaggedIterDomain(), - "rand_like operation is not supported for tensors with RaggedIterDomain. " - "Input tensor: ", - tv->toString()); // Create a new output TV manually so that we carry over IterTypes, instead // of inferring them from the shape as we would if we used rand(). TensorView* out = ops::newOutputTV({tv}, tv->dtype()); @@ -350,11 +339,6 @@ TensorView* full( } TensorView* full_like(TensorView* tv, Val* fill_value, DataType dtype) { - NVF_CHECK( - !tv->domain()->hasRaggedIterDomain(), - "full_like operation is not supported for tensors with RaggedIterDomain. " - "Input tensor: ", - tv->toString()); fill_value = maybeCastOp(dtype, fill_value); TensorView* out = ops::newOutputTV({tv}, dtype); IrBuilder::create(out, fill_value); @@ -1241,18 +1225,12 @@ TensorView* newForReduction( .iter_type(IterType::Reduction) .build(); } else { - // For non-reduced dimensions, preserve RaggedIterDomain if present - if (id->isA()) { - // Cast away const since cloneWithoutRFactor is not const - new_id = const_cast(id)->cloneWithoutRFactor(); - } else { - new_id = IterDomainBuilder(id) - .extent(id->extent()) - .resetSchedulingParams() - .parallel_type(id->getParallelType()) - .iter_type(id->getIterType()) - .build(); - } + new_id = IterDomainBuilder(id) + .extent(id->extent()) + .resetSchedulingParams() + .parallel_type(id->getParallelType()) + .iter_type(id->getIterType()) + .build(); } new_domain.push_back(new_id); } @@ -1591,31 +1569,6 @@ WelfordResult WelfordRaw( TensorView* init_avg, TensorView* init_var, Val* init_N) { - NVF_CHECK( - !tv->domain()->hasRaggedIterDomain(), - "WelfordRaw operation is not supported for tensors with " - "RaggedIterDomain. " - "Input tensor (tv): ", - tv->toString()); - - if (init_avg != nullptr) { - NVF_CHECK( - !init_avg->domain()->hasRaggedIterDomain(), - "WelfordRaw operation is not supported for tensors with " - "RaggedIterDomain. " - "Initial average tensor (init_avg): ", - init_avg->toString()); - } - - if (init_var != nullptr) { - NVF_CHECK( - !init_var->domain()->hasRaggedIterDomain(), - "WelfordRaw operation is not supported for tensors with " - "RaggedIterDomain. " - "Initial variance tensor (init_var): ", - init_var->toString()); - } - NVF_CHECK( TensorDomain::sameAs(tv->getLogicalDomain(), tv->getLoopDomain()), "Reducing a tensor once it's gone under transformations is not permitted " @@ -1686,28 +1639,6 @@ WelfordResult Welford( TensorView* init_avg, TensorView* init_var, Val* init_N) { - NVF_CHECK( - !tv->domain()->hasRaggedIterDomain(), - "Welford operation is not supported for tensors with RaggedIterDomain. " - "Input tensor (tv): ", - tv->toString()); - - if (init_avg != nullptr) { - NVF_CHECK( - !init_avg->domain()->hasRaggedIterDomain(), - "Welford operation is not supported for tensors with RaggedIterDomain. " - "Initial average tensor (init_avg): ", - init_avg->toString()); - } - - if (init_var != nullptr) { - NVF_CHECK( - !init_var->domain()->hasRaggedIterDomain(), - "Welford operation is not supported for tensors with RaggedIterDomain. " - "Initial variance tensor (init_var): ", - init_var->toString()); - } - NVF_CHECK( TensorDomain::sameAs(tv->getLogicalDomain(), tv->getLoopDomain()), "Reducing a tensor once it's gone under transformations is not permitted " @@ -2054,12 +1985,6 @@ TensorView* clamp(TensorView* in, Val* min_val, Val* max_val) { // sum_to operator TensorView* sum_to(TensorView* in, const std::vector& sum_to_size) { - NVF_CHECK( - !in->domain()->hasRaggedIterDomain(), - "sum_to operation is not supported for tensors with RaggedIterDomain. " - "Input tensor: ", - in->toString()); - const auto& logical = TensorDomain::noReductions(in->getLogicalDomain()); NVF_CHECK( @@ -2107,12 +2032,6 @@ TensorView* sum_to(TensorView* in, const std::vector& sum_to_size) { } TensorView* sum_to(TensorView* in, const std::vector& sum_to_size) { - NVF_CHECK( - !in->domain()->hasRaggedIterDomain(), - "sum_to operation is not supported for tensors with RaggedIterDomain. " - "Input tensor: ", - in->toString()); - const auto& logical = TensorDomain::noReductions(in->getLogicalDomain()); NVF_CHECK( diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 2623f4c68be..be50385528c 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -326,6 +326,9 @@ RaggedIterDomain* newOutputRaggedIterDomain( "All input iter domains must be RaggedIterDomain"); NVF_ERROR(!input_ids.empty()); + + // Just using the first ragged ID as all input IDs are assumed to be + // equivalent RaggedIterDomain* ref_input_id = input_ids.front()->as(); return IrBuilder::create( diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index a8e7deb4be6..b8406303c01 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -354,48 +354,36 @@ TEST_F(RaggedIterDomainTest, AsNestedBasic) { // Create nested tensor from dimension 0 auto nested = asNested(data, offsets, 0); - // This should still be a nested tensor - auto copy_of_nested = set(nested); + fusion.addOutput(nested); - fusion.addOutput(copy_of_nested); + // Verify the output is a new TensorView + EXPECT_TRUE(nested != nullptr); + EXPECT_NE(nested, data); + EXPECT_TRUE(nested->isA()); + + // Verify nested tensor has 3 dimensions: [component, ragged, + // original_dim1] + EXPECT_EQ(nested->nDims(), 3); + + // First axis should be a regular IterDomain (component) + EXPECT_TRUE(nested->axis(0)->isStrictlyA()); + EXPECT_FALSE(nested->axis(0)->isA()); - for (auto nested_tv : {nested, copy_of_nested}) { - // Verify the output is a new TensorView - EXPECT_TRUE(nested_tv != nullptr); - EXPECT_NE(nested_tv, data); - EXPECT_TRUE(nested_tv->isA()); - - // Verify nested_tv tensor has 3 dimensions: [component, ragged, - // original_dim1] - EXPECT_EQ(nested_tv->nDims(), 3); - - // First axis should be a regular IterDomain (component) - EXPECT_TRUE(nested_tv->axis(0)->isStrictlyA()); - EXPECT_FALSE(nested_tv->axis(0)->isA()); - - // Second axis should be a RaggedIterDomain - EXPECT_TRUE(nested_tv->axis(1)->isA()); - - // Third axis should be the original second dimension - EXPECT_TRUE(nested_tv->axis(2)->isStrictlyA()); - - if (nested_tv == nested) { - // Verify the definition exists (LoadStoreOp for aliasing) - EXPECT_TRUE(nested_tv->definition() != nullptr); - EXPECT_TRUE(nested_tv->definition()->isA()); - - // Verify the component and ragged IterDomains have Partition as their - // definition - EXPECT_TRUE(nested_tv->axis(0)->definition() != nullptr); - EXPECT_TRUE(nested_tv->axis(0)->definition()->isA()); - EXPECT_EQ( - nested_tv->axis(0)->definition(), nested_tv->axis(1)->definition()); - } else { - // The copy of the original nested tensor does not inherit the Partition - // op - EXPECT_TRUE(nested_tv->axis(0)->definition() == nullptr); - } - } + // Second axis should be a RaggedIterDomain + EXPECT_TRUE(nested->axis(1)->isA()); + + // Third axis should be the original second dimension + EXPECT_TRUE(nested->axis(2)->isStrictlyA()); + + // Verify the definition exists (LoadStoreOp for aliasing) + EXPECT_TRUE(nested->definition() != nullptr); + EXPECT_TRUE(nested->definition()->isA()); + + // Verify the component and ragged IterDomains have Partition as their + // definition + EXPECT_TRUE(nested->axis(0)->definition() != nullptr); + EXPECT_TRUE(nested->axis(0)->definition()->isA()); + EXPECT_EQ(nested->axis(0)->definition(), nested->axis(1)->definition()); } // asNested on different dimensions @@ -495,6 +483,48 @@ TEST_F(RaggedIterDomainTest, AsNestedValidationMultiDimOffsets) { EXPECT_THROW(asNested(data, offsets_2d, 0), nvfuser::nvfError); } +TEST_F(RaggedIterDomainTest, LoadStoreWithNestedTensor) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor from dimension 0 + auto nested = asNested(data, offsets, 0); + + // This should still be a nested tensor + auto copy_of_nested = set(nested); + + fusion.addOutput(copy_of_nested); + + // Verify the output is a new TensorView + EXPECT_TRUE(copy_of_nested != nullptr); + EXPECT_NE(copy_of_nested, data); + EXPECT_TRUE(copy_of_nested->isA()); + + // Verify copy_of_nested tensor has 3 dimensions: [component, ragged, + // original_dim1] + EXPECT_EQ(copy_of_nested->nDims(), 3); + + // First axis should be a regular IterDomain (component) + EXPECT_TRUE(copy_of_nested->axis(0)->isStrictlyA()); + EXPECT_FALSE(copy_of_nested->axis(0)->isA()); + + // Second axis should be a RaggedIterDomain + EXPECT_TRUE(copy_of_nested->axis(1)->isA()); + + // Third axis should be the original second dimension + EXPECT_TRUE(copy_of_nested->axis(2)->isStrictlyA()); + + // The copy of the original copy_of_nested tensor does not inherit the + // Partition op + EXPECT_TRUE(copy_of_nested->axis(0)->definition() == nullptr); +} + // Test binary operations with nested tensors TEST_F(RaggedIterDomainTest, BinaryOpWithNestedTensors) { Fusion fusion; @@ -514,7 +544,7 @@ TEST_F(RaggedIterDomainTest, BinaryOpWithNestedTensors) { auto nested1 = asNested(data1, offsets, 0); auto nested2 = asNested(data2, offsets, 0); - // Perform binary operation: add + // Perform binary operation. The result should be a nested tensor auto result = add(nested1, nested2); fusion.addOutput(result); @@ -577,8 +607,8 @@ TEST_F(RaggedIterDomainTest, BinaryOpDifferentRaggedStructures) { auto nested2 = asNested(data2, offsets2, 0); // This would be an error if, for example, the values of the offset - // tensors are not equivalent, but, like normal tensors, we assume - // that is indeed the case. + // tensors are not equivalent, but, like binary ops with normal + // tensors, we assume their shapes are indeed compatible auto result = add(nested1, nested2); fusion.addOutput(result); @@ -768,8 +798,14 @@ TEST_F(RaggedIterDomainTest, ReductionOnRaggedDimError) { EXPECT_THROW(sum(nested, {1}), nvfuser::nvfError); } -// Test reshape with nested tensors - should error -TEST_F(RaggedIterDomainTest, ReshapeWithNestedTensorsError) { +// Test reduction on component dimension - should error (TODO) +TEST_F(RaggedIterDomainTest, ReductionOnComponentDimError) { + GTEST_SKIP() << "TODO: Implement validation to prevent reduction of " + "component dimension. " + << "Currently there is no explicit marking of which IterDomains " + "are component dimensions, " + << "so this validation cannot be implemented yet."; + Fusion fusion; FusionGuard fg(&fusion); @@ -782,15 +818,15 @@ TEST_F(RaggedIterDomainTest, ReshapeWithNestedTensorsError) { // Create nested tensor: [component, ragged, dim1] auto nested = asNested(data, offsets, 0); - // Try to reshape - this should throw an error because reshape is not - // supported for tensors with RaggedIterDomain - std::vector new_shape = { - IrBuilder::create(-1L, DataType::Index), nested->axis(2)->extent()}; - EXPECT_THROW(reshape(nested, new_shape), nvfuser::nvfError); + // Try to reduce along the component dimension (axis 0) + // This should throw an error because reducing component dimensions is not + // allowed The component dimension defines the batch structure of the ragged + // tensor, and reducing it would destroy the ragged structure + EXPECT_THROW(sum(nested, {0}), nvfuser::nvfError); } -// Test flatten with nested tensors - should error -TEST_F(RaggedIterDomainTest, FlattenWithNestedTensorsError) { +// Test reshape with nested tensors - should error +TEST_F(RaggedIterDomainTest, ReshapeWithNestedTensorsError) { Fusion fusion; FusionGuard fg(&fusion); @@ -803,13 +839,15 @@ TEST_F(RaggedIterDomainTest, FlattenWithNestedTensorsError) { // Create nested tensor: [component, ragged, dim1] auto nested = asNested(data, offsets, 0); - // Try to flatten - this should throw an error because flatten is not + // Try to reshape - this should throw an error because reshape is not // supported for tensors with RaggedIterDomain - EXPECT_THROW(flatten(nested, 0, 2), nvfuser::nvfError); + std::vector new_shape = { + IrBuilder::create(-1L, DataType::Index), nested->axis(2)->extent()}; + EXPECT_THROW(reshape(nested, new_shape), nvfuser::nvfError); } -// Test transpose with nested tensors -TEST_F(RaggedIterDomainTest, TransposeWithNestedTensors) { +// Test flatten with nested tensors - should error +TEST_F(RaggedIterDomainTest, FlattenWithNestedTensorsError) { Fusion fusion; FusionGuard fg(&fusion); @@ -822,20 +860,9 @@ TEST_F(RaggedIterDomainTest, TransposeWithNestedTensors) { // Create nested tensor: [component, ragged, dim1] auto nested = asNested(data, offsets, 0); - // Transpose ragged and dim1 dimensions - auto result = transpose(nested, 1, 2); - - fusion.addOutput(result); - - // Expected: [component, dim1, ragged] - // Should preserve RaggedIterDomain type - auto non_reduction_domain = - TensorDomain::noReductions(result->getLogicalDomain()); - - EXPECT_EQ(non_reduction_domain.size(), 3); - EXPECT_TRUE(non_reduction_domain[0]->isStrictlyA()); - EXPECT_TRUE(non_reduction_domain[1]->isStrictlyA()); - EXPECT_TRUE(non_reduction_domain[2]->isA()); + // Try to flatten - this should throw an error because flatten is not + // supported for tensors with RaggedIterDomain + EXPECT_THROW(flatten(nested, 0, 2), nvfuser::nvfError); } // Test slice on ragged dimension - should error @@ -911,32 +938,6 @@ TEST_F(RaggedIterDomainTest, CatNonRaggedDimensionError) { EXPECT_THROW(cat({nested1, nested2}, 2), nvfuser::nvfError); } -// Test expand with nested tensors - should error -TEST_F(RaggedIterDomainTest, ExpandWithNestedTensorsError) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto data = makeSymbolicTensor(2, DataType::Float); - fusion.addInput(data); - - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); - - // Create nested tensor: [component, ragged, dim1] - auto nested = asNested(data, offsets, 0); - - // Try to expand a broadcast dimension - should error - auto broadcasted = broadcast(nested, {false, false, false, true}); - EXPECT_THROW( - expand( - broadcasted, - {nested->axis(0)->extent(), - nested->axis(1)->extent(), - nested->axis(2)->extent(), - IrBuilder::create(5L, DataType::Index)}), - nvfuser::nvfError); -} - // Test pad on ragged dimension - should error TEST_F(RaggedIterDomainTest, PadRaggedDimensionError) { Fusion fusion; @@ -965,170 +966,4 @@ TEST_F(RaggedIterDomainTest, PadRaggedDimensionError) { EXPECT_THROW(pad(nested, pad_widths), nvfuser::nvfError); } -// Test select with nested tensors - should error -TEST_F(RaggedIterDomainTest, SelectWithNestedTensorsError) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto data = makeSymbolicTensor(2, DataType::Float); - fusion.addInput(data); - - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); - - // Create nested tensor: [component, ragged, dim1] - auto nested = asNested(data, offsets, 0); - - // Try to select from a non-ragged dimension - should error - EXPECT_THROW(select(nested, 0, fusion.zeroVal()), nvfuser::nvfError); -} - -// Test gather with nested tensors - should error -TEST_F(RaggedIterDomainTest, GatherWithNestedTensorsError) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto data = makeSymbolicTensor(2, DataType::Float); - fusion.addInput(data); - - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); - - auto index = makeSymbolicTensor(3, DataType::Index); - fusion.addInput(index); - - // Create nested tensor: [component, ragged, dim1] - auto nested = asNested(data, offsets, 0); - - // Try to gather from nested tensor - should error - EXPECT_THROW(gather(nested, 2, index), nvfuser::nvfError); -} - -// Test view operations with nested tensors - should error -TEST_F(RaggedIterDomainTest, ViewWithNestedTensorsError) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto data = makeSymbolicTensor(2, DataType::Float); - fusion.addInput(data); - - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); - - // Create nested tensor: [component, ragged, dim1] - auto nested = asNested(data, offsets, 0); - - // Try to change dtype via view - should error - EXPECT_THROW(view(nested, DataType::Half), nvfuser::nvfError); -} - -// Test select (indexing) with nested tensors - should error -TEST_F(RaggedIterDomainTest, SelectIndexingWithNestedTensorsError) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto data = makeSymbolicTensor(2, DataType::Float); - fusion.addInput(data); - - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); - - // Create nested tensor: [component, ragged, dim1] - auto nested = asNested(data, offsets, 0); - - // Try to select from component dimension - should error - EXPECT_THROW(select(nested, 0, fusion.zeroVal()), nvfuser::nvfError); -} - -// Test index_select with nested tensors - should error -TEST_F(RaggedIterDomainTest, IndexSelectWithNestedTensorsError) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto data = makeSymbolicTensor(2, DataType::Float); - fusion.addInput(data); - - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); - - auto indices = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(indices); - - // Create nested tensor: [component, ragged, dim1] - auto nested = asNested(data, offsets, 0); - - // Try to index select from non-ragged dimension - should error - EXPECT_THROW(indexSelect(nested, 2, indices), nvfuser::nvfError); -} - -// Test scatter with nested tensors - should error -TEST_F(RaggedIterDomainTest, ScatterWithNestedTensorsError) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto data = makeSymbolicTensor(2, DataType::Float); - fusion.addInput(data); - - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); - - auto src = makeSymbolicTensor(3, DataType::Float); - fusion.addInput(src); - - auto indices = makeSymbolicTensor(3, DataType::Index); - fusion.addInput(indices); - - // Create nested tensor: [component, ragged, dim1] - auto nested = asNested(data, offsets, 0); - - // Try to scatter into nested tensor - should error - EXPECT_THROW(scatter(nested, 2, indices, src), nvfuser::nvfError); -} - -// Test repeat with nested tensors - should error -TEST_F(RaggedIterDomainTest, RepeatWithNestedTensorsError) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto data = makeSymbolicTensor(2, DataType::Float); - fusion.addInput(data); - - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); - - // Create nested tensor: [component, ragged, dim1] - auto nested = asNested(data, offsets, 0); - - // Try to repeat along non-ragged dimension - should error - std::vector repeats = {1, 1, 2}; - EXPECT_THROW(repeat(nested, repeats), nvfuser::nvfError); -} - -// Test reduction on component dimension - should error (TODO) -TEST_F(RaggedIterDomainTest, ReductionOnComponentDimError) { - GTEST_SKIP() << "TODO: Implement validation to prevent reduction of " - "component dimension. " - << "Currently there is no explicit marking of which IterDomains " - "are component dimensions, " - << "so this validation cannot be implemented yet."; - - Fusion fusion; - FusionGuard fg(&fusion); - - auto data = makeSymbolicTensor(2, DataType::Float); - fusion.addInput(data); - - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); - - // Create nested tensor: [component, ragged, dim1] - auto nested = asNested(data, offsets, 0); - - // Try to reduce along the component dimension (axis 0) - // This should throw an error because reducing component dimensions is not - // allowed The component dimension defines the batch structure of the ragged - // tensor, and reducing it would destroy the ragged structure - EXPECT_THROW(sum(nested, {0}), nvfuser::nvfError); -} - } // namespace nvfuser From 0b68d6b4517cafb745aa00c9cd81bd8111720a22 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 17 Dec 2025 23:01:33 -0800 Subject: [PATCH 24/31] cleanup --- csrc/ir/internal_base_nodes.cpp | 46 +++++++++++++++++++-------------- csrc/ir/internal_base_nodes.h | 22 ++++++++-------- csrc/ops/alias.cpp | 2 +- csrc/ops/indexing.cpp | 6 +++-- 4 files changed, 42 insertions(+), 34 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 15336e269ee..7ffb35739a8 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -30,8 +30,8 @@ namespace nvfuser { -IterDomainBuilder::IterDomainBuilder(Val* _start, Val* _extent) - : start_(_start), extent_(_extent) { +IterDomainBuilder::IterDomainBuilder(Val* start, Val* extent) + : start_(start), extent_(extent) { NVF_ERROR( start_ != nullptr && extent_ != nullptr, "Start and extent are required to build an iter domain."); @@ -67,52 +67,58 @@ IterDomainBuilder& IterDomainBuilder::resetRfactor() { return is_rfactor_domain(false); } -IterDomainBuilder& IterDomainBuilder::start(Val* _start) { - start_ = _start; +IterDomainBuilder& IterDomainBuilder::start(Val* start) { + start_ = start; return *this; } -IterDomainBuilder& IterDomainBuilder::extent(Val* _extent) { - extent_ = _extent; +IterDomainBuilder& IterDomainBuilder::extent(Val* extent) { + extent_ = extent; return *this; } -IterDomainBuilder& IterDomainBuilder::expanded_extent(Val* _expanded_extent) { - expanded_extent_ = _expanded_extent; +IterDomainBuilder& IterDomainBuilder::expanded_extent(Val* expanded_extent) { + expanded_extent_ = expanded_extent; return *this; } -IterDomainBuilder& IterDomainBuilder::stop_offset(Val* _stop_offset) { - stop_offset_ = _stop_offset; +IterDomainBuilder& IterDomainBuilder::stop_offset(Val* stop_offset) { + stop_offset_ = stop_offset; return *this; } IterDomainBuilder& IterDomainBuilder::parallel_type( - ParallelType _parallel_type) { - parallel_type_ = _parallel_type; + ParallelType parallel_type) { + parallel_type_ = parallel_type; return *this; } -IterDomainBuilder& IterDomainBuilder::iter_type(IterType _iter_type) { - iter_type_ = _iter_type; +IterDomainBuilder& IterDomainBuilder::iter_type(IterType iter_type) { + iter_type_ = iter_type; return *this; } IterDomainBuilder& IterDomainBuilder::is_rfactor_domain( - bool _is_rfactor_domain) { - is_rfactor_domain_ = _is_rfactor_domain; + bool is_rfactor_domain) { + is_rfactor_domain_ = is_rfactor_domain; return *this; } IterDomainBuilder& IterDomainBuilder::is_padded_dimension( - bool _is_padded_dimension) { - is_padded_dimension_ = _is_padded_dimension; + bool is_padded_dimension) { + is_padded_dimension_ = is_padded_dimension; return *this; } IterDomainBuilder& IterDomainBuilder::padded_to_size( - std::optional _padded_to_size) { - padded_to_size_ = _padded_to_size; + std::optional padded_to_size) { + padded_to_size_ = padded_to_size; + return *this; +} + +IterDomainBuilder& IterDomainBuilder::ragged_extents( + TensorView* ragged_extents) { + ragged_extents_ = ragged_extents; return *this; } diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 1f3e6658e01..84a2e7686be 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -40,7 +40,7 @@ struct AnalyzeViewResult; class IterDomainBuilder { public: // Match legacy constructor - IterDomainBuilder(Val* _start, Val* _extent); + IterDomainBuilder(Val* start, Val* extent); // Grab all the parameters from id to set the IterDomainBuilder IterDomainBuilder(const IterDomain* id); @@ -52,16 +52,16 @@ class IterDomainBuilder { // Resets is_rfactor_domain IterDomainBuilder& resetRfactor(); - IterDomainBuilder& start(Val* _start); - IterDomainBuilder& extent(Val* _extent); - IterDomainBuilder& expanded_extent(Val* _expanded_extent); - IterDomainBuilder& stop_offset(Val* _stop_offset); - IterDomainBuilder& parallel_type(ParallelType _parallel_type); - IterDomainBuilder& iter_type(IterType _iter_type); - IterDomainBuilder& is_rfactor_domain(bool _is_rfactor_domain); - IterDomainBuilder& is_padded_dimension(bool _is_padded_dimension); - IterDomainBuilder& padded_to_size(std::optional _padded_to_size); - IterDomainBuilder& ragged_extents(TensorView* _ragged_extents); + IterDomainBuilder& start(Val* start); + IterDomainBuilder& extent(Val* extent); + IterDomainBuilder& expanded_extent(Val* expanded_extent); + IterDomainBuilder& stop_offset(Val* stop_offset); + IterDomainBuilder& parallel_type(ParallelType parallel_type); + IterDomainBuilder& iter_type(IterType iter_type); + IterDomainBuilder& is_rfactor_domain(bool is_rfactor_domain); + IterDomainBuilder& is_padded_dimension(bool is_padded_dimension); + IterDomainBuilder& padded_to_size(std::optional padded_to_size); + IterDomainBuilder& ragged_extents(TensorView* ragged_extents); IterDomain* build() const; diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index ea87ae1d73c..4a3609f8b28 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -531,7 +531,7 @@ TensorView* pad( Val* value, std::optional iter_type_opt) { NVF_CHECK( - inp->domain()->hasRaggedIterDomain(), + !inp->domain()->hasRaggedIterDomain(), "Padding a tensor with RaggedIterDomain not supported: ", inp->toString()); diff --git a/csrc/ops/indexing.cpp b/csrc/ops/indexing.cpp index a28ca67f72b..b8975c9ac66 100644 --- a/csrc/ops/indexing.cpp +++ b/csrc/ops/indexing.cpp @@ -54,13 +54,15 @@ TensorView* indexSelect( TensorView* index_tv) { NVF_CHECK( !lookup_tv->domain()->hasRaggedIterDomain(), - "IndexSelect operation is not supported for tensors with RaggedIterDomain. " + "IndexSelect operation is not supported for tensors with " + "RaggedIterDomain. " "Input tensor (lookup_tv): ", lookup_tv->toString()); NVF_CHECK( !index_tv->domain()->hasRaggedIterDomain(), - "IndexSelect operation is not supported for tensors with RaggedIterDomain. " + "IndexSelect operation is not supported for tensors with " + "RaggedIterDomain. " "Index tensor (index_tv): ", index_tv->toString()); From 8a73bb2f76729c9b5973365a20a8c7829f4bb1bd Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 18 Dec 2025 10:00:40 -0800 Subject: [PATCH 25/31] cleanup --- csrc/ir/interface_nodes.h | 11 +++--- csrc/ir/internal_base_nodes.cpp | 54 ++++++++------------------- csrc/ir/internal_base_nodes.h | 19 +++++----- csrc/ir/internal_nodes.cpp | 10 ++++- csrc/tensor_view.cpp | 10 ++--- tests/cpp/test_ragged_iter_domain.cpp | 50 ++++++++++++------------- 6 files changed, 69 insertions(+), 85 deletions(-) diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index ea0236a1a7b..0128c937ed4 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -619,14 +619,13 @@ class NVF_API TensorView : public Val { return merge(axis, axis + 1); } - // Partition "axis" into component and ragged dimensions based on offsets - // The offsets tensor defines partition boundaries where: - // Shape: [num_components + 1], values: [0, off1, off2, ..., total] - // Extents are computed as: extents[i] = offsets[i+1] - offsets[i] + // Partition "axis" into component and ragged dimensions based on extents + // The extents tensor directly specifies the size of each component: + // Shape: [num_components], values: [extent0, extent1, ..., extent(n-1)] // Returns this TensorView with the axis replaced by component and ragged dims - // e.g. partition(0, offsets) on tv[id{N}] results in: + // e.g. partition(0, extents) on tv[id{N}] results in: // tv[id{num_components}, ragged_id{extents}] - TensorView* partition(int64_t axis, TensorView* offsets); + TensorView* partition(int64_t axis, TensorView* extents); // Flatten the axis from `from` to `to` into a single axis. // Both `from` and `to` are inclusive. diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index cc068fbfca2..292c00a8740 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -897,7 +897,7 @@ std::string RaggedIterDomain::toString(int indent_size) const { std::pair RaggedIterDomain::partition( IterDomain* in, - TensorView* offsets) { + TensorView* extents) { NVF_ERROR(in != nullptr, "partition: input IterDomain is null"); NVF_ERROR( @@ -918,52 +918,28 @@ std::pair RaggedIterDomain::partition( " for IterDomain: ", in->toString()); - NVF_ERROR(offsets != nullptr, "partition: offsets tensor is null"); + NVF_ERROR(extents != nullptr, "partition: extents tensor is null"); NVF_ERROR_EQ( - offsets->dtype(), + extents->dtype(), DataType::Index, - "partition: offsets must have Index type, got ", - offsets->dtype()); + "partition: extents must have Index type, got ", + extents->dtype()); - const auto& offsets_domain = offsets->getLogicalDomain(); + const auto& extents_domain = extents->getLogicalDomain(); NVF_ERROR_EQ( - offsets_domain.size(), + extents_domain.size(), 1, - "partition: offsets tensor must be 1D, got ", - offsets_domain.size(), - "D tensor. Multi-dimensional offsets not yet supported."); + "partition: extents tensor must be 1D, got ", + extents_domain.size(), + "D tensor. Multi-dimensional extents not yet supported."); auto container = in->container(); - // Compute extents from offsets: extents[i] = offsets[i+1] - offsets[i] - // offsets_left = offsets[:-1] (all but last element) - // offsets_right = offsets[1:] (all but first element) - - auto offsets_len = offsets_domain[0]->extent(); - - auto zero = container->zeroVal(DataType::Index); - auto one = container->oneVal(DataType::Index); - auto len_minus_one = sub(offsets_len, one); - - // Slice offsets[:-1] - Slice left_slice; - left_slice.start = zero; - left_slice.stop = len_minus_one; - auto offsets_left = slice(offsets, {left_slice}); - - // Slice offsets[1:] - Slice right_slice; - right_slice.start = one; - right_slice.stop = offsets_len; - auto offsets_right = slice(offsets, {right_slice}); - - // Compute extents: extents = offsets_right - offsets_left - auto extents = sub(offsets_right, offsets_left); - // Create component IterDomain - // Component extent = number of components = len(offsets) - 1 - auto component_extent = len_minus_one; + // Component extent = number of components = length of extents tensor + auto zero = container->zeroVal(DataType::Index); + auto component_extent = extents_domain.at(0)->extent(); auto component_id = IterDomainBuilder(zero, component_extent) .parallel_type(ParallelType::Serial) .iter_type(IterType::Iteration) @@ -1583,13 +1559,13 @@ void TensorDomain::merge(int64_t axis_o, int64_t axis_i) { // Partition "axis" into component and ragged dimensions. Follow the // pattern of TensorDomain::split. -void TensorDomain::partition(int64_t axis, TensorView* offsets) { +void TensorDomain::partition(int64_t axis, TensorView* extents) { NVF_ERROR(nDims() > 0, "Tried to do partition on a 0-dim domain"); axis = wrapDim(axis); IterDomain* id = this->axis(axis); - auto [component_id, ragged_id] = RaggedIterDomain::partition(id, offsets); + auto [component_id, ragged_id] = RaggedIterDomain::partition(id, extents); // Remove the original axis and insert component and ragged dimensions loop_domain_.erase(loop_domain_.begin() + axis); diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 0187c408bd7..c5fad115ba3 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -477,20 +477,21 @@ class NVF_API RaggedIterDomain : public IterDomain { } //! Partition an IterDomain into component and ragged dimensions - //! Creates a component IterDomain and a RaggedIterDomain based on offsets + //! Creates a component IterDomain and a RaggedIterDomain based on extents //! //! \param in Input IterDomain to partition (must be regular IterDomain) - //! \param offsets Offset tensor defining partition boundaries (must be 1D) - //! Shape: [num_components + 1], values: [0, off1, off2, ..., total] - //! Extents are computed as: extents[i] = offsets[i+1] - offsets[i] + //! \param extents Extents tensor defining the size of each component (must be + //! 1D) + //! Shape: [num_components], values: [extent0, extent1, ..., + //! extent(n-1)] //! \return Pair of (component_id, ragged_id) //! component_id: IterDomain with extent = num_components - //! ragged_id: RaggedIterDomain with extents computed from offsets + //! ragged_id: RaggedIterDomain with the provided extents //! - //! TODO: Support multi-dimensional offsets for nested ragged structures + //! TODO: Support multi-dimensional extents for nested ragged structures static std::pair partition( IterDomain* in, - TensorView* offsets); + TensorView* extents); private: //! Extent tensor containing all component extents @@ -792,8 +793,8 @@ class NVF_API TensorDomain : public Val { // axis is by default placed at original position axis_o void merge(int64_t axis_o, int64_t axis_i); - // Partition axis into component and ragged dimensions based on offsets - void partition(int64_t axis, TensorView* offsets); + // Partition axis into component and ragged dimensions based on extents + void partition(int64_t axis, TensorView* extents); // Reorder axes according to map[old_pos] = new_pos void reorder(const std::unordered_map& old2new); diff --git a/csrc/ir/internal_nodes.cpp b/csrc/ir/internal_nodes.cpp index 6ebc9271b02..219d1669b48 100644 --- a/csrc/ir/internal_nodes.cpp +++ b/csrc/ir/internal_nodes.cpp @@ -2622,7 +2622,15 @@ Partition::Partition( addOutput(component); addOutput(ragged); addInput(in); - // Should the extents tensor be an input rather than an attribute? + // Note: extents is held as an attribute rather than an input, + // despite it's a TensorView. Inputs and outputs in the existing + // IterDomain exprs are always IterDomains. Intuitively, they + // transform input iteration spaces into output iteration spaces in + // some way. Since the extents tensor itself is not transformed in the + // Partition expr, it doesn't seem to be considered as an input. Note that in + // Split, the split factor is an attribute. However, that said, none + // of the existing exprs has tensors as attributes, which makes this + // choice less certain with possible implications. addAttribute(extents); } diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index eabe2004f0c..fc10d3db3a9 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -562,8 +562,8 @@ TensorView* TensorView::merge(int64_t axis_o, int64_t axis_i) { } // Partition "axis" into component and ragged dimensions based on -// offsets. Follow the pattern of TensorView::split. -TensorView* TensorView::partition(int64_t axis, TensorView* offsets) { +// extents. Follow the pattern of TensorView::split. +TensorView* TensorView::partition(int64_t axis, TensorView* extents) { NVF_ERROR( nDims() > 0, "Tried to do partition on a 0-dim TensorView. ", @@ -598,11 +598,11 @@ TensorView* TensorView::partition(int64_t axis, TensorView* offsets) { " Parallelization strategy must be set after calling partition: ", toString()); - if (offsets->dtype() != DataType::Index) { - offsets = castOp(DataType::Index, offsets); + if (extents->dtype() != DataType::Index) { + extents = castOp(DataType::Index, extents); } - domain()->partition(axis, offsets); + domain()->partition(axis, extents); return this; } diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 8d16615bd64..ecbbf10d2f2 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -211,13 +211,13 @@ TEST_F(RaggedIterDomainTest, PartitionBasic) { fusion.zeroVal(), IrBuilder::create(-1, DataType::Index)) .build(); - // Create a symbolic offset tensor - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); + // Create a symbolic extents tensor + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); // Partition the IterDomain auto [component_id, ragged_id] = - RaggedIterDomain::partition(input_id, offsets); + RaggedIterDomain::partition(input_id, extents); // Verify component IterDomain EXPECT_TRUE(component_id != nullptr); @@ -265,28 +265,28 @@ TEST_F(RaggedIterDomainTest, PartitionValidation) { fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) .build(); - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); // Test 1: Null input should fail EXPECT_THROW( - RaggedIterDomain::partition(nullptr, offsets), nvfuser::nvfError); + RaggedIterDomain::partition(nullptr, extents), nvfuser::nvfError); - // Test 2: Null offsets should fail + // Test 2: Null extents should fail EXPECT_THROW( RaggedIterDomain::partition(input_id, nullptr), nvfuser::nvfError); - // Test 3: Non-Index offsets should fail - auto float_offsets = makeSymbolicTensor(1, DataType::Float); - fusion.addInput(float_offsets); + // Test 3: Non-Index extents should fail + auto float_extents = makeSymbolicTensor(1, DataType::Float); + fusion.addInput(float_extents); EXPECT_THROW( - RaggedIterDomain::partition(input_id, float_offsets), nvfuser::nvfError); + RaggedIterDomain::partition(input_id, float_extents), nvfuser::nvfError); - // Test 4: Multi-dimensional offsets should fail - auto offsets_2d = makeSymbolicTensor(2, DataType::Index); - fusion.addInput(offsets_2d); + // Test 4: Multi-dimensional extents should fail + auto extents_2d = makeSymbolicTensor(2, DataType::Index); + fusion.addInput(extents_2d); EXPECT_THROW( - RaggedIterDomain::partition(input_id, offsets_2d), nvfuser::nvfError); + RaggedIterDomain::partition(input_id, extents_2d), nvfuser::nvfError); // Test 5: Non-Iteration IterType should fail auto reduction_id = @@ -295,15 +295,15 @@ TEST_F(RaggedIterDomainTest, PartitionValidation) { .iter_type(IterType::Reduction) .build(); EXPECT_THROW( - RaggedIterDomain::partition(reduction_id, offsets), nvfuser::nvfError); + RaggedIterDomain::partition(reduction_id, extents), nvfuser::nvfError); // Test 6: Cannot partition RaggedIterDomain - auto extents = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(extents); + auto extents2 = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents2); auto ragged_id = IrBuilder::create( - extents, IterType::Iteration, ParallelType::Serial); + extents2, IterType::Iteration, ParallelType::Serial); EXPECT_THROW( - RaggedIterDomain::partition(ragged_id, offsets), nvfuser::nvfError); + RaggedIterDomain::partition(ragged_id, extents), nvfuser::nvfError); } // TensorView::partition operation @@ -315,12 +315,12 @@ TEST_F(RaggedIterDomainTest, TensorViewPartition) { auto tv0 = makeSymbolicTensor(2, DataType::Float); fusion.addInput(tv0); - // Create offsets tensor - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); + // Create extents tensor + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); // Partition the first axis - tv0->partition(0, offsets); + tv0->partition(0, extents); // Verify the tensor now has 3 dimensions: [component, ragged, original_dim1] EXPECT_EQ(tv0->nDims(), 3); From f215f079f1fa8ccbcf425ddf8470f8cf42cf3566 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 18 Dec 2025 14:01:20 -0800 Subject: [PATCH 26/31] Use extents as a parameter --- csrc/ops/alias.cpp | 12 ++++---- csrc/ops/alias.h | 13 ++++---- tests/cpp/test_ragged_iter_domain.cpp | 44 +++++++++++++-------------- 3 files changed, 34 insertions(+), 35 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index e32aa5e6b9c..16e71e7ecdd 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1270,16 +1270,16 @@ TensorView* repeat( TensorView* asNested( TensorView* data, - TensorView* offsets, + TensorView* extents, int64_t ragged_dim) { NVF_ERROR(data != nullptr, "asNested: data tensor is null"); - NVF_ERROR(offsets != nullptr, "asNested: offsets tensor is null"); + NVF_ERROR(extents != nullptr, "asNested: extents tensor is null"); - // Only 1D offset tensors are currently supported + // Only 1D extents tensors are currently supported NVF_ERROR_EQ( - offsets->nDims(), + extents->nDims(), 1, - "asNested currently only supports 1D offset tensors"); + "asNested currently only supports 1D extents tensors"); // Get the logical domain of the input, excluding reductions auto inp_logical = TensorDomain::noReductions(data->getLogicalDomain()); @@ -1296,7 +1296,7 @@ TensorView* asNested( // Partition the specified dimension in root domain // This replaces one IterDomain with (component_id, ragged_id) auto [component_id, ragged_id] = - RaggedIterDomain::partition(root_domain.at(ragged_dim), offsets); + RaggedIterDomain::partition(root_domain.at(ragged_dim), extents); // Build the logical domain: replace ragged_dim with component and ragged std::vector logical_domain; diff --git a/csrc/ops/alias.h b/csrc/ops/alias.h index f3bf769dd71..5963e99df66 100644 --- a/csrc/ops/alias.h +++ b/csrc/ops/alias.h @@ -197,26 +197,25 @@ NVF_API TensorView* repeat( TensorView* inp, const std::vector& repeat_times); -//! Create a nested tensor view from a data tensor and offsets. +//! Create a nested tensor view from a data tensor and extents. //! //! The function partitions the specified dimension of the data tensor into -//! a component dimension and a ragged dimension based on the provided offsets. +//! a component dimension and a ragged dimension based on the provided extents. //! //! \param data Input tensor to be converted to nested representation -//! \param offsets Offset tensor defining partition boundaries -//! Shape: [num_components + 1], values: [0, off1, off2, ..., total] -//! Extents are computed as: extents[i] = offsets[i+1] - offsets[i] +//! \param extents Extents tensor defining the size of each component +//! Shape: [num_components], values: [extent0, extent1, ..., extent(n-1)] //! \param ragged_dim Dimension to partition into nested structure //! \return TensorView with a RaggedIterDomain at the specified dimension //! //! Example: //! data shape: [10, ...] -//! offsets: [0, 3, 8, 10] +//! extents: [3, 5, 2] //! ragged_dim: 0 //! Result: nested tensor with 3 components. [3, [3, 5, 2], ...] NVF_API TensorView* asNested( TensorView* data, - TensorView* offsets, + TensorView* extents, int64_t ragged_dim); } // namespace nvfuser diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 536a723d0cf..cc6b65078c1 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -348,11 +348,11 @@ TEST_F(RaggedIterDomainTest, AsNestedBasic) { auto data = makeSymbolicTensor(2, DataType::Float); fusion.addInput(data); - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); // Create nested tensor from dimension 0 - auto nested = asNested(data, offsets, 0); + auto nested = asNested(data, extents, 0); fusion.addOutput(nested); @@ -393,11 +393,11 @@ TEST_F(RaggedIterDomainTest, AsNestedDifferentDimension) { auto data = makeSymbolicTensor(3, DataType::Float); fusion.addInput(data); - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); // Partition dimension 1 (middle dimension) - auto nested = asNested(data, offsets, 1); + auto nested = asNested(data, extents, 1); // Verify dimensions: [dim0, component, ragged, dim2] EXPECT_EQ(nested->nDims(), 4); @@ -424,12 +424,12 @@ TEST_F(RaggedIterDomainTest, AsNested1DTensor) { auto data = makeSymbolicTensor(1, DataType::Float); fusion.addInput(data); - // Create offsets tensor - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); + // Create extents tensor + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); // Create nested tensor from the only dimension - auto nested = asNested(data, offsets, 0); + auto nested = asNested(data, extents, 0); fusion.addOutput(nested); @@ -448,38 +448,38 @@ TEST_F(RaggedIterDomainTest, AsNestedValidationNullData) { Fusion fusion; FusionGuard fg(&fusion); - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); // Null data should throw - EXPECT_THROW(asNested(nullptr, offsets, 0), nvfuser::nvfError); + EXPECT_THROW(asNested(nullptr, extents, 0), nvfuser::nvfError); } -// asNested validation - null offsets -TEST_F(RaggedIterDomainTest, AsNestedValidationNullOffsets) { +// asNested validation - null extents +TEST_F(RaggedIterDomainTest, AsNestedValidationNullExtents) { Fusion fusion; FusionGuard fg(&fusion); auto data = makeSymbolicTensor(2, DataType::Float); fusion.addInput(data); - // Null offsets should throw + // Null extents should throw EXPECT_THROW(asNested(data, nullptr, 0), nvfuser::nvfError); } -// asNested validation - multi-dimensional offsets (not yet supported) -TEST_F(RaggedIterDomainTest, AsNestedValidationMultiDimOffsets) { +// asNested validation - multi-dimensional extents (not yet supported) +TEST_F(RaggedIterDomainTest, AsNestedValidationMultiDimExtents) { Fusion fusion; FusionGuard fg(&fusion); auto data = makeSymbolicTensor(2, DataType::Float); fusion.addInput(data); - // 2D offsets should fail (only 1D supported currently) - auto offsets_2d = makeSymbolicTensor(2, DataType::Index); - fusion.addInput(offsets_2d); + // 2D extents should fail (only 1D supported currently) + auto extents_2d = makeSymbolicTensor(2, DataType::Index); + fusion.addInput(extents_2d); - EXPECT_THROW(asNested(data, offsets_2d, 0), nvfuser::nvfError); + EXPECT_THROW(asNested(data, extents_2d, 0), nvfuser::nvfError); } } // namespace nvfuser From 8aa854e17bc7dfc6a08c12a7fd63eb5d3ae43070 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 7 Jan 2026 09:58:52 -0800 Subject: [PATCH 27/31] feedback --- csrc/ops/alias.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index ffc26e308cb..1c412a61f8d 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1274,23 +1274,24 @@ TensorView* asNested( NVF_ERROR(data != nullptr, "asNested: data tensor is null"); NVF_ERROR(extents != nullptr, "asNested: extents tensor is null"); + // Get the logical domain of the input, excluding reductions + auto inp_logical = data->getLogicalDomain() | TensorDomain::kNoReductions; + auto inp_logical_size = std::ranges::distance(inp_logical); + // Only 1D extents tensors are currently supported NVF_ERROR_EQ( - extents->nDims(), + inp_logical_size, 1, "asNested currently only supports 1D extents tensors"); - // Get the logical domain of the input, excluding reductions - auto inp_logical = TensorDomain::noReductions(data->getLogicalDomain()); - // Clone the logical domain to create the root domain for output std::vector root_domain; - root_domain.reserve(inp_logical.size()); + root_domain.reserve(inp_logical_size); for (auto* id : inp_logical) { root_domain.push_back(id->cloneWithoutRFactor()); } - ragged_dim = wrapDim(ragged_dim, std::ssize(inp_logical)); + ragged_dim = wrapDim(ragged_dim, inp_logical_size); // Partition the specified dimension in root domain // This replaces one IterDomain with (component_id, ragged_id) From 72ae14f1ebac33c7144572558e5ac8c55925023b Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 7 Jan 2026 10:17:33 -0800 Subject: [PATCH 28/31] fix --- csrc/ops/alias.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 1c412a61f8d..0fdb03274c0 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1274,16 +1274,17 @@ TensorView* asNested( NVF_ERROR(data != nullptr, "asNested: data tensor is null"); NVF_ERROR(extents != nullptr, "asNested: extents tensor is null"); - // Get the logical domain of the input, excluding reductions - auto inp_logical = data->getLogicalDomain() | TensorDomain::kNoReductions; - auto inp_logical_size = std::ranges::distance(inp_logical); - // Only 1D extents tensors are currently supported NVF_ERROR_EQ( - inp_logical_size, + std::ranges::distance( + extents->getLogicalDomain() | TensorDomain::kNoReductions), 1, "asNested currently only supports 1D extents tensors"); + // Get the logical domain of the input, excluding reductions + auto inp_logical = data->getLogicalDomain() | TensorDomain::kNoReductions; + auto inp_logical_size = std::ranges::distance(inp_logical); + // Clone the logical domain to create the root domain for output std::vector root_domain; root_domain.reserve(inp_logical_size); From 4d8acaba484d63b0f2653964ecd8eaa919081f78 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 8 Jan 2026 17:07:37 -0800 Subject: [PATCH 29/31] cleanup --- csrc/ops/utils.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index be50385528c..ccb15fd7f45 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -331,10 +331,7 @@ RaggedIterDomain* newOutputRaggedIterDomain( // equivalent RaggedIterDomain* ref_input_id = input_ids.front()->as(); - return IrBuilder::create( - ref_input_id->extents(), - ref_input_id->getIterType(), - ref_input_id->getParallelType()); + return IterDomainBuilder(ref_input_id).build()->as(); } // Adding these pragmas since gcc-12.2.1 From 3b082ba6fcd636e517610b44eaec8c07436db02a Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 8 Jan 2026 22:59:50 -0800 Subject: [PATCH 30/31] cleanup --- csrc/ir/internal_base_nodes.cpp | 4 +--- csrc/ops/utils.cpp | 9 ++++++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 698a047db6f..e9a0da1fa2f 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -983,9 +983,7 @@ std::string RaggedIterDomain::toString(int indent_size) const { } IterDomain* RaggedIterDomain::cloneWithoutRFactor(bool map_with_original) { - // Create a new RaggedIterDomain with the same extents and properties - auto cloned = IrBuilder::create( - extents_, getIterType(), getParallelType()); + auto cloned = IterDomainBuilder(this).resetRfactor().build(); // Optionally map the clone with the original in the Exact graph if (map_with_original) { diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index ccb15fd7f45..36e46ef97e2 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -345,13 +345,20 @@ IterDomain* newOutputIterDomain( const std::optional force_iter_type) { NVF_ERROR(!input_ids.empty()); - // If any input ID is a RaggedIterDomain, the output should also be ragged + // If an input ID is a RaggedIterDomain, the output as well as all + // other inputs must be ragged bool has_ragged = std::any_of(input_ids.begin(), input_ids.end(), [](IterDomain* id) { return id->isA(); }); if (has_ragged) { + NVF_ERROR( + std::all_of( + input_ids.begin(), + input_ids.end(), + [](IterDomain* id) { return id->isA(); }), + "All of none input IDs must be ragged"); NVF_ERROR( !force_iter_type.has_value(), "force_iter_type not supported for RaggedIterDomain"); From bfc3da951cb147842b7f9289e382e5cf3d1b60b9 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 8 Jan 2026 23:09:21 -0800 Subject: [PATCH 31/31] Update csrc/ops/utils.cpp Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- csrc/ops/utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 36e46ef97e2..80f2bfe7ea0 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -358,7 +358,7 @@ IterDomain* newOutputIterDomain( input_ids.begin(), input_ids.end(), [](IterDomain* id) { return id->isA(); }), - "All of none input IDs must be ragged"); + "All or none input IDs must be ragged"); NVF_ERROR( !force_iter_type.has_value(), "force_iter_type not supported for RaggedIterDomain");