From 0364029c4c474b7ea1c1cce5680319fb52d60a21 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 1 Dec 2025 14:42:04 -0800 Subject: [PATCH 01/11] PoC of the changes Signed-off-by: Przemek Tredak --- tests/pytorch/attention/test_attention.py | 2 +- .../dot_product_attention.py | 2 +- transformer_engine/pytorch/module/base.py | 108 +++++---- .../pytorch/module/grouped_linear.py | 110 ++++----- .../pytorch/module/layernorm_linear.py | 156 ++++++------- .../pytorch/module/layernorm_mlp.py | 214 +++++++++--------- transformer_engine/pytorch/module/linear.py | 142 ++++++------ 7 files changed, 380 insertions(+), 354 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 4aedcff1b83..3ce9d7444b3 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2751,7 +2751,7 @@ def forward( cu_seqlens, max_s, ) -> torch.Tensor: - with self.prepare_forward(inp, num_gemms=3) as inp: + with self.prepare_forward_ctx(inp, num_gemms=3) as inp: out = _custom_mha_fp8.apply( inp, self.qkv_weight, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index f506035c1ef..e2e55849292 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1000,7 +1000,7 @@ def forward( cases. It is ignored for other backends and when context parallelism is enabled. """ - with self.prepare_forward( + with self.prepare_forward_ctx( query_layer, num_gemms=3, allow_non_contiguous=True, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index acf92332817..95ee4e2f32f 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -49,6 +49,7 @@ is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype, get_nvtx_range_context, + _nvtx_enabled, ) from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ...common.recipe import DelayedScaling, Recipe @@ -640,12 +641,15 @@ def __init__(self) -> None: "fp8_parameters", } + def fast_set_attr(self, name: str, value: Any) -> None: + self.__dict__[name] = value + def __setattr__(self, name: str, value: Any) -> None: if name in TransformerEngineBaseModule._fast_setattr_names: # torch.nn.Module has a custom __setattr__ that handles # modules, parameters, and buffers. This is unnecessary # overhead when setting plain attrs. - self.__dict__[name] = value + self.fast_set_attr(name, value) else: # Default case super().__setattr__(name, value) @@ -926,7 +930,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: """Get activation data type for AMP.""" # Native AMP (`torch.autocast`) gets highest priority if torch.is_autocast_enabled(): - self.activation_dtype = torch_get_autocast_gpu_dtype() + self.fast_set_attr("activation_dtype", torch_get_autocast_gpu_dtype()) return # All checks after this have already been performed once, thus skip @@ -941,7 +945,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: "Data types for parameters must match when outside of autocasted region. " f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" ) - self.activation_dtype = dtype + self.fast_set_attr("activation_dtype", dtype) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: """ @@ -970,48 +974,51 @@ def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: # assume FP8 execution. def init_fp8_metadata(self, num_gemms: int = 1) -> None: """Initialize fp8 related metadata and tensors during fprop.""" - _original_recipe = self.fp8_meta.get("recipe", None) - - self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() - self.fp8 = FP8GlobalStateManager.is_fp8_enabled() - self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() - fp8_enabled = self.fp8 or self.fp8_calibration - self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration - - if self.fp8_parameters or fp8_enabled: - if ( - self.fp8_initialized - and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"] - ): + meta = self.fp8_meta + + fp8 = FP8GlobalStateManager.is_fp8_enabled() + fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() + fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + self.fast_set_attr("fp8_parameters", fp8_parameters) + self.fast_set_attr("fp8", fp8) + self.fast_set_attr("fp8_calibration", fp8_calibration) + fp8_enabled = fp8 or fp8_calibration + meta["fp8_checkpoint"] = fp8_enabled + + _original_recipe = None + + if fp8_parameters or fp8_enabled: + _original_recipe = meta.get("recipe", None) + if self.fp8_initialized and FP8GlobalStateManager.get_fp8_recipe() == _original_recipe: # FP8 init has already been run and recipe is the same, don't do anything. return - self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() else: # If fp8 isn't enabled, turn off and return. - self.fp8_initialized = False + self.fast_set_attr("fp8_initialized", False) return - if self.fp8_parameters and not self.fp8_initialized: - self.fp8_meta["num_gemms"] = num_gemms - self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) + if fp8_parameters and not self.fp8_initialized: + meta["num_gemms"] = num_gemms + self.init_fp8_meta_tensors(meta["recipe"]) if fp8_enabled: # Set FP8 and other FP8 metadata - self.fp8_meta["num_gemms"] = num_gemms - self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() + meta["num_gemms"] = num_gemms + meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() # Set FP8_MAX per tensor according to recipe - if hasattr(self.fp8_meta["recipe"], "fp8_format"): - self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd - self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd + if hasattr(meta["recipe"], "fp8_format"): + meta["fp8_max_fwd"] = meta["recipe"].fp8_format.value.max_fwd + meta["fp8_max_bwd"] = meta["recipe"].fp8_format.value.max_bwd # Allocate scales and amaxes - self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) + self.init_fp8_meta_tensors(meta["recipe"]) self.fp8_initialized = True - self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() - _current_recipe = self.fp8_meta["recipe"] + _current_recipe = meta["recipe"] if _original_recipe is not None and not ( issubclass(_current_recipe.__class__, _original_recipe.__class__) or issubclass(_original_recipe.__class__, _current_recipe.__class__) @@ -1024,22 +1031,18 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # Clear cached workspaces as they were created with the old recipe/quantizer type self._fp8_workspaces.clear() - @contextmanager def prepare_forward( self, inp: torch.Tensor, num_gemms: int = 1, allow_non_contiguous: bool = False, allow_different_data_and_param_types: bool = False, - ) -> Generator[torch.Tensor, None, None]: - """Checks and prep for FWD. - The context manager is needed because there isn't a way for a module to know - if it's the last FP8 module in the forward autocast. It is useful - to setup the forward aggregated amax reduction for every module - just in case. The autocast exit will pick up the most recent one. - """ - self.allow_different_data_and_param_types = allow_different_data_and_param_types - self.forwarded_at_least_once = True + ) -> torch.Tensor: + """Checks and prepare for FWD execution.""" + self.fast_set_attr( + "allow_different_data_and_param_types", allow_different_data_and_param_types + ) + self.fast_set_attr("forwarded_at_least_once", True) # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): @@ -1070,13 +1073,32 @@ def prepare_forward( if self.training and is_fp8_activation_recompute_enabled(): FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) - with get_nvtx_range_context(self.__class__.__name__ + " forward"): - if not allow_non_contiguous and not inp.is_contiguous(): - inp = inp.contiguous() - yield inp + # with get_nvtx_range_context(self.__class__.__name__ + " forward"): + if _nvtx_enabled(): + torch.cuda.nvtx.range_push(self.__class__.__name__ + " forward") + if not allow_non_contiguous and not inp.is_contiguous(): + inp = inp.contiguous() + return inp + def end_forward(self): + delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) + if _nvtx_enabled(): + torch.cuda.nvtx.range_pop() + + @contextmanager + def prepare_forward_ctx( + self, + inp: torch.Tensor, + num_gemms: int = 1, + allow_non_contiguous: bool = False, + allow_different_data_and_param_types: bool = False, + ) -> Generator[torch.Tensor, None, None]: + yield self.prepare_forward( + inp, num_gemms, allow_non_contiguous, allow_different_data_and_param_types + ) + self.end_forward() def set_nccl_overlap_warning_if_tp(self) -> None: """When using TP, the NCCL communication needs to be scheduled diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 004c95c3725..42f5185248c 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -787,60 +787,62 @@ def forward( is_grad_enabled = torch.is_grad_enabled() - with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: - weight_tensors = self._get_weight_tensors() - bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] - - quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() - - if debug: - if self.no_debug_features_active(list(chain(*quantizers))): - debug = False - quantizers = self._get_quantizers() - - if isinstance(weight_tensors, QuantizedTensorStorage): - raise RuntimeError("FP8 weights are not supported in debug mode.") - - ( - input_quantizers, - weight_quantizers, - output_quantizers, - grad_input_quantizers, - grad_weight_quantizers, - grad_output_quantizers, - ) = quantizers - - if is_grad_enabled: - linear_fn = _GroupedLinear.apply - autograd_ctx = [] - else: - linear_fn = _GroupedLinear.forward - autograd_ctx = [None] - - non_tensor_args = ( - m_splits, - self.apply_bias, - is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - input_quantizers, - weight_quantizers, - output_quantizers, - grad_input_quantizers, - grad_weight_quantizers, - grad_output_quantizers, - self.fuse_wgrad_accumulation, - is_cpu_offload_enabled(), - self.sequence_parallel, - self.activation_dtype, - is_grad_enabled, - self, - None, # skip_fp8_weight_update - self.save_original_input, - debug, - ) - out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) + inp = self.prepare_forward(inp, num_gemms=self.num_gemms) + weight_tensors = self._get_weight_tensors() + bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] + + quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() + + if debug: + if self.no_debug_features_active(list(chain(*quantizers))): + debug = False + quantizers = self._get_quantizers() + + if isinstance(weight_tensors, QuantizedTensorStorage): + raise RuntimeError("FP8 weights are not supported in debug mode.") + + ( + input_quantizers, + weight_quantizers, + output_quantizers, + grad_input_quantizers, + grad_weight_quantizers, + grad_output_quantizers, + ) = quantizers + + if is_grad_enabled: + linear_fn = _GroupedLinear.apply + autograd_ctx = [] + else: + linear_fn = _GroupedLinear.forward + autograd_ctx = [None] + + non_tensor_args = ( + m_splits, + self.apply_bias, + is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.wgrad_store, + input_quantizers, + weight_quantizers, + output_quantizers, + grad_input_quantizers, + grad_weight_quantizers, + grad_output_quantizers, + self.fuse_wgrad_accumulation, + is_cpu_offload_enabled(), + self.sequence_parallel, + self.activation_dtype, + is_grad_enabled, + self, + None, # skip_fp8_weight_update + self.save_original_input, + debug, + ) + out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) + + self.end_forward() if self.return_bias: return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 667c199c49f..62a6681d0dc 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1514,87 +1514,89 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - with self.prepare_forward( + inp = self.prepare_forward( inp, allow_non_contiguous=False # removed .contiguous from inside the layer - ) as inp: + ) - # Get concatenated weight and bias tensors - weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + # Get concatenated weight and bias tensors + weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - quantizers = ( - self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - if not debug - else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ) - if debug: - if self.no_debug_features_active(quantizers): - debug = False - quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + quantizers = ( + self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + if not debug + else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) + ) + if debug: + if self.no_debug_features_active(quantizers): + debug = False + quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - ) = quantizers - - if is_grad_enabled: - fwd_fn = _LayerNormLinear.apply - autograd_ctx = [] - else: - fwd_fn = _LayerNormLinear.forward - autograd_ctx = [None] - non_tensor_args = ( - self.eps, - is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - self.fuse_wgrad_accumulation, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - is_cpu_offload_enabled(), - self.tp_group, - self.tp_size, - self.sequence_parallel, - self.tp_size > 1, - self.activation_dtype, - self.parallel_mode, - self.return_layernorm_output, - self.return_layernorm_output_gathered, - is_grad_enabled, - self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, - self.bwd_ln_sm_margin, - self.zero_centered_gamma, - self.normalization, - self.ub_overlap_ag_fprop, - self.ub_overlap_rs_fprop, - self.ub_overlap_ag_dgrad, - self.ub_overlap_rs_dgrad, - self.ub_bulk_wgrad, - self.ub_bulk_dgrad, - self.ub_name, - self.fsdp_group, - self, - skip_fp8_weight_update, - self.symmetric_ar_type, - debug, - ) - out = fwd_fn( - *autograd_ctx, - inp, - self.layer_norm_weight, - self.layer_norm_bias, - weight_tensor, - bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, - non_tensor_args, - ) + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers + + if is_grad_enabled: + fwd_fn = _LayerNormLinear.apply + autograd_ctx = [] + else: + fwd_fn = _LayerNormLinear.forward + autograd_ctx = [None] + non_tensor_args = ( + self.eps, + is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.wgrad_store, + self.fuse_wgrad_accumulation, + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + is_cpu_offload_enabled(), + self.tp_group, + self.tp_size, + self.sequence_parallel, + self.tp_size > 1, + self.activation_dtype, + self.parallel_mode, + self.return_layernorm_output, + self.return_layernorm_output_gathered, + is_grad_enabled, + self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, + self.bwd_ln_sm_margin, + self.zero_centered_gamma, + self.normalization, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_rs_dgrad, + self.ub_bulk_wgrad, + self.ub_bulk_dgrad, + self.ub_name, + self.fsdp_group, + self, + skip_fp8_weight_update, + self.symmetric_ar_type, + debug, + ) + out = fwd_fn( + *autograd_ctx, + inp, + self.layer_norm_weight, + self.layer_norm_bias, + weight_tensor, + bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, + non_tensor_args, + ) + + self.end_forward() if self.return_layernorm_output: out, ln_out = out diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 56e050fe886..e8b63b981aa 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -2052,115 +2052,117 @@ def forward( if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): fp8_output = True - with self.prepare_forward(inp, num_gemms=2) as inp: + inp = self.prepare_forward(inp, num_gemms=2) - quantizers = ( - self._get_quantizers(fp8_output, is_grad_enabled) - if not debug - else self._get_debug_quantizers(fp8_output, is_grad_enabled) - ) - if debug: - if self.no_debug_features_active(quantizers): - debug = False - quantizers = self._get_quantizers(fp8_output, is_grad_enabled) + quantizers = ( + self._get_quantizers(fp8_output, is_grad_enabled) + if not debug + else self._get_debug_quantizers(fp8_output, is_grad_enabled) + ) + if debug: + if self.no_debug_features_active(quantizers): + debug = False + quantizers = self._get_quantizers(fp8_output, is_grad_enabled) - # Get quantizers - ( - fc1_input_quantizer, - fc1_weight_quantizer, - fc1_output_quantizer, - fc1_grad_input_quantizer, - fc1_grad_weight_quantizer, - fc1_grad_output_quantizer, - fc2_input_quantizer, - fc2_weight_quantizer, - fc2_output_quantizer, - fc2_grad_input_quantizer, - fc2_grad_weight_quantizer, - fc2_grad_output_quantizer, - ) = quantizers - - # Get weight tensors - fc1_weight, fc2_weight = self._get_weight_tensors() - fc1_bias = self.fc1_bias if self.use_bias else None - fc2_bias = self.fc2_bias if self.use_bias else None - if not self.fp8: - if isinstance(fc1_weight, Float8Tensor): - fc1_weight = fc1_weight.dequantize() - if isinstance(fc2_weight, Float8Tensor): - fc2_weight = fc2_weight.dequantize() - - # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode - if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): - self.bias_gelu_nvfusion = False + # Get quantizers + ( + fc1_input_quantizer, + fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, + ) = quantizers - if is_grad_enabled: - fwd_fn = _LayerNormMLP.apply - autograd_ctx = [] - else: - fwd_fn = _LayerNormMLP.forward - autograd_ctx = [None] - - non_tensor_args = ( - self.eps, - is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - self.fuse_wgrad_accumulation, - fc1_input_quantizer, - fc1_weight_quantizer, - fc1_output_quantizer, - fc1_grad_input_quantizer, - fc1_grad_weight_quantizer, - fc1_grad_output_quantizer, - fc2_input_quantizer, - fc2_weight_quantizer, - fc2_output_quantizer, - fc2_grad_input_quantizer, - fc2_grad_weight_quantizer, - fc2_grad_output_quantizer, - is_cpu_offload_enabled(), - self.tp_group, - self.tp_size, - self.sequence_parallel, - self.tp_size > 1, - self.activation_dtype, - self.return_layernorm_output, - self.return_layernorm_output_gathered, - self.bias_gelu_nvfusion and not self.fp8 and not debug, - self.set_parallel_mode, - is_grad_enabled, - self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, - self.bwd_ln_sm_margin, - self.zero_centered_gamma, - self.activation, - self.activation_params, - self.normalization, - self.ub_overlap_ag, - self.ub_overlap_rs, - self.ub_overlap_rs_dgrad, - self.ub_bulk_dgrad, - self.ub_bulk_wgrad, - self.gemm_gelu_fusion and not debug, - self.fsdp_group, - self, - skip_fp8_weight_update, - self.symmetric_ar_type, - self.checkpoint, - debug, - ) - out = fwd_fn( - *autograd_ctx, - inp, - self.layer_norm_weight, - self.layer_norm_bias, - fc1_weight, - fc1_bias, - fc2_weight, - fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None, - non_tensor_args, - ) + # Get weight tensors + fc1_weight, fc2_weight = self._get_weight_tensors() + fc1_bias = self.fc1_bias if self.use_bias else None + fc2_bias = self.fc2_bias if self.use_bias else None + if not self.fp8: + if isinstance(fc1_weight, Float8Tensor): + fc1_weight = fc1_weight.dequantize() + if isinstance(fc2_weight, Float8Tensor): + fc2_weight = fc2_weight.dequantize() + + # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode + if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): + self.bias_gelu_nvfusion = False + + if is_grad_enabled: + fwd_fn = _LayerNormMLP.apply + autograd_ctx = [] + else: + fwd_fn = _LayerNormMLP.forward + autograd_ctx = [None] + + non_tensor_args = ( + self.eps, + is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.wgrad_store, + self.fuse_wgrad_accumulation, + fc1_input_quantizer, + fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, + is_cpu_offload_enabled(), + self.tp_group, + self.tp_size, + self.sequence_parallel, + self.tp_size > 1, + self.activation_dtype, + self.return_layernorm_output, + self.return_layernorm_output_gathered, + self.bias_gelu_nvfusion and not self.fp8 and not debug, + self.set_parallel_mode, + is_grad_enabled, + self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, + self.bwd_ln_sm_margin, + self.zero_centered_gamma, + self.activation, + self.activation_params, + self.normalization, + self.ub_overlap_ag, + self.ub_overlap_rs, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + self.gemm_gelu_fusion and not debug, + self.fsdp_group, + self, + skip_fp8_weight_update, + self.symmetric_ar_type, + self.checkpoint, + debug, + ) + out = fwd_fn( + *autograd_ctx, + inp, + self.layer_norm_weight, + self.layer_norm_bias, + fc1_weight, + fc1_bias, + fc2_weight, + fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None, + non_tensor_args, + ) + + self.end_forward() if self.return_layernorm_output: out, ln_out = out diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b65f7005eb3..760e809f8ce 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1397,81 +1397,79 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - with self.prepare_forward( - inp, - allow_non_contiguous=isinstance(inp, QuantizedTensor), - ) as inp: + inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) - weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - quantizers = ( - self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - if not debug - else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ) - if debug: - if self.no_debug_features_active(quantizers): - debug = False - quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + quantizers = ( + self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + if not debug + else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) + ) + if debug: + if self.no_debug_features_active(quantizers): + debug = False + quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - ) = quantizers - - if is_grad_enabled: - linear_fn = _Linear.apply - autograd_ctx = [] - else: - linear_fn = _Linear.forward - autograd_ctx = [None] - - non_tensor_args = ( - is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - self.fuse_wgrad_accumulation, - is_cpu_offload_enabled(), - self.tp_group, - self.tp_size, - self.sequence_parallel, - self.tp_size > 1, - self.activation_dtype, - self.parallel_mode, - is_grad_enabled, - self.ub_overlap_rs_fprop, - self.ub_overlap_ag_dgrad, - self.ub_overlap_ag_fprop, - self.ub_overlap_rs_dgrad, - self.ub_bulk_dgrad, - self.ub_bulk_wgrad, - self.ub_name, - fp8_output, - self.fsdp_group, - self, - skip_fp8_weight_update, - self.symmetric_ar_type, - self.save_original_input, - debug, - ) - out = linear_fn( - *autograd_ctx, - weight_tensor, - inp, - bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, - non_tensor_args, - ) + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers + + if is_grad_enabled: + linear_fn = _Linear.apply + autograd_ctx = [] + else: + linear_fn = _Linear.forward + autograd_ctx = [None] + + non_tensor_args = ( + is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.wgrad_store, + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + self.fuse_wgrad_accumulation, + is_cpu_offload_enabled(), + self.tp_group, + self.tp_size, + self.sequence_parallel, + self.tp_size > 1, + self.activation_dtype, + self.parallel_mode, + is_grad_enabled, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + self.ub_name, + fp8_output, + self.fsdp_group, + self, + skip_fp8_weight_update, + self.symmetric_ar_type, + self.save_original_input, + debug, + ) + out = linear_fn( + *autograd_ctx, + weight_tensor, + inp, + bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, + non_tensor_args, + ) + self.end_forward() if self.gemm_bias_unfused_add: out = out + cast_if_needed(bias_tensor, self.activation_dtype) From 70e916d7dbdea255be26182659be5c2843dc2b71 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 2 Dec 2025 14:45:07 -0800 Subject: [PATCH 02/11] Early exit from the Free function for the empty tensor Signed-off-by: Przemek Tredak --- transformer_engine/common/transformer_engine.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 8d9563b789f..3777ea980bd 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -450,9 +450,9 @@ class TensorAllocator { } void Free(NVTETensor t) { - std::lock_guard lock(mutex); uintptr_t index = reinterpret_cast(t); if (index == 0) return; + std::lock_guard lock(mutex); NVTE_CHECK(index <= memory.size(), "Invalid tensor."); free_list.push_back(index); // Clean up @@ -560,9 +560,9 @@ class GroupedTensorAllocator { } void Free(NVTEGroupedTensor t) { - std::lock_guard lock(mutex); uintptr_t index = reinterpret_cast(t); if (index == 0) return; + std::lock_guard lock(mutex); NVTE_CHECK(index <= memory.size(), "Invalid grouped tensor."); free_list.push_back(index); // Clean up From d308d43a58765178a3eec4469592e23ba0a8f6e5 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 2 Dec 2025 14:49:01 -0800 Subject: [PATCH 03/11] Use the proper function for nvtx range Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/module/base.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 95ee4e2f32f..9606febff5c 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -49,7 +49,8 @@ is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype, get_nvtx_range_context, - _nvtx_enabled, + nvtx_range_push, + nvtx_range_pop, ) from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ...common.recipe import DelayedScaling, Recipe @@ -1073,9 +1074,7 @@ def prepare_forward( if self.training and is_fp8_activation_recompute_enabled(): FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) - # with get_nvtx_range_context(self.__class__.__name__ + " forward"): - if _nvtx_enabled(): - torch.cuda.nvtx.range_push(self.__class__.__name__ + " forward") + nvtx_range_push(self.__class__.__name__ + " forward") if not allow_non_contiguous and not inp.is_contiguous(): inp = inp.contiguous() return inp @@ -1084,8 +1083,7 @@ def end_forward(self): delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) - if _nvtx_enabled(): - torch.cuda.nvtx.range_pop() + nvtx_range_pop() @contextmanager def prepare_forward_ctx( From 8ca4ca6fca4126e5fdeb79d4abc67623d413f876 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 12 Dec 2025 11:06:39 -0800 Subject: [PATCH 04/11] Only do mark_not_offload when the cpu_offloading is enabled Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/module/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 760e809f8ce..e6b25a5fe39 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -428,7 +428,7 @@ def forward( # weights if weights are externally touched outside this module ctx.weight_object = weight - mark_not_offload(weight, weightmat, bias) + mark_not_offload(weight, weightmat, bias) # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, From 0a7934b18168c9ef339cc501106dbb4e445594e5 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 15 Dec 2025 15:40:23 -0800 Subject: [PATCH 05/11] First pass on making the setattr issue not come back Signed-off-by: Przemek Tredak --- qa/L0_pytorch_debug_unittest/test.sh | 16 +++---- qa/L0_pytorch_unittest/test.sh | 52 +++++++++++----------- qa/L1_pytorch_distributed_unittest/test.sh | 24 +++++----- qa/L1_pytorch_onnx_unittest/test.sh | 2 +- qa/L1_pytorch_thunder_integration/test.sh | 2 +- transformer_engine/pytorch/module/base.py | 36 ++++++++------- 6 files changed, 67 insertions(+), 65 deletions(-) diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh index a176d21b15b..2420c3f54e8 100644 --- a/qa/L0_pytorch_debug_unittest/test.sh +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -28,16 +28,16 @@ mkdir -p "$XML_LOG_DIR" pip install pytest==8.2.1 || error_exit "Failed to install pytest" -pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_sanity.py" -pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_config.py" -pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_numerics.py" -pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_log.py" -NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_api_features.py" -pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py" +pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_sanity.py" +pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_config.py" +pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_numerics.py" +pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_log.py" +NVTE_TORCH_COMPILE=0 pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_api_features.py" +pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py" # standard sanity and numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py" -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py" +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py" +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 512c01db42b..202fdac0b00 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -24,32 +24,32 @@ mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" -NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" -NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" +NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" +NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index b1e3a3e15cb..ff5c30b54a8 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -22,16 +22,16 @@ mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" # debug tests @@ -42,9 +42,9 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_ : ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} -pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_distributed.xml $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" +pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_distributed.xml $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" # standard numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L1_pytorch_onnx_unittest/test.sh b/qa/L1_pytorch_onnx_unittest/test.sh index 303c5c281ab..668343c2122 100644 --- a/qa/L1_pytorch_onnx_unittest/test.sh +++ b/qa/L1_pytorch_onnx_unittest/test.sh @@ -6,4 +6,4 @@ : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py diff --git a/qa/L1_pytorch_thunder_integration/test.sh b/qa/L1_pytorch_thunder_integration/test.sh index edf3f2eb841..ecfde921c70 100644 --- a/qa/L1_pytorch_thunder_integration/test.sh +++ b/qa/L1_pytorch_thunder_integration/test.sh @@ -9,7 +9,7 @@ set -x mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.1.1 pytest-benchmark==5.1.0 -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest.xml ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py # Check return code # Note: Return code 5 is fine. Lightning tests are skipped on systems diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 9606febff5c..7b8573791c1 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -642,18 +642,20 @@ def __init__(self) -> None: "fp8_parameters", } - def fast_set_attr(self, name: str, value: Any) -> None: + def fast_setattr(self, name: str, value: Any) -> None: self.__dict__[name] = value + def module_setattr(self, name: str, value: Any) -> None: + super().__setattr__(name, value) + def __setattr__(self, name: str, value: Any) -> None: - if name in TransformerEngineBaseModule._fast_setattr_names: - # torch.nn.Module has a custom __setattr__ that handles - # modules, parameters, and buffers. This is unnecessary - # overhead when setting plain attrs. - self.fast_set_attr(name, value) - else: - # Default case - super().__setattr__(name, value) + warnings.warn( + """The default implementation of torch.nn.Module introduces significant CPU overhead + when setting attributes and is therefore not recommended. Please use the explicit calls + (fast_setattr for setting regular values and module_setattr for setting parameters, + children modules and buffers).""", + RuntimeWarning) + self.module_setattr(name, value) def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: """ @@ -931,7 +933,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: """Get activation data type for AMP.""" # Native AMP (`torch.autocast`) gets highest priority if torch.is_autocast_enabled(): - self.fast_set_attr("activation_dtype", torch_get_autocast_gpu_dtype()) + self.fast_setattr("activation_dtype", torch_get_autocast_gpu_dtype()) return # All checks after this have already been performed once, thus skip @@ -946,7 +948,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: "Data types for parameters must match when outside of autocasted region. " f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" ) - self.fast_set_attr("activation_dtype", dtype) + self.fast_setattr("activation_dtype", dtype) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: """ @@ -980,9 +982,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: fp8 = FP8GlobalStateManager.is_fp8_enabled() fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() - self.fast_set_attr("fp8_parameters", fp8_parameters) - self.fast_set_attr("fp8", fp8) - self.fast_set_attr("fp8_calibration", fp8_calibration) + self.fast_setattr("fp8_parameters", fp8_parameters) + self.fast_setattr("fp8", fp8) + self.fast_setattr("fp8_calibration", fp8_calibration) fp8_enabled = fp8 or fp8_calibration meta["fp8_checkpoint"] = fp8_enabled @@ -996,7 +998,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() else: # If fp8 isn't enabled, turn off and return. - self.fast_set_attr("fp8_initialized", False) + self.fast_setattr("fp8_initialized", False) return if fp8_parameters and not self.fp8_initialized: @@ -1040,10 +1042,10 @@ def prepare_forward( allow_different_data_and_param_types: bool = False, ) -> torch.Tensor: """Checks and prepare for FWD execution.""" - self.fast_set_attr( + self.fast_setattr( "allow_different_data_and_param_types", allow_different_data_and_param_types ) - self.fast_set_attr("forwarded_at_least_once", True) + self.fast_setattr("forwarded_at_least_once", True) # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): From a086631f03ffbe763ec31ff6dba0dd6d8973ace5 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 15 Dec 2025 15:51:30 -0800 Subject: [PATCH 06/11] Actually add pytest.ini Signed-off-by: Przemek Tredak --- tests/pytorch/pytest.ini | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 tests/pytorch/pytest.ini diff --git a/tests/pytorch/pytest.ini b/tests/pytorch/pytest.ini new file mode 100644 index 00000000000..49111713b80 --- /dev/null +++ b/tests/pytorch/pytest.ini @@ -0,0 +1,8 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +[pytest] +filterWarnings= + error:RuntimeWarning + From 1ef79cd9599710753d5c78fc95b0fd5e14f1d90e Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 15 Dec 2025 16:28:11 -0800 Subject: [PATCH 07/11] Changes to __init__ Signed-off-by: Przemek Tredak --- tests/pytorch/pytest.ini | 4 +- transformer_engine/pytorch/module/base.py | 50 +++++++++-------------- 2 files changed, 22 insertions(+), 32 deletions(-) diff --git a/tests/pytorch/pytest.ini b/tests/pytorch/pytest.ini index 49111713b80..e90989721b6 100644 --- a/tests/pytorch/pytest.ini +++ b/tests/pytorch/pytest.ini @@ -3,6 +3,6 @@ # See LICENSE for license information. [pytest] -filterWarnings= - error:RuntimeWarning +filterwarnings= + error::RuntimeWarning diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 7b8573791c1..9de2e6a5b1d 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -606,42 +606,32 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): def __init__(self) -> None: super().__init__() assert torch.cuda.is_available(), "TransformerEngine needs CUDA." - self.name = None - self.next_iter_when_debug_should_be_run = 0 - self.fp8_initialized = False - self.fp8 = False - self.fp8_calibration = False - self.fp8_meta = {} + self.fast_setattr("name", None) + self.fast_setattr("next_iter_when_debug_should_be_run", 0) + self.fast_setattr("fp8_initialized", False) + self.fast_setattr("fp8", False) + self.fast_setattr("fp8_calibration", False) + self.fast_setattr("fp8_meta", {}) self.fp8_meta["fp8_checkpoint"] = False self.fp8_meta["fp8_group"] = None - self.fp8_meta_tensors_initialized = False - self.quantizers = {"scaling_fwd": {}, "scaling_bwd": {}} - self.tp_group = None - self.tp_size = 1 - self.sequence_parallel = False - self.param_init_meta = {} - self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() - self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val() - self.fsdp_wrapped = False - self.fsdp_group = None - self._fp8_workspaces: Dict[str, QuantizedTensor] = {} - self.activation_dtype: Optional[torch.dtype] = None - self.wgrad_accumulation_and_reduce_hooks = [] - self.wgrad_store = None + self.fast_setattr("fp8_meta_tensors_initialized", False) + self.fast_setattr("quantizers", {"scaling_fwd": {}, "scaling_bwd": {}}) + self.fast_setattr("tp_group", None) + self.fast_setattr("tp_size", 1) + self.fast_setattr("sequence_parallel", False) + self.fast_setattr("param_init_meta", {}) + self.fast_setattr("primary_weights_in_fp8", FP8GlobalStateManager.with_fp8_parameters()) + self.fast_setattr("preserve_high_precision_init_val", FP8GlobalStateManager.with_high_precision_init_val()) + self.fast_setattr("fsdp_wrapped", False) + self.fast_setattr("fsdp_group", None) + self.fast_setattr("_fp8_workspaces", {}) + self.fast_setattr("activation_dtype", None) + self.fast_setattr("wgrad_accumulation_and_reduce_hooks", []) + self.fast_setattr("wgrad_store", None) if not TEDebugState.debug_enabled: TEDebugState.initialize() - # Names of attributes that can be set quickly (see __setattr__ - # method) - _fast_setattr_names: Set[str] = { - "activation_dtype", - "fp8", - "fp8_initialized", - "fp8_calibration", - "fp8_parameters", - } - def fast_setattr(self, name: str, value: Any) -> None: self.__dict__[name] = value From c4e380f404a17864e0dedc9b08a4aba011807796 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 15 Dec 2025 17:08:07 -0800 Subject: [PATCH 08/11] A different way Signed-off-by: Przemek Tredak --- .../dot_product_attention.py | 2 + transformer_engine/pytorch/module/base.py | 42 +++++++++---------- .../pytorch/module/grouped_linear.py | 2 + .../pytorch/module/layernorm_linear.py | 2 + .../pytorch/module/layernorm_mlp.py | 2 + transformer_engine/pytorch/module/linear.py | 2 + 6 files changed, 31 insertions(+), 21 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index e2e55849292..5cae557ca78 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -482,6 +482,8 @@ def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unuse self.register_load_state_dict_post_hook(remove_extra_states_check) + self.__setattr__ = self.default_setattr + def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 9de2e6a5b1d..8044f3761ad 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -606,28 +606,28 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): def __init__(self) -> None: super().__init__() assert torch.cuda.is_available(), "TransformerEngine needs CUDA." - self.fast_setattr("name", None) - self.fast_setattr("next_iter_when_debug_should_be_run", 0) - self.fast_setattr("fp8_initialized", False) - self.fast_setattr("fp8", False) - self.fast_setattr("fp8_calibration", False) - self.fast_setattr("fp8_meta", {}) + self.name = None + self.next_iter_when_debug_should_be_run = 0 + self.fp8_initialized = False + self.fp8 = False + self.fp8_calibration = False + self.fp8_meta = {} self.fp8_meta["fp8_checkpoint"] = False self.fp8_meta["fp8_group"] = None - self.fast_setattr("fp8_meta_tensors_initialized", False) - self.fast_setattr("quantizers", {"scaling_fwd": {}, "scaling_bwd": {}}) - self.fast_setattr("tp_group", None) - self.fast_setattr("tp_size", 1) - self.fast_setattr("sequence_parallel", False) - self.fast_setattr("param_init_meta", {}) - self.fast_setattr("primary_weights_in_fp8", FP8GlobalStateManager.with_fp8_parameters()) - self.fast_setattr("preserve_high_precision_init_val", FP8GlobalStateManager.with_high_precision_init_val()) - self.fast_setattr("fsdp_wrapped", False) - self.fast_setattr("fsdp_group", None) - self.fast_setattr("_fp8_workspaces", {}) - self.fast_setattr("activation_dtype", None) - self.fast_setattr("wgrad_accumulation_and_reduce_hooks", []) - self.fast_setattr("wgrad_store", None) + self.fp8_meta_tensors_initialized = False + self.quantizers = {"scaling_fwd": {}, "scaling_bwd": {}} + self.tp_group = None + self.tp_size = 1 + self.sequence_parallel = False + self.param_init_meta = {} + self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() + self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val() + self.fsdp_wrapped = False + self.fsdp_group = None + self._fp8_workspaces: Dict[str, QuantizedTensor] = {} + self.activation_dtype: Optional[torch.dtype] = None + self.wgrad_accumulation_and_reduce_hooks = [] + self.wgrad_store = None if not TEDebugState.debug_enabled: TEDebugState.initialize() @@ -638,7 +638,7 @@ def fast_setattr(self, name: str, value: Any) -> None: def module_setattr(self, name: str, value: Any) -> None: super().__setattr__(name, value) - def __setattr__(self, name: str, value: Any) -> None: + def default_setattr(self, name: str, value: Any) -> None: warnings.warn( """The default implementation of torch.nn.Module introduces significant CPU overhead when setting attributes and is therefore not recommended. Please use the explicit calls diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 42f5185248c..e064dba0af2 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -710,6 +710,8 @@ def __init__( if name in (f"weight{i}", f"bias{i}"): param.skip_backward_post_hook = True + self.__setattr__ = self.default_setattr + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 62a6681d0dc..481a3dca863 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1405,6 +1405,8 @@ def __init__( if name in self.weight_names or name in self.bias_names: param.skip_backward_post_hook = True + self.__setattr__ = self.default_setattr + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index e8b63b981aa..882734595b2 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1960,6 +1960,8 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) + self.__setattr__ = self.default_setattr + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index e6b25a5fe39..c13975e90cd 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1308,6 +1308,8 @@ def __init__( if name in self.weight_names or name in self.bias_names: param.skip_backward_post_hook = True + self.__setattr__ = self.default_setattr + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) From f57deec95192236e5924bf4b33ae2051bc577091 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Dec 2025 21:21:13 +0000 Subject: [PATCH 09/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 8044f3761ad..a9bd5e4e4e3 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -644,7 +644,8 @@ def default_setattr(self, name: str, value: Any) -> None: when setting attributes and is therefore not recommended. Please use the explicit calls (fast_setattr for setting regular values and module_setattr for setting parameters, children modules and buffers).""", - RuntimeWarning) + RuntimeWarning, + ) self.module_setattr(name, value) def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: From 8560fcff8455c05bf919a6629cc86bf15b2bdfb4 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 16 Dec 2025 15:58:06 -0800 Subject: [PATCH 10/11] WAR the fact that it is not possible to set __setattr__ dynamically Signed-off-by: Przemek Tredak --- .../dot_product_attention/dot_product_attention.py | 2 +- transformer_engine/pytorch/module/base.py | 8 +++++++- transformer_engine/pytorch/module/grouped_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/module/linear.py | 3 ++- 6 files changed, 13 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 5cae557ca78..b9ab967ec2b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -482,7 +482,7 @@ def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unuse self.register_load_state_dict_post_hook(remove_extra_states_check) - self.__setattr__ = self.default_setattr + self._default_setattr = self._warning_setattr def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index a9bd5e4e4e3..95c3acf78f3 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -638,7 +638,7 @@ def fast_setattr(self, name: str, value: Any) -> None: def module_setattr(self, name: str, value: Any) -> None: super().__setattr__(name, value) - def default_setattr(self, name: str, value: Any) -> None: + def _warning_setattr(self, name: str, value: Any) -> None: warnings.warn( """The default implementation of torch.nn.Module introduces significant CPU overhead when setting attributes and is therefore not recommended. Please use the explicit calls @@ -648,6 +648,12 @@ def default_setattr(self, name: str, value: Any) -> None: ) self.module_setattr(name, value) + def _default_setattr(self, name: str, value: Any) -> None: + return self.module_setattr(name, value) + + def __setattr__(self, name: str, value: Any) -> None: + return self._default_setattr(name, value) + def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: """ Delayed scaling only. diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index e064dba0af2..82d30c2cfe5 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -710,7 +710,7 @@ def __init__( if name in (f"weight{i}", f"bias{i}"): param.skip_backward_post_hook = True - self.__setattr__ = self.default_setattr + self._default_setattr = self._warning_setattr def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 481a3dca863..ae158a6383a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1405,7 +1405,7 @@ def __init__( if name in self.weight_names or name in self.bias_names: param.skip_backward_post_hook = True - self.__setattr__ = self.default_setattr + self._default_setattr = self._warning_setattr def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 882734595b2..dc9513b8697 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1960,7 +1960,7 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) - self.__setattr__ = self.default_setattr + self._default_setattr = self._warning_setattr def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index c13975e90cd..4c81e25154a 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1308,7 +1308,8 @@ def __init__( if name in self.weight_names or name in self.bias_names: param.skip_backward_post_hook = True - self.__setattr__ = self.default_setattr + self._default_setattr = self._warning_setattr + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" From b09823ed690304c4af87689f4118782088a516ad Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Dec 2025 00:00:07 +0000 Subject: [PATCH 11/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/linear.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 4c81e25154a..dd1d6b3846d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1310,7 +1310,6 @@ def __init__( self._default_setattr = self._warning_setattr - def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe)