Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 23 additions & 17 deletions csrc/id_model/id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1345,6 +1345,22 @@ void IdModel::allocateLoopIndexVariables() {

ParallelType ptype = getParallelType(loop_group);

if (GpuLower::current()->circularBufferInfo().isCircularBufferedIterDomain(
loop_group->front()->as<IterDomain>())) {
// Allocate index variable for each stage of the circular
// buffered loop.
auto indices = std::make_unique<CircularBufferIndices>();
for (auto i :
arange(static_cast<int>(CircularBufferLoopStage::EndOfStages))) {
indices->emplace(
static_cast<CircularBufferLoopStage>(i),
IrBuilder::create<Val>(DataType::Index));
}
circular_buffered_loop_index_variable_map_[loop_group] =
std::move(indices);
continue;
}

Val* loop_index = nullptr;

// TODO: Cleanup needed. ir_utils::isMemoryPartitionedAcross
Expand All @@ -1354,7 +1370,7 @@ void IdModel::allocateLoopIndexVariables() {
if (shouldUseZeroIndex(loop_group, *this) ||
isParallelTypeDeviceDim(ptype)) {
loop_index = fusion_->zeroVal();
} else if (isParallelTypeThread(ptype)) {
} else if (isParallelTypeThread(ptype) || ptype == ParallelType::Stream) {
loop_index = NamedScalar::getParallelIndex(ptype);
}

Expand All @@ -1363,22 +1379,6 @@ void IdModel::allocateLoopIndexVariables() {
continue;
}

if (GpuLower::current()->circularBufferInfo().isCircularBufferedIterDomain(
loop_group->front()->as<IterDomain>())) {
// Allocate index variable for each stage of the circular
// buffered loop.
auto indices = std::make_unique<CircularBufferIndices>();
for (auto i :
arange(static_cast<int>(CircularBufferLoopStage::EndOfStages))) {
indices->emplace(
static_cast<CircularBufferLoopStage>(i),
IrBuilder::create<Val>(DataType::Index));
}
circular_buffered_loop_index_variable_map_[loop_group] =
std::move(indices);
continue;
}

// If enabled, allocate own indices. Otherwise, use the one
// generated for ComputeAtMap for compatibility with the legacy
// indexing
Expand Down Expand Up @@ -1430,6 +1430,12 @@ Val* IdModel::getLoopIndexVariable(
// stage defined, and we just default to using the main stage index.
circular_buffer_loop_stage = CircularBufferLoopStage::Main;
}
NVF_ERROR(
circular_buffered_loop_index_variable_map_.contains(loop_group),
"Failed to find circular buffer index var for: ",
nvfuser::toString(loop_group),
", ",
loop_group->front()->toString());
return circular_buffered_loop_index_variable_map_.at(loop_group)
->at(circular_buffer_loop_stage);
} else {
Expand Down
17 changes: 17 additions & 0 deletions tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ class MatmulSchedulerTest : public NVFuserTest {
protected:
MatmulSchedulerTest() : optimization_guard_(false) {}

void SetUp() override {
NVFuserTest::SetUp();
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
}

private:
// Allocation order set by the pass breaks matmul tests
// see issue https://github.com/NVIDIA/Fuser/issues/1810
Expand Down Expand Up @@ -2482,6 +2487,11 @@ class MatmulSchedulerPluginTest : public NVFuserTest {
MatmulSchedulerPluginTest()
: optimization_guard_(false), factory_guard_(testConfigFactory) {}

void SetUp() override {
NVFuserTest::SetUp();
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
}

private:
// Allocation order set by the pass breaks matmul tests
// see issue https://github.com/NVIDIA/Fuser/issues/1810
Expand Down Expand Up @@ -2917,6 +2927,11 @@ class AllocationDomainTest
mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
}

void SetUp() override {
NVFuserFixtureParamTest::SetUp();
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
}

std::pair<TensorView*, TensorView*> getInputTVs(
int M,
int N,
Expand Down Expand Up @@ -3365,6 +3380,8 @@ class HopperPlusMatmulSchedulerTest
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = true;
mparams.circular_buffer_options.smem_circular_buffer_stage = 2;

EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
}

void TearDown() {
Expand Down
1 change: 1 addition & 0 deletions tests/cpp/test_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class StreamTest : public NVFuserTest {
public:
StreamTest() {
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
}
};

Expand Down
1 change: 1 addition & 0 deletions tests/cpp/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ void NVFuserTest::SetUp() {
if (!deviceMajorMinorCheck(6)) {
GTEST_SKIP() << "skipping tests on pre-PASCAL GPUs";
}
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
}

NVFuserTest::~NVFuserTest() {
Expand Down