diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e0f104bdb..c4d4f7940 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -10,7 +10,7 @@ on: pull_request: env: - PRIMUS_TURBO_COMMIT: 935a6314b1e6f57894c198a0b8def3053a8e903f # feat(mxfp4): Add mxfp4 gemm and fix scale calculate of mxfp8 (#179) + PRIMUS_TURBO_COMMIT: 5233748e9c5c5795a6484ab31ece47c442d29ec2 # feat(mxfp4): refactor gemm mxfp4 and mxfp8. fuse transpose, hadamard transform and quantization. (#195) jobs: code-lint: diff --git a/primus/backends/megatron/core/enums.py b/primus/backends/megatron/core/enums.py new file mode 100644 index 000000000..1dd3113a8 --- /dev/null +++ b/primus/backends/megatron/core/enums.py @@ -0,0 +1,14 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +import enum + + +class Fp4Recipe(str, enum.Enum): + """FP4 recipe names: nvfp4.""" + + nvfp4 = "nvfp4" + mxfp4 = "mxfp4" diff --git a/primus/backends/megatron/core/extensions/primus_turbo.py b/primus/backends/megatron/core/extensions/primus_turbo.py index d3953aa0d..620f00724 100644 --- a/primus/backends/megatron/core/extensions/primus_turbo.py +++ b/primus/backends/megatron/core/extensions/primus_turbo.py @@ -31,11 +31,14 @@ from megatron.core.utils import get_tensor_model_parallel_group_if_none from megatron.training.global_vars import get_args from primus_turbo.pytorch.core.low_precision import ( + Float4QuantConfig, Float8QuantConfig, Format, + ScaleDtype, ScalingGranularity, ScalingStrategy, check_fp8_support, + check_mxfp4_support, check_mxfp8_support, ) from torch import Tensor @@ -61,39 +64,97 @@ def use_split_wgrad_op(): return False -class PrimusTurboFloat8QuantConfig(Float8QuantConfig): +class PrimusTurboQuantConfig: + + def __init__( + self, + format: Format = Format.E4M3, + granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, + strategy: ScalingStrategy = ScalingStrategy.DYNAMIC, + scale_dtype: ScaleDtype = ScaleDtype.FP32, + block_size: int = None, + ): + self._is_fp4 = False + self._is_fp8 = False + if format == Format.E2M1_X2: + # FP4 + self._quant_config = Float4QuantConfig( + format=format, + granularity=granularity, + strategy=strategy, + scale_dtype=scale_dtype, + block_size=block_size, + ) + self._is_fp4 = True + else: + # FP8 + self._quant_config = Float8QuantConfig( + format=format, + granularity=granularity, + strategy=strategy, + scale_dtype=scale_dtype, + block_size=block_size, + ) + self._is_fp8 = True + + def is_fp4(self): + return self._is_fp4 + + def is_fp8(self): + return self._is_fp8 def block_scaling(self): - return self.granularity == ScalingGranularity.BLOCKWISE and self.strategy == ScalingStrategy.DYNAMIC + return ( + self._quant_config.granularity == ScalingGranularity.BLOCKWISE + and self._quant_config.strategy == ScalingStrategy.DYNAMIC + ) def current_scaling(self): - return self.granularity == ScalingGranularity.TENSORWISE and self.strategy == ScalingStrategy.DYNAMIC + return ( + self._quant_config.granularity == ScalingGranularity.TENSORWISE + and self._quant_config.strategy == ScalingStrategy.DYNAMIC + ) def mxfp8_scaling(self): # NOTE: The mxfp8 recipe only support e4m3 format in megatron-lm backend. return ( - self.granularity == ScalingGranularity.MX_BLOCKWISE - and self.strategy == ScalingStrategy.DYNAMIC - and self.format == Format.E4M3 + self._quant_config.granularity == ScalingGranularity.MX_BLOCKWISE + and self._quant_config.strategy == ScalingStrategy.DYNAMIC + and self._quant_config.format == Format.E4M3 ) + def mxfp4_scaling(self): + return ( + self._quant_config.granularity == ScalingGranularity.MX_BLOCKWISE + and self._quant_config.strategy == ScalingStrategy.DYNAMIC + and self._quant_config.format == Format.E2M1_X2 + and self._quant_config.scale_dtype == ScaleDtype.E8M0 + ) -class PrimusTurboFP8GlobalStateManager(FP8GlobalStateManager): - PRIMUS_TURBO_FP8_QUANT_CONFIG: PrimusTurboFloat8QuantConfig = None + +class PrimusTurboLowPrecisionGlobalStateManager(FP8GlobalStateManager): + PRIMUS_TURBO_QUANT_CONFIG: PrimusTurboQuantConfig = None PRIMUS_TURBO_FP8_ENABLED: bool = False + PRIMUS_TURBO_FP4_ENABLED: bool = False @classmethod def is_turbo_fp8_enabled(cls) -> bool: """Is FP8 enabled""" return cls.PRIMUS_TURBO_FP8_ENABLED + @classmethod + def is_turbo_fp4_enabled(cls) -> bool: + """Is FP4 enabled""" + return cls.PRIMUS_TURBO_FP4_ENABLED + @classmethod def reset(cls) -> None: """Reset the global state""" FP8GlobalStateManager.reset() cls.PRIMUS_TURBO_FP8_ENABLED = False - cls.PRIMUS_TURBO_FP8_QUANT_CONFIG = None + cls.PRIMUS_TURBO_FP4_ENABLED = False + cls.PRIMUS_TURBO_QUANT_CONFIG = None @classmethod def fp8_autocast_enter( @@ -104,7 +165,7 @@ def fp8_autocast_enter( fp8_group: Optional[dist_group_type] = None, _graph: bool = False, enabled_turbo: bool = False, - turbo_fp8_quant_config: Optional[PrimusTurboFloat8QuantConfig] = None, + turbo_quant_config: Optional[PrimusTurboQuantConfig] = None, ) -> None: FP8GlobalStateManager.fp8_autocast_enter( enabled=enabled, @@ -114,29 +175,32 @@ def fp8_autocast_enter( _graph=_graph, ) - turbo_fp8_quant_config = ( - PrimusTurboFloat8QuantConfig() if turbo_fp8_quant_config is None else turbo_fp8_quant_config - ) + # Default is fp8 tensorwise + turbo_quant_config = PrimusTurboQuantConfig() if turbo_quant_config is None else turbo_quant_config - cls.PRIMUS_TURBO_FP8_ENABLED = enabled_turbo - cls.PRIMUS_TURBO_FP8_QUANT_CONFIG = turbo_fp8_quant_config + cls.PRIMUS_TURBO_FP8_ENABLED = enabled_turbo and turbo_quant_config.is_fp8() + cls.PRIMUS_TURBO_FP4_ENABLED = enabled_turbo and turbo_quant_config.is_fp4() + cls.PRIMUS_TURBO_QUANT_CONFIG = turbo_quant_config if enabled_turbo: fp8_available, reason_for_no_fp8 = check_fp8_support() assert fp8_available, reason_for_no_fp8 - if turbo_fp8_quant_config.mxfp8_scaling(): + if turbo_quant_config.mxfp8_scaling(): mxfp8_available, reason_for_no_mxfp8 = check_mxfp8_support() assert mxfp8_available, reason_for_no_mxfp8 + if turbo_quant_config.mxfp4_scaling(): + mxfp4_available, reason_for_no_mxfp4 = check_mxfp4_support() + assert mxfp4_available, reason_for_no_mxfp4 @classmethod - def get_turbo_fp8_quant_config(cls) -> PrimusTurboFloat8QuantConfig: - """Return the turbo's fp8 quant_config""" - return cls.PRIMUS_TURBO_FP8_QUANT_CONFIG + def get_turbo_quant_config(cls) -> PrimusTurboQuantConfig: + """Return the turbo's quant_config""" + return cls.PRIMUS_TURBO_QUANT_CONFIG @classmethod def get_fp8_autocast_state( cls, - ) -> Tuple[bool, bool, Recipe, dist_group_type, bool, bool, PrimusTurboFloat8QuantConfig]: + ) -> Tuple[bool, bool, Recipe, dist_group_type, bool, bool, PrimusTurboQuantConfig]: """FP8 autocast state getter""" return ( cls.FP8_ENABLED, @@ -146,15 +210,14 @@ def get_fp8_autocast_state( cls.IS_FIRST_FP8_MODULE, cls.FP8_GRAPH_CAPTURING, cls.PRIMUS_TURBO_FP8_ENABLED, - cls.PRIMUS_TURBO_FP8_QUANT_CONFIG, + cls.PRIMUS_TURBO_FP4_ENABLED, + cls.PRIMUS_TURBO_QUANT_CONFIG, ) @classmethod def set_fp8_autocast_state( cls, - fp8_state: Tuple[ - bool, bool, DelayedScaling, dist_group_type, bool, bool, PrimusTurboFloat8QuantConfig - ], + fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool, bool, PrimusTurboQuantConfig], ) -> None: """FP8 autocast state setter""" ( @@ -165,7 +228,8 @@ def set_fp8_autocast_state( cls.IS_FIRST_FP8_MODULE, cls.FP8_GRAPH_CAPTURING, cls.PRIMUS_TURBO_FP8_ENABLED, - cls.PRIMUS_TURBO_FP8_QUANT_CONFIG, + cls.PRIMUS_TURBO_FP4_ENABLED, + cls.PRIMUS_TURBO_QUANT_CONFIG, ) = fp8_state @@ -177,23 +241,23 @@ def primus_turbo_fp8_autocast( fp8_group: Optional[dist_group_type] = None, _graph: bool = False, enabled_turbo: bool = False, - turbo_fp8_quant_config: Optional[PrimusTurboFloat8QuantConfig] = None, + turbo_quant_config: Optional[PrimusTurboQuantConfig] = None, ) -> None: # type: ignore - fp8_state = PrimusTurboFP8GlobalStateManager.get_fp8_autocast_state() - PrimusTurboFP8GlobalStateManager.fp8_autocast_enter( + fp8_state = PrimusTurboLowPrecisionGlobalStateManager.get_fp8_autocast_state() + PrimusTurboLowPrecisionGlobalStateManager.fp8_autocast_enter( enabled=enabled, calibrating=calibrating, fp8_recipe=fp8_recipe, fp8_group=fp8_group, _graph=_graph, enabled_turbo=enabled_turbo, - turbo_fp8_quant_config=turbo_fp8_quant_config, + turbo_quant_config=turbo_quant_config, ) try: yield finally: - PrimusTurboFP8GlobalStateManager.set_fp8_autocast_state(fp8_state) - PrimusTurboFP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph) + PrimusTurboLowPrecisionGlobalStateManager.set_fp8_autocast_state(fp8_state) + PrimusTurboLowPrecisionGlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph) class PrimusTurboAttention(te.pytorch.DotProductAttention): @@ -419,8 +483,8 @@ def forward( input_ = input_.contiguous() input_ = input_.view(-1, original_shape[-1]) - if PrimusTurboFP8GlobalStateManager.is_turbo_fp8_enabled(): - quant_config = PrimusTurboFP8GlobalStateManager.get_turbo_fp8_quant_config() + if PrimusTurboLowPrecisionGlobalStateManager.is_turbo_fp8_enabled(): + quant_config = PrimusTurboLowPrecisionGlobalStateManager.get_turbo_quant_config() if quant_config.block_scaling(): fp8_gemm = pt.ops.gemm_fp8_blockwise elif quant_config.current_scaling() or quant_config.mxfp8_scaling(): @@ -429,6 +493,14 @@ def forward( raise ValueError("Not support quant config.") out = fp8_gemm(input_, weights, trans_a=False, trans_b=True, out_dtype=None, config=quant_config) + elif PrimusTurboLowPrecisionGlobalStateManager.is_turbo_fp4_enabled(): + quant_config = PrimusTurboLowPrecisionGlobalStateManager.get_turbo_quant_config() + if quant_config.mxfp4_scaling(): + fp4_gemm = pt.ops.gemm_fp4 + else: + raise ValueError("Not support quant config.") + + out = fp4_gemm(input_, weights, trans_a=False, trans_b=True, out_dtype=None, config=quant_config) else: out = self.gemm(input_, weights) @@ -519,8 +591,8 @@ def forward( input_ = input_.contiguous() input_ = input_.view(-1, original_shape[-1]) - if PrimusTurboFP8GlobalStateManager.is_turbo_fp8_enabled(): - quant_config = PrimusTurboFP8GlobalStateManager.get_turbo_fp8_quant_config() + if PrimusTurboLowPrecisionGlobalStateManager.is_turbo_fp8_enabled(): + quant_config = PrimusTurboLowPrecisionGlobalStateManager.get_turbo_quant_config() if quant_config.block_scaling(): fp8_gemm = pt.ops.gemm_fp8_blockwise elif quant_config.current_scaling() or quant_config.mxfp8_scaling(): @@ -529,6 +601,14 @@ def forward( raise ValueError("Not support quant config.") out = fp8_gemm(input_, weights, trans_a=False, trans_b=True, out_dtype=None, config=quant_config) + elif PrimusTurboLowPrecisionGlobalStateManager.is_turbo_fp4_enabled(): + quant_config = PrimusTurboLowPrecisionGlobalStateManager.get_turbo_quant_config() + if quant_config.mxfp4_scaling(): + fp4_gemm = pt.ops.gemm_fp4 + else: + raise ValueError("Not support quant config.") + + out = fp4_gemm(input_, weights, trans_a=False, trans_b=True, out_dtype=None, config=quant_config) else: out = self.gemm(input_, weights) @@ -615,8 +695,8 @@ def forward( input_ = input_.contiguous() input_ = input_.view(-1, original_shape[-1]) - if PrimusTurboFP8GlobalStateManager.is_turbo_fp8_enabled(): - quant_config = PrimusTurboFP8GlobalStateManager.get_turbo_fp8_quant_config() + if PrimusTurboLowPrecisionGlobalStateManager.is_turbo_fp8_enabled(): + quant_config = PrimusTurboLowPrecisionGlobalStateManager.get_turbo_quant_config() if quant_config.block_scaling(): fp8_gemm = pt.ops.gemm_fp8_blockwise elif quant_config.current_scaling() or quant_config.mxfp8_scaling(): @@ -625,6 +705,14 @@ def forward( raise ValueError("Not support quant config.") out = fp8_gemm(input_, weights, trans_a=False, trans_b=True, out_dtype=None, config=quant_config) + elif PrimusTurboLowPrecisionGlobalStateManager.is_turbo_fp4_enabled(): + quant_config = PrimusTurboLowPrecisionGlobalStateManager.get_turbo_quant_config() + if quant_config.mxfp4_scaling(): + fp4_gemm = pt.ops.gemm_fp4 + else: + raise ValueError("Not support quant config.") + + out = fp4_gemm(input_, weights, trans_a=False, trans_b=True, out_dtype=None, config=quant_config) else: out = self.gemm(input_, weights) @@ -706,8 +794,8 @@ def forward( input_ = input_.contiguous() input_ = input_.view(-1, original_shape[-1]) - if PrimusTurboFP8GlobalStateManager.is_turbo_fp8_enabled(): - quant_config = PrimusTurboFP8GlobalStateManager.get_turbo_fp8_quant_config() + if PrimusTurboLowPrecisionGlobalStateManager.is_turbo_fp8_enabled(): + quant_config = PrimusTurboLowPrecisionGlobalStateManager.get_turbo_quant_config() if quant_config.block_scaling(): fp8_gemm = pt.ops.gemm_fp8_blockwise elif quant_config.current_scaling() or quant_config.mxfp8_scaling(): @@ -716,6 +804,14 @@ def forward( raise ValueError("Not support quant config.") out = fp8_gemm(input_, weight, trans_a=False, trans_b=True, out_dtype=None, config=quant_config) + elif PrimusTurboLowPrecisionGlobalStateManager.is_turbo_fp4_enabled(): + quant_config = PrimusTurboLowPrecisionGlobalStateManager.get_turbo_quant_config() + if quant_config.mxfp4_scaling(): + fp4_gemm = pt.ops.gemm_fp4 + else: + raise ValueError("Not support quant config.") + + out = fp4_gemm(input_, weights, trans_a=False, trans_b=True, out_dtype=None, config=quant_config) else: out = self.gemm(input_, weight) out = out.view(original_shape[0], original_shape[1], -1) @@ -829,8 +925,8 @@ def forward(self, x): norm_out = norm_out.contiguous() inp = norm_out.view(-1, original_shape[-1]) - if PrimusTurboFP8GlobalStateManager.is_turbo_fp8_enabled(): - quant_config = PrimusTurboFP8GlobalStateManager.get_turbo_fp8_quant_config() + if PrimusTurboLowPrecisionGlobalStateManager.is_turbo_fp8_enabled(): + quant_config = PrimusTurboLowPrecisionGlobalStateManager.get_turbo_quant_config() if quant_config.block_scaling(): fp8_gemm = pt.ops.gemm_fp8_blockwise elif quant_config.current_scaling() or quant_config.mxfp8_scaling(): @@ -839,6 +935,14 @@ def forward(self, x): raise ValueError("Not support quant config.") out = fp8_gemm(inp, weights, trans_a=False, trans_b=True, out_dtype=None, config=quant_config) + elif PrimusTurboLowPrecisionGlobalStateManager.is_turbo_fp4_enabled(): + quant_config = PrimusTurboLowPrecisionGlobalStateManager.get_turbo_quant_config() + if quant_config.mxfp4_scaling(): + fp4_gemm = pt.ops.gemm_fp4 + else: + raise ValueError("Not support quant config.") + + out = fp4_gemm(input_, weights, trans_a=False, trans_b=True, out_dtype=None, config=quant_config) else: out = self.gemm(inp, weights) @@ -930,8 +1034,8 @@ def forward( tokens_per_expert = tokens_per_expert.to(w1.device) assert w1.is_contiguous(), "w1 must be contiguous" assert w2.is_contiguous(), "w2 must be contiguous" - if PrimusTurboFP8GlobalStateManager.is_turbo_fp8_enabled(): - quant_config = PrimusTurboFP8GlobalStateManager.get_turbo_fp8_quant_config() + if PrimusTurboLowPrecisionGlobalStateManager.is_turbo_fp8_enabled(): + quant_config = PrimusTurboLowPrecisionGlobalStateManager.get_turbo_quant_config() fc1_output = pt.ops.grouped_gemm_fp8( permuted_local_hidden_states, w1, tokens_per_expert, trans_b=False, config=quant_config ) @@ -951,8 +1055,8 @@ def forward( intermediate_parallel = self.activation_checkpoint.checkpoint( self.activation_func_with_probs, fc1_output, permuted_probs.unsqueeze(-1) ) - if PrimusTurboFP8GlobalStateManager.is_turbo_fp8_enabled(): - quant_config = PrimusTurboFP8GlobalStateManager.get_turbo_fp8_quant_config() + if PrimusTurboLowPrecisionGlobalStateManager.is_turbo_fp8_enabled(): + quant_config = PrimusTurboLowPrecisionGlobalStateManager.get_turbo_quant_config() fc2_output = pt.ops.grouped_gemm_fp8( intermediate_parallel, w2, tokens_per_expert, trans_b=False, config=quant_config ) @@ -970,8 +1074,8 @@ def forward( intermediate_parallel = self.activation_func_with_probs( fc1_output, permuted_probs.unsqueeze(-1) ) - if PrimusTurboFP8GlobalStateManager.is_turbo_fp8_enabled(): - quant_config = PrimusTurboFP8GlobalStateManager.get_turbo_fp8_quant_config() + if PrimusTurboLowPrecisionGlobalStateManager.is_turbo_fp8_enabled(): + quant_config = PrimusTurboLowPrecisionGlobalStateManager.get_turbo_quant_config() fc2_output = pt.ops.grouped_gemm_fp8( intermediate_parallel, w2, tokens_per_expert, trans_b=False, config=quant_config ) diff --git a/primus/backends/megatron/core/fp4_utils.py b/primus/backends/megatron/core/fp4_utils.py new file mode 100644 index 000000000..1a6336948 --- /dev/null +++ b/primus/backends/megatron/core/fp4_utils.py @@ -0,0 +1,209 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + + +"""Utility functions related to FP4 that are used throughout Megatron core""" + +from contextlib import nullcontext + +from megatron.core import parallel_state +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import is_te_min_version + +from primus.backends.megatron.core.enums import Fp4Recipe +from primus.modules.module_utils import warning_rank_0 + +# Check if Transformer Engine is installed +HAVE_TE = False +try: + import transformer_engine # pylint: disable=W0611 + + HAVE_TE = True +except (ImportError, ModuleNotFoundError): + # Transformer Engine not found + pass + +# Check if Primus-Turbo is installed +HAVE_TURBO = False +try: + import primus_turbo # pylint: disable=W0611 + + HAVE_TURBO = True +except (ImportError, ModuleNotFoundError): + # Primus-Turbo not found + pass + + +WARN_ONCE = True + + +if HAVE_TE and HAVE_TURBO: + from primus_turbo.pytorch.core.low_precision import ( + Format, + ScaleDtype, + ScalingGranularity, + ) + + from primus.backends.megatron.core.extensions.primus_turbo import ( + PrimusTurboQuantConfig, + ) + + def get_fp4_recipe(config: TransformerConfig): + """Return fp4 recipe.""" + fp4_recipe = None + fp4_recipe_none_reason = "" + if is_te_min_version("2.7.0.dev0"): + if config.fp4_recipe == Fp4Recipe.nvfp4: + try: + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + except AttributeError: + fp4_recipe_none_reason = "NVFP4BlockScaling recipe is not available in this version of Transformer Engine. Please make sure you are using TE version >= 2.7.0.dev0." + else: + fp4_recipe_none_reason = "NVFP4BlockScaling is the only supported FP4 recipe. Please make sure you are using a compatible TE version >= 2.7.0.dev0." + else: + fp4_recipe_none_reason = ( + "FP4 support requires TransformerEngine version >= 2.7.0.dev0 for NVFP4BlockScaling." + ) + + return fp4_recipe, fp4_recipe_none_reason + + def get_fp4_quant_config(config: TransformerConfig): + """Return fp4 quant config.""" + fp4_quant_config = None + fp4_quant_config_none_reason = "" + if config.fp4_recipe == Fp4Recipe.mxfp4: + fp4_quant_config = PrimusTurboQuantConfig( + granularity=ScalingGranularity.MX_BLOCKWISE, + format=Format.E2M1_X2, + block_size=32, + scale_dtype=ScaleDtype.E8M0, + ) + else: + fp4_quant_config_none_reason = "Only MXFP4 is supported in Primus-Turbo." + + return fp4_quant_config, fp4_quant_config_none_reason + + def get_fp4_context(config: TransformerConfig, layer_no: int = -1, is_init: bool = False): + """Return fp4 context manager.""" + num_bf16_layers_at_start = config.num_layers_at_start_in_bf16 if config.first_last_layers_bf16 else 0 + num_bf16_layers_at_end = config.num_layers_at_end_in_bf16 if config.first_last_layers_bf16 else 0 + is_first_layer = layer_no < num_bf16_layers_at_start + is_last_layer = layer_no >= config.num_layers - num_bf16_layers_at_end + + need_fp4_context = config.fp4 if not is_init else config.fp4_param + + if not need_fp4_context: + fp4_context = nullcontext() + elif layer_no >= 0 and config.first_last_layers_bf16 and (is_first_layer or is_last_layer): + fp4_context = nullcontext() + else: + fp4_recipe, fp4_recipe_none_reason = get_fp4_recipe(config) + fp4_quant_config, fp4_quant_config_none_reason = get_fp4_quant_config(config) + + global WARN_ONCE + if WARN_ONCE: + if fp4_recipe is None: + warning_rank_0( + f"TransformerEngine FP4 {config.fp4_recipe} not work since {fp4_recipe_none_reason}" + ) + if fp4_quant_config is None: + warning_rank_0( + f"Primus-Turbo FP4 {config.fp4_recipe} not work since {fp4_quant_config_none_reason}" + ) + WARN_ONCE = False + + fp4_group = None + if parallel_state.model_parallel_is_initialized(): + fp4_group = parallel_state.get_amax_reduction_group( + with_context_parallel=True, tp_only_amax_red=config.tp_only_amax_red + ) + + if not is_init: + # TE currently uses fp8_autocast for fp8 and fp4 quantization. + fp4_context = transformer_engine.pytorch.fp8_autocast( + enabled=True, fp8_recipe=fp4_recipe, fp8_group=fp4_group + ) + else: + import inspect + + context_args = {"enabled": True} + if "recipe" in inspect.signature(transformer_engine.pytorch.fp8_model_init).parameters: + context_args["recipe"] = fp4_recipe + fp4_context = transformer_engine.pytorch.fp8_model_init(**context_args) + + return fp4_context + +elif HAVE_TE: + + def get_fp4_recipe(config: TransformerConfig): + """Return fp4 recipe.""" + if is_te_min_version("2.7.0.dev0"): + if config.fp4_recipe == Fp4Recipe.nvfp4: + try: + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + except AttributeError: + raise ValueError( + """NVFP4BlockScaling recipe is not available in this version of + Transformer Engine. Please make sure you are using TE version + >= 2.7.0.dev0.""" + ) + else: + raise ValueError( + "NVFP4BlockScaling is the only supported FP4 recipe. " + "Please make sure you are using a compatible TE version >= 2.7.0.dev0." + ) + else: + raise ValueError( + """FP4 support requires TransformerEngine version >= 2.7.0.dev0 + for NVFP4BlockScaling.""" + ) + return fp4_recipe + + def get_fp4_context(config: TransformerConfig, layer_no: int = -1, is_init: bool = False): + """Return fp4 context manager.""" + num_bf16_layers_at_start = config.num_layers_at_start_in_bf16 if config.first_last_layers_bf16 else 0 + num_bf16_layers_at_end = config.num_layers_at_end_in_bf16 if config.first_last_layers_bf16 else 0 + is_first_layer = layer_no < num_bf16_layers_at_start + is_last_layer = layer_no >= config.num_layers - num_bf16_layers_at_end + + need_fp4_context = config.fp4 if not is_init else config.fp4_param + + if not need_fp4_context: + fp4_context = nullcontext() + elif layer_no >= 0 and config.first_last_layers_bf16 and (is_first_layer or is_last_layer): + fp4_context = nullcontext() + else: + fp4_recipe = get_fp4_recipe(config) + fp4_group = None + if parallel_state.model_parallel_is_initialized(): + fp4_group = parallel_state.get_amax_reduction_group( + with_context_parallel=True, tp_only_amax_red=config.tp_only_amax_red + ) + + if not is_init: + # TE currently uses fp8_autocast for fp8 and fp4 quantization. + fp4_context = transformer_engine.pytorch.fp8_autocast( + enabled=True, fp8_recipe=fp4_recipe, fp8_group=fp4_group + ) + else: + import inspect + + context_args = {"enabled": True} + if "recipe" in inspect.signature(transformer_engine.pytorch.fp8_model_init).parameters: + context_args["recipe"] = fp4_recipe + fp4_context = transformer_engine.pytorch.fp8_model_init(**context_args) + + return fp4_context + +else: + + def get_fp4_recipe(config: TransformerConfig): + """Return None when Transformer Engine is not available.""" + return None + + def get_fp4_context(config: TransformerConfig, layer_no: int = -1, is_init: bool = False): + """Return nullcontext when Transformer Engine is not available.""" + return nullcontext() diff --git a/primus/backends/megatron/core/fp8_utils.py b/primus/backends/megatron/core/fp8_utils.py index 4e357a326..5058359d0 100644 --- a/primus/backends/megatron/core/fp8_utils.py +++ b/primus/backends/megatron/core/fp8_utils.py @@ -7,6 +7,8 @@ """Utility functions related to FP8 that are used throughout Megatron core""" from contextlib import nullcontext +import torch +from megatron.core.fp8_utils import is_first_last_bf16_layer from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import is_te_min_version @@ -46,7 +48,7 @@ from primus_turbo.pytorch.core.low_precision import ScaleDtype, ScalingGranularity from primus.backends.megatron.core.extensions.primus_turbo import ( - PrimusTurboFloat8QuantConfig, + PrimusTurboQuantConfig, ) def te_fp8_format_mapping(te_format): @@ -61,6 +63,98 @@ def te_fp8_format_mapping(te_format): return format_mapping[te_format] + def get_fp8_recipe(config: TransformerConfig): + """Return fp8 recipe. + + Arguments: + config (TransformerConfig): Configuration object. + + Returns: + FP8 recipe: Transformer Engine FP8 recipe. + FP8 None reason: reason why the fp8 recipe is None. + """ + if config.fp8 == "e4m3": + fp8_format = transformer_engine.common.recipe.Format.E4M3 + elif config.fp8 == "hybrid": + fp8_format = transformer_engine.common.recipe.Format.HYBRID + else: + raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") + + # Select fp8 recipe (TE version >= 2.1.0). + fp8_recipe = None + fp8_recipe_none_reason = "" + if is_te_min_version("2.1.0"): + if config.fp8_recipe == Fp8Recipe.delayed: + fp8_recipe = TEDelayedScaling( + config=config, + fp8_format=fp8_format, + override_linear_precision=(False, False, not config.fp8_wgrad), + ) + elif config.fp8_recipe == Fp8Recipe.tensorwise and is_te_min_version("2.2.0.dev0"): + fp8_recipe = transformer_engine.common.recipe.Float8CurrentScaling( + fp8_format=fp8_format, fp8_dpa=config.fp8_dot_product_attention + ) + elif config.fp8_recipe == Fp8Recipe.blockwise and is_te_min_version("2.3.0.dev0"): + fp8_recipe = transformer_engine.common.recipe.Float8BlockScaling(fp8_format=fp8_format) + elif config.fp8_recipe == Fp8Recipe.mxfp8: + fp8_recipe = transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=fp8_format) + else: + fp8_recipe_none_reason = "Float8CurrentScaling, MXFP8BlockScaling, Float8BlockwiseScaling and DelayedScaling are the only supported FP8 recipes. Please also make sure you are using a compatible TE version." + else: + # Assert that the user is using delayed scaling. + if config.fp8_recipe == Fp8Recipe.delayed: + fp8_recipe = TEDelayedScaling( + config=config, + fp8_format=fp8_format, + override_linear_precision=(False, False, not config.fp8_wgrad), + ) + else: + fp8_recipe_none_reason = "Please make sure to use TransformerEngine version >= 2.2.0.dev0 for Float8CurrentScaling, >= 2.1.0 for MXFP8BlockScaling, and >= 2.3.0.dev0 for Float8BlockScaling." + + return fp8_recipe, fp8_recipe_none_reason + + def get_fp8_quant_config(config: TransformerConfig): + """Return fp8 quant config. + + Arguments: + config (TransformerConfig): Configuration object. + + Returns: + FP8 quant config: Primus-Turbo FP8 quant config. + FP8 quant config none reason: reason why the fp8 quant config is None. + """ + if config.fp8 == "e4m3": + fp8_format = transformer_engine.common.recipe.Format.E4M3 + elif config.fp8 == "hybrid": + fp8_format = transformer_engine.common.recipe.Format.HYBRID + else: + raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") + + fp8_quant_config = None + fp8_quant_config_none_reason = "" + if config.fp8_recipe == Fp8Recipe.delayed: + # NOTE: Primus-Turbo not support delayed scaling. + fp8_quant_config_none_reason = "Primus-Turbo not support delayed scaling." + elif config.fp8_recipe == Fp8Recipe.tensorwise: + fp8_quant_config = PrimusTurboQuantConfig( + granularity=ScalingGranularity.TENSORWISE, format=te_fp8_format_mapping(fp8_format) + ) + elif config.fp8_recipe == Fp8Recipe.blockwise: + fp8_quant_config = PrimusTurboQuantConfig( + granularity=ScalingGranularity.BLOCKWISE, + format=te_fp8_format_mapping(fp8_format), + block_size=SCALING_BLOCK_SIZE, + ) + elif config.fp8_recipe == Fp8Recipe.mxfp8: + fp8_quant_config = PrimusTurboQuantConfig( + granularity=ScalingGranularity.MX_BLOCKWISE, + format=te_fp8_format_mapping(fp8_format), + block_size=MX_SCALING_BLOCK_SIZE, + scale_dtype=ScaleDtype.E8M0, + ) + + return fp8_quant_config, fp8_quant_config_none_reason + def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool = False): """Return fp8 context manager. @@ -76,83 +170,26 @@ def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool We return nullcontext() when: a) not using fp8 to train, b) layer_no is a layer that needs to be trained in bf16. """ - num_bf16_layers_at_start = config.num_layers_at_start_in_bf16 if config.first_last_layers_bf16 else 0 - num_bf16_layers_at_end = config.num_layers_at_end_in_bf16 if config.first_last_layers_bf16 else 0 - # Since layer_no is a global layer index, additional checks on whether - # we are in the first or last pipeline-parallel rank are not needed. - is_first_layer = layer_no < num_bf16_layers_at_start - is_last_layer = layer_no >= config.num_layers - num_bf16_layers_at_end need_fp8_context = config.fp8 if not is_init else config.fp8_param - if not need_fp8_context: - # bf16 training - fp8_context = nullcontext() - elif layer_no >= 0 and config.first_last_layers_bf16 and (is_first_layer or is_last_layer): - # fp8 training but this layer_no should be bf16 + if not need_fp8_context or is_first_last_bf16_layer(config, layer_no): + # bf16 training or bf16 layer in fp8 training fp8_context = nullcontext() else: # fp8 training and this layer_no is in fp8 - import transformer_engine # To keep out TE dependency when not training in fp8 - - if config.fp8 == "e4m3": - fp8_format = transformer_engine.common.recipe.Format.E4M3 - elif config.fp8 == "hybrid": - fp8_format = transformer_engine.common.recipe.Format.HYBRID - else: - raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") - - # Select TE fp8 recipe and turbo fp8 quant config - fp8_recipe, fp8_recipe_none_reason = None, "" - fp8_quant_config, fp8_quant_config_none_reason = None, "" - if config.fp8_recipe == Fp8Recipe.delayed: - fp8_recipe = TEDelayedScaling( - config=config, - fp8_format=fp8_format, - override_linear_precision=(False, False, not config.fp8_wgrad), - ) - # NOTE: Primus-Turbo not support delayed scaling. - fp8_quant_config_none_reason = "Primus-Turbo not support delayed scaling." - elif config.fp8_recipe == Fp8Recipe.tensorwise: - if is_te_min_version("2.2.0.dev0"): - fp8_recipe = transformer_engine.common.recipe.Float8CurrentScaling(fp8_format=fp8_format) - else: - fp8_recipe_none_reason = "Transformer Engine version < 2.2.0.dev0." - fp8_quant_config = PrimusTurboFloat8QuantConfig( - granularity=ScalingGranularity.TENSORWISE, format=te_fp8_format_mapping(fp8_format) - ) - elif config.fp8_recipe == Fp8Recipe.blockwise: - if is_te_min_version("2.3.0.dev0"): - fp8_recipe = transformer_engine.common.recipe.Float8BlockScaling(fp8_format=fp8_format) - else: - fp8_recipe_none_reason = "Transformer Engine version < 2.3.0.dev0." - fp8_quant_config = PrimusTurboFloat8QuantConfig( - granularity=ScalingGranularity.BLOCKWISE, - format=te_fp8_format_mapping(fp8_format), - block_size=SCALING_BLOCK_SIZE, - ) - elif config.fp8_recipe == Fp8Recipe.mxfp8: - if is_te_min_version("2.1.0.dev0"): - fp8_recipe = transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=fp8_format) - else: - fp8_recipe_none_reason = "Transformer Engine version < 2.1.0.dev0" - fp8_quant_config = PrimusTurboFloat8QuantConfig( - granularity=ScalingGranularity.MX_BLOCKWISE, - format=te_fp8_format_mapping(fp8_format), - block_size=MX_SCALING_BLOCK_SIZE, - scale_dtype=ScaleDtype.E8M0, - ) + fp8_recipe, fp8_recipe_none_reason = get_fp8_recipe(config) + fp8_quant_config, fp8_quant_config_none_reason = get_fp8_quant_config(config) global WARN_ONCE if WARN_ONCE: if fp8_recipe is None: warning_rank_0( - f"WARNING: TransformerEngine FP8 {config.fp8_recipe} not work since {fp8_recipe_none_reason}." + f"TransformerEngine FP8 {config.fp8_recipe} not work since {fp8_recipe_none_reason}" ) - if fp8_quant_config is None: warning_rank_0( - f"WARNING: Primus-Turbo FP8 {config.fp8_recipe} not work since {fp8_quant_config_none_reason}." + f"Primus-Turbo FP8 {config.fp8_recipe} not work since {fp8_quant_config_none_reason}" ) WARN_ONCE = False @@ -185,7 +222,7 @@ def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool if "preserve_high_precision_init_val" in ( inspect.signature(transformer_engine.pytorch.fp8_model_init).parameters ): - context_args["preserve_high_precision_init_val"] = True + context_args["preserve_high_precision_init_val"] = torch.is_grad_enabled() fp8_context = transformer_engine.pytorch.fp8_model_init(**context_args) # First / last layer in bf16 isn't supported with delayed scaling since it @@ -202,6 +239,59 @@ def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool from megatron.core.enums import Fp8Recipe from megatron.core.extensions.transformer_engine import TEDelayedScaling + def get_fp8_recipe(config: TransformerConfig): + """Return fp8 recipe. + + Arguments: + config (TransformerConfig): Configuration object. + + Returns: + FP8 recipe. + """ + if config.fp8 == "e4m3": + fp8_format = transformer_engine.common.recipe.Format.E4M3 + elif config.fp8 == "hybrid": + fp8_format = transformer_engine.common.recipe.Format.HYBRID + else: + raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") + + # Select fp8 recipe (TE version >= 2.1.0). + fp8_recipe = None + if is_te_min_version("2.1.0"): + if config.fp8_recipe == Fp8Recipe.delayed: + fp8_recipe = TEDelayedScaling( + config=config, + fp8_format=fp8_format, + override_linear_precision=(False, False, not config.fp8_wgrad), + ) + elif config.fp8_recipe == Fp8Recipe.tensorwise and is_te_min_version("2.2.0.dev0"): + fp8_recipe = transformer_engine.common.recipe.Float8CurrentScaling( + fp8_format=fp8_format, fp8_dpa=config.fp8_dot_product_attention + ) + elif config.fp8_recipe == Fp8Recipe.blockwise and is_te_min_version("2.3.0.dev0"): + fp8_recipe = transformer_engine.common.recipe.Float8BlockScaling(fp8_format=fp8_format) + elif config.fp8_recipe == Fp8Recipe.mxfp8: + fp8_recipe = transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=fp8_format) + else: + raise ValueError( + "Float8CurrentScaling, MXFP8BlockScaling, Float8BlockwiseScaling and " + "DelayedScaling are the only supported FP8 recipes. Please also make sure " + "you are using a compatible TE version." + ) + else: + # Assert that the user is using delayed scaling. + assert config.fp8_recipe == Fp8Recipe.delayed, ( + "Please make sure to use TransformerEngine version >= 2.2.0.dev0 for " + "Float8CurrentScaling, >= 2.1.0 for MXFP8BlockScaling, and >= 2.3.0.dev0 for " + "Float8BlockScaling." + ) + fp8_recipe = TEDelayedScaling( + config=config, + fp8_format=fp8_format, + override_linear_precision=(False, False, not config.fp8_wgrad), + ) + return fp8_recipe + def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool = False): """Return fp8 context manager. @@ -217,65 +307,15 @@ def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool We return nullcontext() when: a) not using fp8 to train, b) layer_no is a layer that needs to be trained in bf16. """ - num_bf16_layers_at_start = config.num_layers_at_start_in_bf16 if config.first_last_layers_bf16 else 0 - num_bf16_layers_at_end = config.num_layers_at_end_in_bf16 if config.first_last_layers_bf16 else 0 - # Since layer_no is a global layer index, additional checks on whether - # we are in the first or last pipeline-parallel rank are not needed. - is_first_layer = layer_no < num_bf16_layers_at_start - is_last_layer = layer_no >= config.num_layers - num_bf16_layers_at_end need_fp8_context = config.fp8 if not is_init else config.fp8_param - if not need_fp8_context: - # bf16 training - fp8_context = nullcontext() - elif layer_no >= 0 and config.first_last_layers_bf16 and (is_first_layer or is_last_layer): - # fp8 training but this layer_no should be bf16 + if not need_fp8_context or is_first_last_bf16_layer(config, layer_no): + # bf16 training or bf16 layer in fp8 training fp8_context = nullcontext() else: # fp8 training and this layer_no is in fp8 - import transformer_engine # To keep out TE dependency when not training in fp8 - - if config.fp8 == "e4m3": - fp8_format = transformer_engine.common.recipe.Format.E4M3 - elif config.fp8 == "hybrid": - fp8_format = transformer_engine.common.recipe.Format.HYBRID - else: - raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") - - # Select fp8 recipe (TE version >= 2.1.0). - fp8_recipe = None - if is_te_min_version("2.1.0"): - if config.fp8_recipe == Fp8Recipe.delayed: - fp8_recipe = TEDelayedScaling( - config=config, - fp8_format=fp8_format, - override_linear_precision=(False, False, not config.fp8_wgrad), - ) - elif config.fp8_recipe == Fp8Recipe.tensorwise and is_te_min_version("2.2.0.dev0"): - fp8_recipe = transformer_engine.common.recipe.Float8CurrentScaling(fp8_format=fp8_format) - elif config.fp8_recipe == Fp8Recipe.blockwise and is_te_min_version("2.3.0.dev0"): - fp8_recipe = transformer_engine.common.recipe.Float8BlockScaling(fp8_format=fp8_format) - elif config.fp8_recipe == Fp8Recipe.mxfp8: - fp8_recipe = transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=fp8_format) - else: - raise ValueError( - "Float8CurrentScaling, MXFP8BlockScaling, Float8BlockwiseScaling and " - "DelayedScaling are the only supported FP8 recipes. Please also make sure " - "you are using a compatible TE version." - ) - else: - # Assert that the user is using delayed scaling. - assert config.fp8_recipe == Fp8Recipe.delayed, ( - "Please make sure to use TransformerEngine version >= 2.2.0.dev0 for " - "Float8CurrentScaling, >= 2.1.0 for MXFP8BlockScaling, and >= 2.3.0.dev0 for " - "Float8BlockScaling." - ) - fp8_recipe = TEDelayedScaling( - config=config, - fp8_format=fp8_format, - override_linear_precision=(False, False, not config.fp8_wgrad), - ) + fp8_recipe = get_fp8_recipe(config) fp8_group = None if parallel_state.model_parallel_is_initialized(): @@ -298,7 +338,7 @@ def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool if "preserve_high_precision_init_val" in ( inspect.signature(transformer_engine.pytorch.fp8_model_init).parameters ): - context_args["preserve_high_precision_init_val"] = True + context_args["preserve_high_precision_init_val"] = torch.is_grad_enabled() fp8_context = transformer_engine.pytorch.fp8_model_init(**context_args) # First / last layer in bf16 isn't supported with delayed scaling since it @@ -312,6 +352,10 @@ def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool else: + def get_fp8_recipe(config: TransformerConfig): + """Returns None since TE is not available.""" + return None + def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool = False): """Returns dummy fp8 context manager since TE is not available.""" return nullcontext() diff --git a/primus/backends/megatron/patches/fp4_patches.py b/primus/backends/megatron/patches/fp4_patches.py new file mode 100644 index 000000000..725aa3907 --- /dev/null +++ b/primus/backends/megatron/patches/fp4_patches.py @@ -0,0 +1,70 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +Megatron enums Patches +""" + +from primus.core.patches import PatchContext, get_args, register_patch +from primus.modules.module_utils import log_rank_0 + + +@register_patch( + "megatron.core.enums", + backend="megatron", + phase="before_train", + description="Override Megatron enums to use Primus implementation when fp4 is enabled", + condition=lambda ctx: getattr(get_args(ctx), "fp4", False), +) +def patch_enums(ctx: PatchContext): + from megatron.core import enums + + from primus.backends.megatron.core.enums import Fp4Recipe + + log_rank_0("[Patch:megatron.core.enums] Overriding enums for fp4=True") + + enums.Fp4Recipe = Fp4Recipe + log_rank_0(f"[Patch:megatron.core.enums] Patched {enums.__name__}.Fp4Recipe") + + +@register_patch( + "megatron.core.fp4_utils", + backend="megatron", + phase="before_train", + description="Override Megatron get_fp4_context to use Primus implementation when fp4 is enabled", + condition=lambda ctx: getattr(get_args(ctx), "fp4", False), +) +def patch_fp4_context(ctx: PatchContext): + """ + Patch Megatron's get_fp4_context functions to use Primus implementation. + + Behavior (moved from MegatronTrainer.patch_fp4_context): + - When module_config.fp4 is True, replace get_fp4_context in: + * megatron.core.transformer.transformer_block + * megatron.core.ssm.mamba_block + * megatron.core.transformer.multi_token_prediction + * megatron.core.fp4_utils + with Primus's ROCm-friendly get_fp4_context. + """ + from megatron.core import fp4_utils + from megatron.core.ssm import mamba_block + from megatron.core.transformer import multi_token_prediction, transformer_block + + from primus.backends.megatron.core.fp4_utils import get_fp4_context + + log_rank_0("[Patch:megatron.fp4.context] Overriding get_fp4_context for fp4=True") + + # Patch get_fp4_context in all relevant modules + modules_to_patch = [ + transformer_block, + mamba_block, + multi_token_prediction, + fp4_utils, + ] + + for module in modules_to_patch: + module.get_fp4_context = get_fp4_context + log_rank_0(f"[Patch:megatron.fp4.context] Patched {module.__name__}.get_fp4_context") diff --git a/primus/modules/trainer/megatron/trainer.py b/primus/modules/trainer/megatron/trainer.py index a7e9d433d..ff33990f1 100644 --- a/primus/modules/trainer/megatron/trainer.py +++ b/primus/modules/trainer/megatron/trainer.py @@ -725,7 +725,20 @@ def initialize_megatron( if args.decoder_pipeline_manual_split_list is not None: from .utils import validate_args_modified - validate_args_modified(args, args_defaults) + ori_code = "if args.decoder_first_pipeline_num_layers is None and args.decoder_last_pipeline_num_layers is None:" + new_code = ( + "if args.decoder_pipeline_manual_split_list is None and " + ori_code.split("if ")[-1] + ) + + validate_args_modified(args, args_defaults, ori_code=ori_code, new_code=new_code) + elif args.fp4 is not None: + # TODO(ruibin): Remove it when ROCm TE upgrade to 2.7.0.dev0 + from .utils import validate_args_modified + + ori_code = """raise ValueError("--fp4-format requires Transformer Engine >= 2.7.0.dev0 for NVFP4BlockScaling support.")""" + new_code = """pass""" + + validate_args_modified(args, args_defaults, ori_code=ori_code, new_code=new_code) else: validate_args(args, args_defaults) diff --git a/primus/modules/trainer/megatron/utils.py b/primus/modules/trainer/megatron/utils.py index 0d59e955e..5bb8350f5 100644 --- a/primus/modules/trainer/megatron/utils.py +++ b/primus/modules/trainer/megatron/utils.py @@ -209,10 +209,11 @@ def validate_args_modifier(func, modification): exec(modified_source, func.__globals__, namespace) return namespace[func.__name__] - ori_code = ( - "if args.decoder_first_pipeline_num_layers is None and args.decoder_last_pipeline_num_layers is None:" - ) - new_code = "if args.decoder_pipeline_manual_split_list is None and " + ori_code.split("if ")[-1] + ori_code = kwargs.pop("ori_code", None) + new_code = kwargs.pop("new_code", None) + + assert ori_code is not None and new_code is not None, "ori_code and new_code must be provided." + megatron.training.arguments.validate_args = validate_args_modifier( megatron.training.arguments.validate_args, lambda s: s.replace(ori_code, new_code) ) @@ -489,6 +490,13 @@ def validate_args_on_rocm(args): args.fp8_recipe in support_fp8_recipe ), f"{args.fp8_recipe} recipe is not support when enable `use_turbo_parallel_linear`." + # Turbo FP4 linear check + if args.fp4 and args.use_turbo_parallel_linear: + support_fp4_recipe = ["mxfp4"] + assert ( + args.fp4_recipe in support_fp4_recipe + ), f"{args.fp4_recipe} recipe is not support when enable `use_turbo_parallel_linear`." + # NOTE: mxfp8 environment variable must be set to 1 to enable mxfp8 recipe on ROCm. if args.fp8_recipe == "mxfp8": assert (