Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion csrc/scheduler/pointwise_tma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,20 @@ namespace tma {
// TODO: This can be further relaxed to allow more tensor views with fewer
// dimensions, e.g., outer broadcast inputs [B, I] can also be loaded with TMA.
bool isTvSuitableForTma(const TensorView* tv, int64_t n_valid_dims) {
// When tv has broadcast dimensions, it may cause one of the following issues:
// 1. Merge of iteration domain with broadcast dimension.
// This is not supported by TMA and will trigger tma lowering validation
// error. The TMA domain must be equivalent to the allocation domain of the
// gmem tensor.
// 2. Right side or left side contains a single broadcast domain, which is
// against our 2D tile assumption. This restriction can be lifted if we
// further revise the scheduler.
if (std::any_of(
tv->getLogicalDomain().begin(),
tv->getLogicalDomain().end(),
[](const IterDomain* id) { return id->isBroadcast(); })) {
return false;
}
return scheduler_utils::nLogicalDims(tv) == n_valid_dims;
};

Expand Down Expand Up @@ -101,6 +115,17 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
fusion, prop, data_cache, /*is_tma =*/true);
params->break_point = bp_info.break_point;

// check which input is suitable for TMA load
// summarize bits of all suitable inputs
// if bits_per_element is 0, no suitable input is found, return nullptr
// Note: we can move this check upfront when dispatching to TMA and non-TMA
// versions, however, we may want to further extend to use break point for
// more fine-grained control of TMA loads, e.g. mixing 2D tile and 1D tile.
const int64_t bits_per_element = getInputBitsPerElement(prop);
if (bits_per_element == 0) {
return nullptr;
}

// ========== Step 1: Compute TMA Domain Dimensions ==========
// The TMA domain splits the entire problem into
// [tma_domain_outer, tma_domain_inner]
Expand Down Expand Up @@ -159,7 +184,6 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
constexpr int64_t cta_per_sm = 8;
const int64_t bits_per_sm = scheduler_utils::getRequiredBitsInFlight();
const int64_t bits_per_cta = bits_per_sm / cta_per_sm;
const int64_t bits_per_element = getInputBitsPerElement(prop);
if (bits_per_element == 0) {
return nullptr;
}
Comment on lines 187 to 189
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: redundant check - bits_per_element == 0 is already checked at line 125-127, so this condition will never be true

Suggested change
if (bits_per_element == 0) {
return nullptr;
}
// bits_per_element already validated at line 125-127

Expand Down
69 changes: 69 additions & 0 deletions tests/cpp/test_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2114,4 +2114,73 @@ TEST_F(TmaPointwiseTestF, MixedPrecisionIllegalTma) {
executor_cache.fusion(), out_tensors, {t0, t1}, __LINE__, __FILE__);
}

TEST_F(TmaPointwiseTestF, OuterDimOne) {
int64_t dim1 = 8192;
DataType dtype = DataType::Float;
auto fusion_ptr = std::make_unique<Fusion>();
auto fusion = fusion_ptr.get();
FusionGuard fg(fusion);
auto tv0 = makeContigConcreteTensor({1, dim1}, dtype);
auto tv1 = makeContigConcreteTensor({1, dim1}, dtype);
fusion->addInput(tv0);
fusion->addInput(tv1);
auto tv2 = add(tv0, tv1);
fusion->addOutput(tv2);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({1, dim1}, options);
auto t1 = at::randn({1, dim1}, options);

auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, {t0, t1});
auto pparams = cg_results.heuristic_params->as<PointwiseParams>();
EXPECT_FALSE(pparams->use_tma_load);
testValidate(fusion, cg_results.outputs, {t0, t1}, __LINE__, __FILE__);
}

TEST_F(TmaPointwiseTestF, InnerDimOne) {
int64_t dim0 = 8192;
DataType dtype = DataType::Float;
auto fusion_ptr = std::make_unique<Fusion>();
auto fusion = fusion_ptr.get();
FusionGuard fg(fusion);
auto tv0 = makeContigConcreteTensor({dim0, 1}, dtype);
auto tv1 = makeContigConcreteTensor({dim0, 1}, dtype);
fusion->addInput(tv0);
fusion->addInput(tv1);
auto tv2 = add(tv0, tv1);
fusion->addOutput(tv2);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({dim0, 1}, options);
auto t1 = at::randn({dim0, 1}, options);

auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, {t0, t1});
auto pparams = cg_results.heuristic_params->as<PointwiseParams>();
EXPECT_FALSE(pparams->use_tma_load);
testValidate(fusion, cg_results.outputs, {t0, t1}, __LINE__, __FILE__);
}

TEST_F(TmaPointwiseTestF, MiddleDimOne) {
int64_t dim0 = 8192;
int64_t dim2 = 1024;
DataType dtype = DataType::Float;
auto fusion_ptr = std::make_unique<Fusion>();
auto fusion = fusion_ptr.get();
FusionGuard fg(fusion);
auto tv0 = makeContigConcreteTensor({dim0, 1, dim2}, dtype);
auto tv1 = makeContigConcreteTensor({dim0, 1, dim2}, dtype);
fusion->addInput(tv0);
fusion->addInput(tv1);
auto tv2 = add(tv0, tv1);
fusion->addOutput(tv2);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({dim0, 1, dim2}, options);
auto t1 = at::randn({dim0, 1, dim2}, options);

auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, {t0, t1});
auto pparams = cg_results.heuristic_params->as<PointwiseParams>();
EXPECT_FALSE(pparams->use_tma_load);
testValidate(fusion, cg_results.outputs, {t0, t1}, __LINE__, __FILE__);
}
} // namespace nvfuser