Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 130 additions & 22 deletions csrc/ir/internal_base_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@

namespace nvfuser {

IterDomainBuilder::IterDomainBuilder(Val* _start, Val* _extent)
: start_(_start), extent_(_extent) {
IterDomainBuilder::IterDomainBuilder(Val* start, Val* extent)
: start_(start), extent_(extent) {
NVF_ERROR(
start_ != nullptr && extent_ != nullptr,
"Start and extent are required to build an iter domain.");
Expand All @@ -48,7 +48,11 @@ IterDomainBuilder::IterDomainBuilder(const IterDomain* id)
is_rfactor_domain_(id->isRFactorProduct()),
is_padded_dimension_(id->hasPaddingToMultipleOfWarp()),
is_clustered_dimension_(id->isClusteredBlockDim()),
padded_to_size_(id->getMaybeSizeAfterPadding()) {}
padded_to_size_(id->getMaybeSizeAfterPadding()) {
if (id->isA<RaggedIterDomain>()) {
ragged_extents_ = id->as<RaggedIterDomain>()->extents();
}
}

IterDomainBuilder& IterDomainBuilder::resetSchedulingParams() {
parallel_type_ = ParallelType::Serial;
Expand All @@ -63,60 +67,72 @@ IterDomainBuilder& IterDomainBuilder::resetRfactor() {
return is_rfactor_domain(false);
}

IterDomainBuilder& IterDomainBuilder::start(Val* _start) {
start_ = _start;
IterDomainBuilder& IterDomainBuilder::start(Val* start) {
start_ = start;
return *this;
}

IterDomainBuilder& IterDomainBuilder::extent(Val* _extent) {
extent_ = _extent;
IterDomainBuilder& IterDomainBuilder::extent(Val* extent) {
extent_ = extent;
return *this;
}

IterDomainBuilder& IterDomainBuilder::expanded_extent(Val* _expanded_extent) {
expanded_extent_ = _expanded_extent;
IterDomainBuilder& IterDomainBuilder::expanded_extent(Val* expanded_extent) {
expanded_extent_ = expanded_extent;
return *this;
}

IterDomainBuilder& IterDomainBuilder::stop_offset(Val* _stop_offset) {
stop_offset_ = _stop_offset;
IterDomainBuilder& IterDomainBuilder::stop_offset(Val* stop_offset) {
stop_offset_ = stop_offset;
return *this;
}

IterDomainBuilder& IterDomainBuilder::parallel_type(
ParallelType _parallel_type) {
parallel_type_ = _parallel_type;
ParallelType parallel_type) {
parallel_type_ = parallel_type;
return *this;
}

IterDomainBuilder& IterDomainBuilder::iter_type(IterType _iter_type) {
iter_type_ = _iter_type;
IterDomainBuilder& IterDomainBuilder::iter_type(IterType iter_type) {
iter_type_ = iter_type;
return *this;
}

IterDomainBuilder& IterDomainBuilder::is_rfactor_domain(
bool _is_rfactor_domain) {
is_rfactor_domain_ = _is_rfactor_domain;
bool is_rfactor_domain) {
is_rfactor_domain_ = is_rfactor_domain;
return *this;
}

IterDomainBuilder& IterDomainBuilder::is_padded_dimension(
bool _is_padded_dimension) {
is_padded_dimension_ = _is_padded_dimension;
bool is_padded_dimension) {
is_padded_dimension_ = is_padded_dimension;
return *this;
}

IterDomainBuilder& IterDomainBuilder::padded_to_size(
std::optional<int64_t> _padded_to_size) {
padded_to_size_ = _padded_to_size;
std::optional<int64_t> padded_to_size) {
padded_to_size_ = padded_to_size;
return *this;
}

IterDomainBuilder& IterDomainBuilder::ragged_extents(
TensorView* ragged_extents) {
ragged_extents_ = ragged_extents;
return *this;
}

IterDomain* IterDomainBuilder::build() const {
NVF_ERROR(
start_ != nullptr && extent_ != nullptr,
"Start and extent are required to build an iter domain.");
return IrBuilder::createInContainer<IterDomain>(start_->container(), *this);

if (ragged_extents_ != nullptr) {
return IrBuilder::createInContainer<RaggedIterDomain>(
start_->container(), *this);
} else {
return IrBuilder::createInContainer<IterDomain>(start_->container(), *this);
}
}

IterDomain::IterDomain(
Expand Down Expand Up @@ -815,6 +831,77 @@ void validateLoopDomain(

} // namespace

RaggedIterDomain::RaggedIterDomain(
IrBuilderPasskey passkey,
const IterDomainBuilder& args)
: IterDomain(
passkey,
ValType::RaggedIterDomain,
args.start_,
args.extent_,
args.expanded_extent_,
args.stop_offset_,
args.parallel_type_,
args.iter_type_,
args.is_rfactor_domain_,
args.is_padded_dimension_,
args.is_clustered_dimension_,
args.padded_to_size_),
extents_(args.ragged_extents_) {
// Extents must be non-null
NVF_ERROR(
extents_ != nullptr, "RaggedIterDomain requires non-null extents tensor");

// Extents must have integer dtype
NVF_ERROR_EQ(
extents_->dtype(),
DataType::Index,
"RaggedIterDomain extents must have index type, got ",
extents_->dtype());

// Only IterType::Iteration is supported at this moment
NVF_ERROR_EQ(
iter_type_,
IterType::Iteration,
"Only IterType::Iteration is supported: ",
iter_type_);

// RaggedIterDomain has specific requirements on member values
NVF_ERROR(
start_->isZeroInt(),
"RaggedIterDomain start must be zero, got: ",
start_->toInlineString());

NVF_ERROR(
extent_->isOneInt(),
"RaggedIterDomain extent must be one (placeholder), got: ",
extent_->toInlineString());

NVF_ERROR(
expanded_extent_ == nullptr,
"RaggedIterDomain does not support expanded_extent");

NVF_ERROR(
stop_offset_ == nullptr || stop_offset_->isZeroInt(),
"RaggedIterDomain stop_offset must be nullptr or zero, got: ",
stop_offset_ ? stop_offset_->toInlineString() : "nullptr");

NVF_ERROR(
!is_rfactor_domain_, "RaggedIterDomain does not support rfactor domains");

NVF_ERROR(
!is_padded_dimension_,
"RaggedIterDomain does not support padded dimensions");

NVF_ERROR(
!is_clustered_dimension_,
"RaggedIterDomain does not support clustered dimensions");

NVF_ERROR(
!padded_to_size_.has_value(),
"RaggedIterDomain does not support padded_to_size");
}

RaggedIterDomain::RaggedIterDomain(
IrBuilderPasskey passkey,
TensorView* extents,
Expand Down Expand Up @@ -895,6 +982,20 @@ std::string RaggedIterDomain::toString(int indent_size) const {
return toInlineString(indent_size);
}

IterDomain* RaggedIterDomain::cloneWithoutRFactor(bool map_with_original) {
// Create a new RaggedIterDomain with the same extents and properties
auto cloned = IrBuilder::create<RaggedIterDomain>(
extents_, getIterType(), getParallelType());

// Optionally map the clone with the original in the Exact graph
if (map_with_original) {
// TODO: Implement mapping if needed
NVF_THROW("Not implemented");
}
Comment on lines +991 to +994
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: mapping implementation missing - should call fusion()->registerExactMapping(this, cloned) like base IterDomain::cloneWithoutRFactor does (line 334)


return cloned;
}

std::pair<IterDomain*, RaggedIterDomain*> RaggedIterDomain::partition(
IterDomain* in,
TensorView* extents) {
Expand Down Expand Up @@ -1472,6 +1573,13 @@ bool TensorDomain::hasVectorize() const {
});
}

bool TensorDomain::hasRaggedIterDomain() const {
return std::any_of(
logical().begin(), logical().end(), [](IterDomain* logical_id) {
return logical_id->isA<RaggedIterDomain>();
});
}

std::optional<int64_t> TensorDomain::getReductionAxis() const {
auto it = std::find_if(
loop_domain_.begin(), loop_domain_.end(), [](const auto& id) {
Expand Down
33 changes: 22 additions & 11 deletions csrc/ir/internal_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ struct AnalyzeViewResult;
class IterDomainBuilder {
public:
// Match legacy constructor
IterDomainBuilder(Val* _start, Val* _extent);
IterDomainBuilder(Val* start, Val* extent);

// Grab all the parameters from id to set the IterDomainBuilder
IterDomainBuilder(const IterDomain* id);
Expand All @@ -52,15 +52,16 @@ class IterDomainBuilder {
// Resets is_rfactor_domain
IterDomainBuilder& resetRfactor();

IterDomainBuilder& start(Val* _start);
IterDomainBuilder& extent(Val* _extent);
IterDomainBuilder& expanded_extent(Val* _expanded_extent);
IterDomainBuilder& stop_offset(Val* _stop_offset);
IterDomainBuilder& parallel_type(ParallelType _parallel_type);
IterDomainBuilder& iter_type(IterType _iter_type);
IterDomainBuilder& is_rfactor_domain(bool _is_rfactor_domain);
IterDomainBuilder& is_padded_dimension(bool _is_padded_dimension);
IterDomainBuilder& padded_to_size(std::optional<int64_t> _padded_to_size);
IterDomainBuilder& start(Val* start);
IterDomainBuilder& extent(Val* extent);
IterDomainBuilder& expanded_extent(Val* expanded_extent);
IterDomainBuilder& stop_offset(Val* stop_offset);
IterDomainBuilder& parallel_type(ParallelType parallel_type);
IterDomainBuilder& iter_type(IterType iter_type);
IterDomainBuilder& is_rfactor_domain(bool is_rfactor_domain);
IterDomainBuilder& is_padded_dimension(bool is_padded_dimension);
IterDomainBuilder& padded_to_size(std::optional<int64_t> padded_to_size);
IterDomainBuilder& ragged_extents(TensorView* ragged_extents);

IterDomain* build() const;

Expand All @@ -79,6 +80,9 @@ class IterDomainBuilder {
bool is_padded_dimension_ = false;
bool is_clustered_dimension_ = false;
std::optional<int64_t> padded_to_size_ = std::nullopt;

// For RaggedIterDomain: stores the extents tensor
TensorView* ragged_extents_ = nullptr;
};

//! Simply a representation of an annotated 1D iterable from start to extent.
Expand Down Expand Up @@ -122,7 +126,7 @@ class NVF_API IterDomain : public Val {
//!
//! When map_with_original is true, the clone of the original is
//! mapped in the Exact graph.
IterDomain* cloneWithoutRFactor(bool map_with_original = false);
virtual IterDomain* cloneWithoutRFactor(bool map_with_original = false);

//! Clone a vector domains
static std::vector<IterDomain*> clone(
Expand Down Expand Up @@ -448,6 +452,8 @@ class NVF_API IterDomain : public Val {
//! components
class NVF_API RaggedIterDomain : public IterDomain {
public:
RaggedIterDomain(IrBuilderPasskey passkey, const IterDomainBuilder& args);

//! \param extents TensorView containing component extents (must be integer
//! type)
//! \param iter_type Iteration type (Iteration, Reduction, etc.)
Expand Down Expand Up @@ -493,6 +499,9 @@ class NVF_API RaggedIterDomain : public IterDomain {
IterDomain* in,
TensorView* extents);

//! Override cloneWithoutRFactor to preserve RaggedIterDomain type
IterDomain* cloneWithoutRFactor(bool map_with_original = false) override;

private:
//! Extent tensor containing all component extents
//! Can be 1D, 2D, or N-D depending on nesting structure
Expand Down Expand Up @@ -643,6 +652,8 @@ class NVF_API TensorDomain : public Val {

bool hasSymbolicAxis() const;

bool hasRaggedIterDomain() const;

std::optional<int64_t> getReductionAxis() const;

// The input logical domain. The root domain of a consumer should equal the
Expand Down
Loading