diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 292c00a8740..e9a0da1fa2f 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -30,8 +30,8 @@ namespace nvfuser { -IterDomainBuilder::IterDomainBuilder(Val* _start, Val* _extent) - : start_(_start), extent_(_extent) { +IterDomainBuilder::IterDomainBuilder(Val* start, Val* extent) + : start_(start), extent_(extent) { NVF_ERROR( start_ != nullptr && extent_ != nullptr, "Start and extent are required to build an iter domain."); @@ -48,7 +48,11 @@ IterDomainBuilder::IterDomainBuilder(const IterDomain* id) is_rfactor_domain_(id->isRFactorProduct()), is_padded_dimension_(id->hasPaddingToMultipleOfWarp()), is_clustered_dimension_(id->isClusteredBlockDim()), - padded_to_size_(id->getMaybeSizeAfterPadding()) {} + padded_to_size_(id->getMaybeSizeAfterPadding()) { + if (id->isA()) { + ragged_extents_ = id->as()->extents(); + } +} IterDomainBuilder& IterDomainBuilder::resetSchedulingParams() { parallel_type_ = ParallelType::Serial; @@ -63,52 +67,58 @@ IterDomainBuilder& IterDomainBuilder::resetRfactor() { return is_rfactor_domain(false); } -IterDomainBuilder& IterDomainBuilder::start(Val* _start) { - start_ = _start; +IterDomainBuilder& IterDomainBuilder::start(Val* start) { + start_ = start; return *this; } -IterDomainBuilder& IterDomainBuilder::extent(Val* _extent) { - extent_ = _extent; +IterDomainBuilder& IterDomainBuilder::extent(Val* extent) { + extent_ = extent; return *this; } -IterDomainBuilder& IterDomainBuilder::expanded_extent(Val* _expanded_extent) { - expanded_extent_ = _expanded_extent; +IterDomainBuilder& IterDomainBuilder::expanded_extent(Val* expanded_extent) { + expanded_extent_ = expanded_extent; return *this; } -IterDomainBuilder& IterDomainBuilder::stop_offset(Val* _stop_offset) { - stop_offset_ = _stop_offset; +IterDomainBuilder& IterDomainBuilder::stop_offset(Val* stop_offset) { + stop_offset_ = stop_offset; return *this; } IterDomainBuilder& IterDomainBuilder::parallel_type( - ParallelType _parallel_type) { - parallel_type_ = _parallel_type; + ParallelType parallel_type) { + parallel_type_ = parallel_type; return *this; } -IterDomainBuilder& IterDomainBuilder::iter_type(IterType _iter_type) { - iter_type_ = _iter_type; +IterDomainBuilder& IterDomainBuilder::iter_type(IterType iter_type) { + iter_type_ = iter_type; return *this; } IterDomainBuilder& IterDomainBuilder::is_rfactor_domain( - bool _is_rfactor_domain) { - is_rfactor_domain_ = _is_rfactor_domain; + bool is_rfactor_domain) { + is_rfactor_domain_ = is_rfactor_domain; return *this; } IterDomainBuilder& IterDomainBuilder::is_padded_dimension( - bool _is_padded_dimension) { - is_padded_dimension_ = _is_padded_dimension; + bool is_padded_dimension) { + is_padded_dimension_ = is_padded_dimension; return *this; } IterDomainBuilder& IterDomainBuilder::padded_to_size( - std::optional _padded_to_size) { - padded_to_size_ = _padded_to_size; + std::optional padded_to_size) { + padded_to_size_ = padded_to_size; + return *this; +} + +IterDomainBuilder& IterDomainBuilder::ragged_extents( + TensorView* ragged_extents) { + ragged_extents_ = ragged_extents; return *this; } @@ -116,7 +126,13 @@ IterDomain* IterDomainBuilder::build() const { NVF_ERROR( start_ != nullptr && extent_ != nullptr, "Start and extent are required to build an iter domain."); - return IrBuilder::createInContainer(start_->container(), *this); + + if (ragged_extents_ != nullptr) { + return IrBuilder::createInContainer( + start_->container(), *this); + } else { + return IrBuilder::createInContainer(start_->container(), *this); + } } IterDomain::IterDomain( @@ -815,6 +831,77 @@ void validateLoopDomain( } // namespace +RaggedIterDomain::RaggedIterDomain( + IrBuilderPasskey passkey, + const IterDomainBuilder& args) + : IterDomain( + passkey, + ValType::RaggedIterDomain, + args.start_, + args.extent_, + args.expanded_extent_, + args.stop_offset_, + args.parallel_type_, + args.iter_type_, + args.is_rfactor_domain_, + args.is_padded_dimension_, + args.is_clustered_dimension_, + args.padded_to_size_), + extents_(args.ragged_extents_) { + // Extents must be non-null + NVF_ERROR( + extents_ != nullptr, "RaggedIterDomain requires non-null extents tensor"); + + // Extents must have integer dtype + NVF_ERROR_EQ( + extents_->dtype(), + DataType::Index, + "RaggedIterDomain extents must have index type, got ", + extents_->dtype()); + + // Only IterType::Iteration is supported at this moment + NVF_ERROR_EQ( + iter_type_, + IterType::Iteration, + "Only IterType::Iteration is supported: ", + iter_type_); + + // RaggedIterDomain has specific requirements on member values + NVF_ERROR( + start_->isZeroInt(), + "RaggedIterDomain start must be zero, got: ", + start_->toInlineString()); + + NVF_ERROR( + extent_->isOneInt(), + "RaggedIterDomain extent must be one (placeholder), got: ", + extent_->toInlineString()); + + NVF_ERROR( + expanded_extent_ == nullptr, + "RaggedIterDomain does not support expanded_extent"); + + NVF_ERROR( + stop_offset_ == nullptr || stop_offset_->isZeroInt(), + "RaggedIterDomain stop_offset must be nullptr or zero, got: ", + stop_offset_ ? stop_offset_->toInlineString() : "nullptr"); + + NVF_ERROR( + !is_rfactor_domain_, "RaggedIterDomain does not support rfactor domains"); + + NVF_ERROR( + !is_padded_dimension_, + "RaggedIterDomain does not support padded dimensions"); + + NVF_ERROR( + !is_clustered_dimension_, + "RaggedIterDomain does not support clustered dimensions"); + + NVF_ERROR( + !padded_to_size_.has_value(), + "RaggedIterDomain does not support padded_to_size"); +} + RaggedIterDomain::RaggedIterDomain( IrBuilderPasskey passkey, TensorView* extents, @@ -895,6 +982,18 @@ std::string RaggedIterDomain::toString(int indent_size) const { return toInlineString(indent_size); } +IterDomain* RaggedIterDomain::cloneWithoutRFactor(bool map_with_original) { + auto cloned = IterDomainBuilder(this).resetRfactor().build(); + + // Optionally map the clone with the original in the Exact graph + if (map_with_original) { + // TODO: Implement mapping if needed + NVF_THROW("Not implemented"); + } + + return cloned; +} + std::pair RaggedIterDomain::partition( IterDomain* in, TensorView* extents) { @@ -1472,6 +1571,13 @@ bool TensorDomain::hasVectorize() const { }); } +bool TensorDomain::hasRaggedIterDomain() const { + return std::any_of( + logical().begin(), logical().end(), [](IterDomain* logical_id) { + return logical_id->isA(); + }); +} + std::optional TensorDomain::getReductionAxis() const { auto it = std::find_if( loop_domain_.begin(), loop_domain_.end(), [](const auto& id) { diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index c5fad115ba3..9d40fcfbf1e 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -40,7 +40,7 @@ struct AnalyzeViewResult; class IterDomainBuilder { public: // Match legacy constructor - IterDomainBuilder(Val* _start, Val* _extent); + IterDomainBuilder(Val* start, Val* extent); // Grab all the parameters from id to set the IterDomainBuilder IterDomainBuilder(const IterDomain* id); @@ -52,15 +52,16 @@ class IterDomainBuilder { // Resets is_rfactor_domain IterDomainBuilder& resetRfactor(); - IterDomainBuilder& start(Val* _start); - IterDomainBuilder& extent(Val* _extent); - IterDomainBuilder& expanded_extent(Val* _expanded_extent); - IterDomainBuilder& stop_offset(Val* _stop_offset); - IterDomainBuilder& parallel_type(ParallelType _parallel_type); - IterDomainBuilder& iter_type(IterType _iter_type); - IterDomainBuilder& is_rfactor_domain(bool _is_rfactor_domain); - IterDomainBuilder& is_padded_dimension(bool _is_padded_dimension); - IterDomainBuilder& padded_to_size(std::optional _padded_to_size); + IterDomainBuilder& start(Val* start); + IterDomainBuilder& extent(Val* extent); + IterDomainBuilder& expanded_extent(Val* expanded_extent); + IterDomainBuilder& stop_offset(Val* stop_offset); + IterDomainBuilder& parallel_type(ParallelType parallel_type); + IterDomainBuilder& iter_type(IterType iter_type); + IterDomainBuilder& is_rfactor_domain(bool is_rfactor_domain); + IterDomainBuilder& is_padded_dimension(bool is_padded_dimension); + IterDomainBuilder& padded_to_size(std::optional padded_to_size); + IterDomainBuilder& ragged_extents(TensorView* ragged_extents); IterDomain* build() const; @@ -79,6 +80,9 @@ class IterDomainBuilder { bool is_padded_dimension_ = false; bool is_clustered_dimension_ = false; std::optional padded_to_size_ = std::nullopt; + + // For RaggedIterDomain: stores the extents tensor + TensorView* ragged_extents_ = nullptr; }; //! Simply a representation of an annotated 1D iterable from start to extent. @@ -122,7 +126,7 @@ class NVF_API IterDomain : public Val { //! //! When map_with_original is true, the clone of the original is //! mapped in the Exact graph. - IterDomain* cloneWithoutRFactor(bool map_with_original = false); + virtual IterDomain* cloneWithoutRFactor(bool map_with_original = false); //! Clone a vector domains static std::vector clone( @@ -448,6 +452,8 @@ class NVF_API IterDomain : public Val { //! components class NVF_API RaggedIterDomain : public IterDomain { public: + RaggedIterDomain(IrBuilderPasskey passkey, const IterDomainBuilder& args); + //! \param extents TensorView containing component extents (must be integer //! type) //! \param iter_type Iteration type (Iteration, Reduction, etc.) @@ -493,6 +499,9 @@ class NVF_API RaggedIterDomain : public IterDomain { IterDomain* in, TensorView* extents); + //! Override cloneWithoutRFactor to preserve RaggedIterDomain type + IterDomain* cloneWithoutRFactor(bool map_with_original = false) override; + private: //! Extent tensor containing all component extents //! Can be 1D, 2D, or N-D depending on nesting structure @@ -643,6 +652,8 @@ class NVF_API TensorDomain : public Val { bool hasSymbolicAxis() const; + bool hasRaggedIterDomain() const; + std::optional getReductionAxis() const; // The input logical domain. The root domain of a consumer should equal the diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 0fdb03274c0..d5b1990fa80 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -142,6 +142,12 @@ TensorView* reshape(TensorView* inp_tv, const std::vector& new_sizes) { "Unsupported input tensor to reshape as its axes may be partial: ", inp_tv->toString()); + NVF_CHECK( + !inp_tv->domain()->hasRaggedIterDomain(), + "Reshape operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + inp_tv->toString()); + auto static_reshape_output = tryStaticReshape(inp_tv, inp_dom, new_sizes); if (static_reshape_output) { return static_reshape_output; @@ -239,6 +245,12 @@ TensorView* flatten(TensorView* x, int64_t start_dim, int64_t end_dim) { end_dim); NVF_CHECK(start_dim <= end_dim, "start_dim must be <= end_dim"); + NVF_CHECK( + !x->domain()->hasRaggedIterDomain(), + "Flatten operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + x->toString()); + if (start_dim == end_dim) { return x; } @@ -518,6 +530,11 @@ TensorView* pad( const std::vector& pad_widths, Val* value, std::optional iter_type_opt) { + NVF_CHECK( + !inp->domain()->hasRaggedIterDomain(), + "Padding a tensor with RaggedIterDomain not supported: ", + inp->toString()); + DataType dt = inp->getDataType().value(); if (!value) { // Create a zero of the appropriate type @@ -624,6 +641,14 @@ TensorView* cat( bool manual_padding) { NVF_CHECK(!inputs.empty(), "No input tensor given"); + NVF_CHECK( + std::ranges::none_of( + inputs, + [](TensorView* inp_tv) { + return inp_tv->domain()->hasRaggedIterDomain(); + }), + "Concat with a tensor with RaggedIterDomain not supported"); + const auto dtype = inputs.at(0)->getDataType().value(); std::vector> inp_doms; @@ -783,7 +808,12 @@ TensorView* slice( NVF_CHECK_EQ( ndims, std::ssize(ranges), - "The range vector must have the same number of Slice descriptors.") + "The range vector must have the same number of Slice descriptors."); + + NVF_CHECK( + !inp->domain()->hasRaggedIterDomain(), + "Slicing a tensor with RaggedIterDomain not supported: ", + inp->toString()); ExpressionEvaluator expr_eval; @@ -1058,6 +1088,12 @@ TensorView* broadcast( TensorView* expand(TensorView* inp, const std::vector& expanded_sizes) { auto inp_domain = TensorDomain::noReductions(inp->getLogicalDomain()); + NVF_CHECK( + !inp->domain()->hasRaggedIterDomain(), + "Expand operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + inp->toString()); + NVF_CHECK_GE(expanded_sizes.size(), inp_domain.size()); inp = ops::maybe_broadcast_inner_to_rank(inp, expanded_sizes.size()); @@ -1180,6 +1216,12 @@ TensorView* expand_as(TensorView* inp, TensorView* other) { TensorView* repeat( TensorView* inp_tv, const std::vector& repeat_times) { + NVF_CHECK( + !inp_tv->domain()->hasRaggedIterDomain(), + "Repeat operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + inp_tv->toString()); + const auto ndims = TensorDomain::noReductions(inp_tv->getLogicalDomain()).size(); @@ -1281,6 +1323,11 @@ TensorView* asNested( 1, "asNested currently only supports 1D extents tensors"); + NVF_CHECK( + !data->domain()->hasRaggedIterDomain(), + "Multiple level of nesting is not supported: ", + data->toString()); + // Get the logical domain of the input, excluding reductions auto inp_logical = data->getLogicalDomain() | TensorDomain::kNoReductions; auto inp_logical_size = std::ranges::distance(inp_logical); diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 64f68cbf1a4..17b80d965d2 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -1209,6 +1209,14 @@ TensorView* newForReduction( " of tensor ", tv); } + NVF_CHECK( + !id->isA(), + "Cannot reduce a RaggedIterDomain. Reduction of ragged dimensions is " + "not supported. " + "Tried to reduce ID = ", + id, + " of tensor ", + tv); new_id = IterDomainBuilder(id) // If the domain is being reduced, but it's coming in as an // expanded extent, we need to realize the expand. diff --git a/csrc/ops/indexing.cpp b/csrc/ops/indexing.cpp index fb2f1b3feda..b8975c9ac66 100644 --- a/csrc/ops/indexing.cpp +++ b/csrc/ops/indexing.cpp @@ -24,6 +24,12 @@ TensorView* select(TensorView* tv, int64_t dim, Val* index) { auto dom = TensorDomain::noReductions(tv->getLogicalDomain()); NVF_CHECK(!dom.empty(), "select can not be applied to 0d tensor."); + NVF_CHECK( + !tv->domain()->hasRaggedIterDomain(), + "Select operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + tv->toString()); + std::vector new_root; new_root.reserve(dom.size() - 1); dim = wrapDim(dim, (int64_t)dom.size()); @@ -46,6 +52,20 @@ TensorView* indexSelect( TensorView* lookup_tv, int64_t dim, TensorView* index_tv) { + NVF_CHECK( + !lookup_tv->domain()->hasRaggedIterDomain(), + "IndexSelect operation is not supported for tensors with " + "RaggedIterDomain. " + "Input tensor (lookup_tv): ", + lookup_tv->toString()); + + NVF_CHECK( + !index_tv->domain()->hasRaggedIterDomain(), + "IndexSelect operation is not supported for tensors with " + "RaggedIterDomain. " + "Index tensor (index_tv): ", + index_tv->toString()); + DataType dtype = lookup_tv->getDataType().value(); NVF_CHECK( dtype != DataType::Null, "Invalid datatype provided for new value."); @@ -131,6 +151,18 @@ TensorView* indexPutAccumulate( // torch.gather TensorView* gather(TensorView* inp, int64_t dim, TensorView* index) { + NVF_CHECK( + !inp->domain()->hasRaggedIterDomain(), + "Gather operation is not supported for tensors with RaggedIterDomain. " + "Input tensor (inp): ", + inp->toString()); + + NVF_CHECK( + !index->domain()->hasRaggedIterDomain(), + "Gather operation is not supported for tensors with RaggedIterDomain. " + "Index tensor (index): ", + index->toString()); + auto inp_domain = TensorDomain::noReductions(inp->getLogicalDomain()); auto idx_domain = TensorDomain::noReductions(index->getLogicalDomain()); NVF_CHECK( @@ -168,6 +200,26 @@ TensorView* scatter( TensorView* index, Val* src, std::optional accumulate_op) { + NVF_CHECK( + !self->domain()->hasRaggedIterDomain(), + "Scatter operation is not supported for tensors with RaggedIterDomain. " + "Input tensor (self): ", + self->toString()); + + NVF_CHECK( + !index->domain()->hasRaggedIterDomain(), + "Scatter operation is not supported for tensors with RaggedIterDomain. " + "Index tensor (index): ", + index->toString()); + + if (src->isA()) { + NVF_CHECK( + !src->as()->domain()->hasRaggedIterDomain(), + "Scatter operation is not supported for tensors with RaggedIterDomain. " + "Source tensor (src): ", + src->toString()); + } + auto self_dom = TensorDomain::noReductions(self->getLogicalDomain()); auto idx_dom = TensorDomain::noReductions(index->getLogicalDomain()); diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 2db0b424d55..80f2bfe7ea0 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -315,6 +315,25 @@ std::vector mapLinearOpIterDomains( return mapping; } +RaggedIterDomain* newOutputRaggedIterDomain( + const std::vector& input_ids) { + NVF_ERROR( + std::ranges::all_of( + input_ids, + [](IterDomain* input_id) { + return input_id->isA(); + }), + "All input iter domains must be RaggedIterDomain"); + + NVF_ERROR(!input_ids.empty()); + + // Just using the first ragged ID as all input IDs are assumed to be + // equivalent + RaggedIterDomain* ref_input_id = input_ids.front()->as(); + + return IterDomainBuilder(ref_input_id).build()->as(); +} + // Adding these pragmas since gcc-12.2.1 // incorrectly reports a warning with the use of evaluate #if defined(__GNUC__) && !defined(__clang__) @@ -324,6 +343,28 @@ std::vector mapLinearOpIterDomains( IterDomain* newOutputIterDomain( const std::vector& input_ids, const std::optional force_iter_type) { + NVF_ERROR(!input_ids.empty()); + + // If an input ID is a RaggedIterDomain, the output as well as all + // other inputs must be ragged + bool has_ragged = + std::any_of(input_ids.begin(), input_ids.end(), [](IterDomain* id) { + return id->isA(); + }); + + if (has_ragged) { + NVF_ERROR( + std::all_of( + input_ids.begin(), + input_ids.end(), + [](IterDomain* id) { return id->isA(); }), + "All or none input IDs must be ragged"); + NVF_ERROR( + !force_iter_type.has_value(), + "force_iter_type not supported for RaggedIterDomain"); + return newOutputRaggedIterDomain(input_ids); + } + // For the start and stop offsets, take the maximum of input axes. // For now, the offsets of both start and stop are always integer // constant, so we can statically compute them. It is unclear diff --git a/csrc/ops/utils.h b/csrc/ops/utils.h index 44a98242a4d..3ceadc4aa6a 100644 --- a/csrc/ops/utils.h +++ b/csrc/ops/utils.h @@ -92,6 +92,12 @@ std::vector mapLinearOpIterDomains( size_t out_size, bool k_bcast); +// Creates an output RaggedIterDomain from input RaggedIterDomains at the same +// dimension position. All inputs must be RaggedIterDomain. Uses the extents, +// IterType, and ParallelType from the first input. +RaggedIterDomain* newOutputRaggedIterDomain( + const std::vector& input_ids); + // Takes a vector of aligned input iterdomains to create the output iterdomain. // This is used if the input iterdomains are not trivially mapped to the output // iterdomains. For eg: MatmulOp. If given, the forced_iter_type argument will diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index cc6b65078c1..e346ac2d021 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -482,4 +482,487 @@ TEST_F(RaggedIterDomainTest, AsNestedValidationMultiDimExtents) { EXPECT_THROW(asNested(data, extents_2d, 0), nvfuser::nvfError); } +TEST_F(RaggedIterDomainTest, LoadStoreWithNestedTensor) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor from dimension 0 + auto nested = asNested(data, offsets, 0); + + // This should still be a nested tensor + auto copy_of_nested = set(nested); + + fusion.addOutput(copy_of_nested); + + // Verify the output is a new TensorView + EXPECT_TRUE(copy_of_nested != nullptr); + EXPECT_NE(copy_of_nested, data); + EXPECT_TRUE(copy_of_nested->isA()); + + // Verify copy_of_nested tensor has 3 dimensions: [component, ragged, + // original_dim1] + EXPECT_EQ(copy_of_nested->nDims(), 3); + + // First axis should be a regular IterDomain (component) + EXPECT_TRUE(copy_of_nested->axis(0)->isStrictlyA()); + EXPECT_FALSE(copy_of_nested->axis(0)->isA()); + + // Second axis should be a RaggedIterDomain + EXPECT_TRUE(copy_of_nested->axis(1)->isA()); + + // Third axis should be the original second dimension + EXPECT_TRUE(copy_of_nested->axis(2)->isStrictlyA()); + + // The copy of the original copy_of_nested tensor does not inherit the + // Partition op + EXPECT_TRUE(copy_of_nested->axis(0)->definition() == nullptr); +} + +// Test binary operations with nested tensors +TEST_F(RaggedIterDomainTest, BinaryOpWithNestedTensors) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create two 2D input tensors + auto data1 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data1); + + auto data2 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data2); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensors from both inputs + auto nested1 = asNested(data1, offsets, 0); + auto nested2 = asNested(data2, offsets, 0); + + // Perform binary operation. The result should be a nested tensor + auto result = add(nested1, nested2); + + fusion.addOutput(result); + + // Verify the result has 3 dimensions: [component, ragged, original_dim1] + EXPECT_EQ(result->nDims(), 3); + + // First axis should be a regular IterDomain (component) + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_FALSE(result->axis(0)->isA()); + + // Second axis should be a RaggedIterDomain + EXPECT_TRUE(result->axis(1)->isA()); + + // Third axis should be the original second dimension + EXPECT_TRUE(result->axis(2)->isStrictlyA()); +} + +// Test binary operation with mixed inputs (one ragged, one not) - should error +TEST_F(RaggedIterDomainTest, BinaryOpMixedInputsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data1 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data1); + + auto data2 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data2); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor from first input only + auto nested1 = asNested(data1, offsets, 0); + + // Try to add nested tensor with non-nested tensor + // This should fail because one is ragged and one is not + EXPECT_THROW(add(nested1, data2), nvfuser::nvfError); +} + +// Test binary operation with different offsets +TEST_F(RaggedIterDomainTest, BinaryOpDifferentRaggedStructures) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data1 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data1); + + auto data2 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data2); + + auto offsets1 = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets1); + + auto offsets2 = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets2); + + // Create nested tensors with different offset tensors + auto nested1 = asNested(data1, offsets1, 0); + auto nested2 = asNested(data2, offsets2, 0); + + // This would be an error if, for example, the values of the offset + // tensors are not equivalent, but, like binary ops with normal + // tensors, we assume their shapes are indeed compatible + auto result = add(nested1, nested2); + fusion.addOutput(result); + + EXPECT_TRUE(result->axis(1)->isA()); +} + +// Test unary operations with nested tensors +TEST_F(RaggedIterDomainTest, UnaryOpWithNestedTensors) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor + auto nested = asNested(data, offsets, 0); + + // Perform unary operation: neg + auto result = neg(nested); + + fusion.addOutput(result); + + // Verify the result preserves RaggedIterDomain structure + EXPECT_EQ(result->nDims(), 3); + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_TRUE(result->axis(1)->isA()); + EXPECT_TRUE(result->axis(2)->isStrictlyA()); +} + +// Test broadcast with nested tensors +TEST_F(RaggedIterDomainTest, BroadcastWithNestedTensors) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + auto result = broadcast(nested, {false, false, false, true}); + + fusion.addOutput(result); + + // Result should be: [component, ragged, dim1, broadcast_dim] + EXPECT_EQ(result->nDims(), 4); + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_TRUE(result->axis(1)->isA()); + EXPECT_TRUE(result->axis(2)->isStrictlyA()); + EXPECT_TRUE(result->axis(3)->isStrictlyA()); + EXPECT_TRUE(result->axis(3)->isBroadcast()); +} + +// Test squeeze on non-ragged dimension +TEST_F(RaggedIterDomainTest, SqueezeNonRaggedDim) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // First broadcast to add a dimension: [component, ragged, dim1, 1] + auto broadcasted = broadcast(nested, {false, false, false, true}); + + // Then squeeze the broadcast dimension (dimension index 3) + auto result = squeeze(broadcasted, {3}); + + fusion.addOutput(result); + + // Result should be: [component, ragged, dim1] + EXPECT_EQ(result->nDims(), 3); + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_TRUE(result->axis(1)->isA()); + EXPECT_TRUE(result->axis(2)->isStrictlyA()); +} + +// Test unsqueeze with nested tensors +TEST_F(RaggedIterDomainTest, UnsqueezeWithNestedTensors) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Unsqueeze to add dimension at the end + auto result = unsqueeze(nested, -1); + + fusion.addOutput(result); + + // Result should be: [component, ragged, dim1, 1] + EXPECT_EQ(result->nDims(), 4); + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_TRUE(result->axis(1)->isA()); + EXPECT_TRUE(result->axis(2)->isStrictlyA()); + EXPECT_TRUE(result->axis(3)->isStrictlyA()); +} + +// Test permute/transpose with nested tensors +TEST_F(RaggedIterDomainTest, PermuteWithNestedTensors) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Permute dimensions: swap ragged and dim1 + auto result = permute(nested, {0, 2, 1}); + + fusion.addOutput(result); + + // Result should be: [component, dim1, ragged] + EXPECT_EQ(result->nDims(), 3); + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_TRUE(result->axis(1)->isStrictlyA()); + EXPECT_TRUE(result->axis(2)->isA()); +} + +// Test reduction on non-ragged dimension +TEST_F(RaggedIterDomainTest, ReductionOnNonRaggedDim) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Reduce along the last dimension (non-ragged) + auto result = sum(nested, {2}); + + fusion.addOutput(result); + + // Result should be: [component, ragged] + // Get non-reduction dimensions + auto non_reduction_domain = + TensorDomain::noReductions(result->getLogicalDomain()); + + EXPECT_EQ(non_reduction_domain.size(), 2); + EXPECT_TRUE(non_reduction_domain[0]->isStrictlyA()); + EXPECT_TRUE(non_reduction_domain[1]->isA()); +} + +// Test reduction on ragged dimension - should error +TEST_F(RaggedIterDomainTest, ReductionOnRaggedDimError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to reduce along the ragged dimension (axis 1) + // This should throw an error because reducing RaggedIterDomain is not allowed + EXPECT_THROW(sum(nested, {1}), nvfuser::nvfError); +} + +// Test reduction on component dimension - should error (TODO) +TEST_F(RaggedIterDomainTest, ReductionOnComponentDimError) { + GTEST_SKIP() << "TODO: Implement validation to prevent reduction of " + "component dimension. " + << "Currently there is no explicit marking of which IterDomains " + "are component dimensions, " + << "so this validation cannot be implemented yet."; + + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to reduce along the component dimension (axis 0) + // This should throw an error because reducing component dimensions is not + // allowed The component dimension defines the batch structure of the ragged + // tensor, and reducing it would destroy the ragged structure + EXPECT_THROW(sum(nested, {0}), nvfuser::nvfError); +} + +// Test reshape with nested tensors - should error +TEST_F(RaggedIterDomainTest, ReshapeWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to reshape - this should throw an error because reshape is not + // supported for tensors with RaggedIterDomain + std::vector new_shape = { + IrBuilder::create(-1L, DataType::Index), nested->axis(2)->extent()}; + EXPECT_THROW(reshape(nested, new_shape), nvfuser::nvfError); +} + +// Test flatten with nested tensors - should error +TEST_F(RaggedIterDomainTest, FlattenWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to flatten - this should throw an error because flatten is not + // supported for tensors with RaggedIterDomain + EXPECT_THROW(flatten(nested, 0, 2), nvfuser::nvfError); +} + +// Test slice on ragged dimension - should error +TEST_F(RaggedIterDomainTest, SliceRaggedDimensionError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to slice the ragged dimension (axis 1) + // This should error because resize on RaggedIterDomain is not allowed + EXPECT_THROW( + slice( + nested, + {{fusion.zeroVal(), fusion.oneVal()}, + {fusion.zeroVal(), fusion.oneVal()}, + {fusion.zeroVal(), nested->axis(2)->extent()}}), + nvfuser::nvfError); +} + +// Test cat on ragged dimension - should error +TEST_F(RaggedIterDomainTest, CatRaggedDimensionError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data1 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data1); + + auto data2 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data2); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensors with same structure + auto nested1 = asNested(data1, offsets, 0); + auto nested2 = asNested(data2, offsets, 0); + + // Try to concatenate along ragged dimension (axis 1) + // This should error because cat would need to resize RaggedIterDomain + EXPECT_THROW(cat({nested1, nested2}, 1), nvfuser::nvfError); +} + +// Test cat on non-ragged dimension - currently also errors +TEST_F(RaggedIterDomainTest, CatNonRaggedDimensionError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data1 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data1); + + auto data2 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data2); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensors with same structure + auto nested1 = asNested(data1, offsets, 0); + auto nested2 = asNested(data2, offsets, 0); + + // Try to concatenate along non-ragged dimension (axis 2) + // Currently cat rejects all tensors with RaggedIterDomain for safety + // In the future, this could be supported if concatenating along non-ragged + // dimensions + EXPECT_THROW(cat({nested1, nested2}, 2), nvfuser::nvfError); +} + +// Test pad on ragged dimension - should error +TEST_F(RaggedIterDomainTest, PadRaggedDimensionError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to pad the ragged dimension (axis 1) + // This should error because pad uses resize on RaggedIterDomain + std::vector pad_widths = { + fusion.zeroVal(), + fusion.zeroVal(), // component: no padding + fusion.oneVal(), + fusion.oneVal(), // ragged: PADDING - should error + fusion.zeroVal(), + fusion.zeroVal() // dim1: no padding + }; + + EXPECT_THROW(pad(nested, pad_widths), nvfuser::nvfError); +} + } // namespace nvfuser