From 368c6f1e75e72d99d5bbe06611683f941798a2d8 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Mon, 22 Dec 2025 09:29:35 -0800 Subject: [PATCH 1/2] avoid merging iter and bcast domains for tma --- csrc/scheduler/pointwise_tma.cpp | 26 +++++++++++- tests/cpp/test_pointwise.cpp | 69 ++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/csrc/scheduler/pointwise_tma.cpp b/csrc/scheduler/pointwise_tma.cpp index 4205af35fe8..34e88dc340d 100644 --- a/csrc/scheduler/pointwise_tma.cpp +++ b/csrc/scheduler/pointwise_tma.cpp @@ -56,6 +56,20 @@ namespace tma { // TODO: This can be further relaxed to allow more tensor views with fewer // dimensions, e.g., outer broadcast inputs [B, I] can also be loaded with TMA. bool isTvSuitableForTma(const TensorView* tv, int64_t n_valid_dims) { + // When tv has broadcast dimensions, it may cause one of the following issues: + // 1. Merge of iteration domain with broadcast dimension. + // This is not supported by TMA and will trigger tma lowering validation + // error. The TMA domain must be equivalent to the allocation domain of the + // gmem tensor. + // 2. Righ side or left side containes a single broadcast domain, which is + // againt our 2D tile assumption. This restriction can be lifted if further + // revise the scheduler. + if (std::any_of( + tv->getLogicalDomain().begin(), + tv->getLogicalDomain().end(), + [](const IterDomain* id) { return id->isBroadcast(); })) { + return false; + } return scheduler_utils::nLogicalDims(tv) == n_valid_dims; }; @@ -101,6 +115,17 @@ std::unique_ptr getPointwiseHeuristics( fusion, prop, data_cache, /*is_tma =*/true); params->break_point = bp_info.break_point; + // check which input is suitable for TMA load + // summarize bits of all suitable inputs + // if bits_per_element is 0, no suitable input is found, return nullptr + // Note: we can move this check upfront when dispatching to TMA and non-TMA + // versions, however, we may want to further extend to use break point for + // more fine-grained control of TMA loads, e.g. mixing 2D tile and 1D tile. + const int64_t bits_per_element = getInputBitsPerElement(prop); + if (bits_per_element == 0) { + return nullptr; + } + // ========== Step 1: Compute TMA Domain Dimensions ========== // The TMA domain splits the entire problem into // [tma_domain_outer, tma_domain_inner] @@ -159,7 +184,6 @@ std::unique_ptr getPointwiseHeuristics( constexpr int64_t cta_per_sm = 8; const int64_t bits_per_sm = scheduler_utils::getRequiredBitsInFlight(); const int64_t bits_per_cta = bits_per_sm / cta_per_sm; - const int64_t bits_per_element = getInputBitsPerElement(prop); if (bits_per_element == 0) { return nullptr; } diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 259baf663bc..1872275bb5a 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -2114,4 +2114,73 @@ TEST_F(TmaPointwiseTestF, MixedPrecisionIllegalTma) { executor_cache.fusion(), out_tensors, {t0, t1}, __LINE__, __FILE__); } +TEST_F(TmaPointwiseTestF, OuterDimOne) { + int64_t dim1 = 8192; + DataType dtype = DataType::Float; + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + auto tv0 = makeContigConcreteTensor({1, dim1}, dtype); + auto tv1 = makeContigConcreteTensor({1, dim1}, dtype); + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = add(tv0, tv1); + fusion->addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({1, dim1}, options); + auto t1 = at::randn({1, dim1}, options); + + auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, {t0, t1}); + auto pparams = cg_results.heuristic_params->as(); + EXPECT_FALSE(pparams->use_tma_load); + testValidate(fusion, cg_results.outputs, {t0, t1}, __LINE__, __FILE__); +} + +TEST_F(TmaPointwiseTestF, InnerDimOne) { + int64_t dim0 = 8192; + DataType dtype = DataType::Float; + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + auto tv0 = makeContigConcreteTensor({dim0, 1}, dtype); + auto tv1 = makeContigConcreteTensor({dim0, 1}, dtype); + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = add(tv0, tv1); + fusion->addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({dim0, 1}, options); + auto t1 = at::randn({dim0, 1}, options); + + auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, {t0, t1}); + auto pparams = cg_results.heuristic_params->as(); + EXPECT_FALSE(pparams->use_tma_load); + testValidate(fusion, cg_results.outputs, {t0, t1}, __LINE__, __FILE__); +} + +TEST_F(TmaPointwiseTestF, MiddleDimOne) { + int64_t dim0 = 8192; + int64_t dim2 = 1024; + DataType dtype = DataType::Float; + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + auto tv0 = makeContigConcreteTensor({dim0, 1, dim2}, dtype); + auto tv1 = makeContigConcreteTensor({dim0, 1, dim2}, dtype); + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = add(tv0, tv1); + fusion->addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({dim0, 1, dim2}, options); + auto t1 = at::randn({dim0, 1, dim2}, options); + + auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, {t0, t1}); + auto pparams = cg_results.heuristic_params->as(); + EXPECT_FALSE(pparams->use_tma_load); + testValidate(fusion, cg_results.outputs, {t0, t1}, __LINE__, __FILE__); +} } // namespace nvfuser From ea7f1c86ce2fda9a0477284d8b60f834ffaac392 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Mon, 22 Dec 2025 09:34:21 -0800 Subject: [PATCH 2/2] comment --- csrc/scheduler/pointwise_tma.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/scheduler/pointwise_tma.cpp b/csrc/scheduler/pointwise_tma.cpp index 34e88dc340d..1b93585340b 100644 --- a/csrc/scheduler/pointwise_tma.cpp +++ b/csrc/scheduler/pointwise_tma.cpp @@ -61,9 +61,9 @@ bool isTvSuitableForTma(const TensorView* tv, int64_t n_valid_dims) { // This is not supported by TMA and will trigger tma lowering validation // error. The TMA domain must be equivalent to the allocation domain of the // gmem tensor. - // 2. Righ side or left side containes a single broadcast domain, which is - // againt our 2D tile assumption. This restriction can be lifted if further - // revise the scheduler. + // 2. Right side or left side contains a single broadcast domain, which is + // against our 2D tile assumption. This restriction can be lifted if we + // further revise the scheduler. if (std::any_of( tv->getLogicalDomain().begin(), tv->getLogicalDomain().end(),