Skip to content

Commit 8477d3d

Browse files
Enable fused RMSNorm dLN + add through CUDNN (NVIDIA#2778)
* add cudnn dln+add Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt> * try fixing cudnn build issue Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt> * guard against cudnn version Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change itype to wtype for add in rmsnorm_bwd Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt> * remove dead code Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt> * remove dangling todo Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt> --------- Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d2625e5 commit 8477d3d

3 files changed

Lines changed: 38 additions & 13 deletions

File tree

transformer_engine/common/normalization/common.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,23 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
395395
std::tie(_dx, _dgamma, _dbeta) = std::make_tuple(ret[0], ret[1], ret[2]);
396396
if (_dbeta != nullptr) NVTE_ERROR("cuDNN rmsnorm dbias incorrectly returned.");
397397
}
398+
// Fuse the add for BackwardAdd stage
399+
if (_norm_stage == NVTE_Norm_Stage::BackwardAdd) {
400+
NVTE_CHECK(cudnnGetVersion() >= 92100,
401+
"Fused BackwardAdd requires cuDNN >= 9.21.0, but found ", cudnnGetVersion());
402+
403+
_add = _graph.tensor(fe::graph::Tensor_attributes()
404+
.set_name("add")
405+
.set_dim({batch_dim, hidden_dim, 1, 1})
406+
.set_stride({hidden_dim, 1, hidden_dim, hidden_dim})
407+
.set_data_type(get_cudnn_fe_dtype(wtype)));
408+
auto add_options = fe::graph::Pointwise_attributes()
409+
.set_mode(fe::PointwiseMode_t::ADD)
410+
.set_compute_data_type(get_cudnn_fe_dtype(ctype));
411+
auto _dx_with_add = _graph.pointwise(_dx, _add, add_options);
412+
_dx->set_output(false).set_data_type(get_cudnn_fe_dtype(itype));
413+
_dx = _dx_with_add;
414+
}
398415
_dx->set_output(true).set_data_type(get_cudnn_fe_dtype(otype));
399416
_dgamma->set_output(true).set_data_type(get_cudnn_fe_dtype(otype));
400417
}
@@ -467,13 +484,16 @@ void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_
467484
void* rsigma_dptr, void* dx_dptr, void* dz_dptr,
468485
void* add_dptr, void* dbeta_dptr, void* dgamma_dptr,
469486
void* workspace_dptr, cudaStream_t stream) {
470-
// cuDNN does not currently support fused backward+add
471-
NVTE_CHECK(add_dptr == nullptr);
472-
473487
// Binding data pointers to graph tensors
474488
_variant_pack = {
475489
{_x, x_dptr}, {_rsigma, rsigma_dptr}, {_dz, dz_dptr}, {_dgamma, dgamma_dptr}, {_dx, dx_dptr}};
476490

491+
// Bind the add tensor for fused backward+add
492+
if (_norm_stage == NVTE_Norm_Stage::BackwardAdd) {
493+
NVTE_CHECK(add_dptr != nullptr, "add_dptr must not be null for BackwardAdd");
494+
_variant_pack.insert({{_add, add_dptr}});
495+
}
496+
477497
if (_zero_centered)
478498
_variant_pack.insert({{_scalar_offset, reinterpret_cast<void*>(this->_scalar_dptr.get())},
479499
{_gamma_zero, gamma_dptr}});

transformer_engine/common/normalization/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
294294
std::shared_ptr<fe::graph::Tensor_attributes> _z_mx_row, _z_mx_col, _sf_row, _sf_col;
295295
const bool _training;
296296
// BWD
297-
std::shared_ptr<fe::graph::Tensor_attributes> _dz, _dx, _dgamma, _dbeta;
297+
std::shared_ptr<fe::graph::Tensor_attributes> _dz, _dx, _dgamma, _dbeta, _add;
298298

299299
fe::graph::Graph _graph;
300300
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> _variant_pack;

transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -206,16 +206,21 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const
206206
CheckOutputTensor(*dgamma, "dgamma");
207207
}
208208

209-
// cuDNN does not currently support fused backward+add
210-
NVTE_Norm_Backend norm_backend = NVTE_Norm_Backend::Te;
211-
212-
// TE backend does not currently support zero_centered_gamma_in_weight_dtype
213-
NVTE_CHECK(!use_zero_centered_gamma_in_weight_dtype(),
214-
"zero_centered_gamma_in_weight_dtype is currently not supported for rmsnorm_bwd_add");
215-
216-
bool is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr,
217-
dz.data.dptr, dgamma->data.dptr, add.data.dptr);
209+
NVTE_Norm_Backend norm_backend;
210+
bool is_aligned = true;
218211
bool gamma_in_weight_dtype = false;
212+
if (use_cudnn_norm_bwd()) {
213+
norm_backend = NVTE_Norm_Backend::Cudnn;
214+
gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype();
215+
} else {
216+
norm_backend = NVTE_Norm_Backend::Te;
217+
// TE backend does not currently support zero_centered_gamma_in_weight_dtype
218+
NVTE_CHECK(!use_zero_centered_gamma_in_weight_dtype(),
219+
"zero_centered_gamma_in_weight_dtype is currently not supported "
220+
"for rmsnorm_bwd_add with TE backend");
221+
is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr,
222+
dz.data.dptr, dgamma->data.dptr, add.data.dptr);
223+
}
219224

220225
auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan(
221226
norm_backend, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::BackwardAdd,

0 commit comments

Comments
 (0)