From eac15d8fae860a471502b48f233b5e07d6efd303 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Mon, 22 Dec 2025 12:08:11 -0800 Subject: [PATCH 1/9] Add parameter best auto-pruning for Minitron Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- CHANGELOG.rst | 3 +- modelopt/torch/nas/plugins/megatron.py | 20 +- modelopt/torch/opt/searcher.py | 3 +- modelopt/torch/prune/__init__.py | 2 - .../torch/prune/plugins/mcore_minitron.py | 242 ++++++++++++++---- modelopt/torch/utils/network.py | 2 +- modelopt/torch/utils/plugins/__init__.py | 3 + .../torch/utils/plugins/megatron_model.py | 57 +++++ .../torch/nas_prune/minitron_common.py | 11 +- .../test_megatron_gpt_dynamic_modules.py | 77 +++++- .../test_megatron_mamba_dynamic_modules.py | 7 +- .../test_mcore_gpt_minitron_pruning.py | 105 ++------ .../test_mcore_mamba_minitron_pruning.py | 2 +- 13 files changed, 386 insertions(+), 148 deletions(-) create mode 100644 modelopt/torch/utils/plugins/megatron_model.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 978ac209d..424d6d542 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,6 +18,7 @@ NVIDIA Model Optimizer Changelog (Linux) - Add support for parallel draft heads in Eagle speculative decoding. - Add support to enable custom emulated quantization backend. See :meth:`register_quant_backend `` for more details. See an example in ``tests/unit/torch/quantization/test_custom_backend.py``. - Add ``examples/llm_qad`` for QAD training with Megatron-LM. +- Add support for ``params`` constraint based automatic neural architecture search in Minitron pruning (``mcore_minitron``) as an alternative to manual pruning using ``export_config``. See `examples/pruning/README.md `_ for more details on its usage. **Deprecations** @@ -84,7 +85,7 @@ NVIDIA Model Optimizer Changelog (Linux) **Documentation** -- Add general guidelines for Minitron pruning and distillation. See `examples/pruning/README.md `_ for more details. +- Add general guidelines for Minitron pruning and distillation. See `pruning guidelines `_ for more details. - Added example for exporting QLoRA checkpoint for vLLM deployment. Refer to `examples/llm_qat/README.md `_ for more details 0.37 (2025-10-08) diff --git a/modelopt/torch/nas/plugins/megatron.py b/modelopt/torch/nas/plugins/megatron.py index a25012709..5814009d0 100644 --- a/modelopt/torch/nas/plugins/megatron.py +++ b/modelopt/torch/nas/plugins/megatron.py @@ -27,8 +27,6 @@ from megatron.core.models.gpt import GPTModel from megatron.core.parallel_state import ( get_data_parallel_group, - get_pipeline_model_parallel_group, - get_tensor_model_parallel_group, is_pipeline_first_stage, is_pipeline_last_stage, ) @@ -54,13 +52,8 @@ from modelopt.torch.opt.searcher import ConstraintsDict from modelopt.torch.trace import Symbol from modelopt.torch.utils import distributed as dist -from modelopt.torch.utils import ( - get_module_device, - make_divisible, - param_num_from_forward, - print_rank_0, - random, -) +from modelopt.torch.utils import make_divisible, print_rank_0, random +from modelopt.torch.utils.plugins import param_num_megatron from ..algorithms import ( MODULE_TYPE_TO_CONSTRAINTS_FUNC, @@ -1045,7 +1038,6 @@ def modify( *, hidden_size_divisor: int = 1, ffn_hidden_size_divisor: int = 1, - mamba_num_heads_divisor: int = 1, mamba_head_dim_divisor: int = 1, num_moe_experts_divisor: int = 1, ): @@ -1054,7 +1046,6 @@ def modify( Args: hidden_size_divisor: The divisor of the hidden_size. ffn_hidden_size_divisor: The divisor of the mlp ffn_hidden_size. - mamba_num_heads_divisor: The divisor of the mamba num_heads. mamba_head_dim_divisor: The divisor of the mamba head_dim. num_moe_experts_divisor: The divisor of the number of MoE experts. """ @@ -1065,7 +1056,6 @@ def modify( for layer in self.decoder.layers: layer.modify( ffn_hidden_size_divisor=ffn_hidden_size_divisor, - mamba_num_heads_divisor=mamba_num_heads_divisor, mamba_head_dim_divisor=mamba_head_dim_divisor, num_moe_experts_divisor=num_moe_experts_divisor, ) @@ -1142,11 +1132,7 @@ def constraint_eval_funcs(self) -> dict[str, ConstraintEvalFunc]: def _get_params(self, _: ConstraintsRes | None = None) -> float: """Get number of model parameters from forward pass.""" - params = param_num_from_forward(self.model, args=self.dummy_input, unit=1.0) - reduced_params = torch.Tensor([params]).to(device=get_module_device(self.model)) - torch.distributed.all_reduce(reduced_params, group=get_pipeline_model_parallel_group()) - torch.distributed.all_reduce(reduced_params, group=get_tensor_model_parallel_group()) - return reduced_params.item() + return param_num_megatron(self.model, from_forward=True, args=self.dummy_input) def _get_flops(self, _: ConstraintsRes | None = None) -> float: """Get inference FLOPs.""" diff --git a/modelopt/torch/opt/searcher.py b/modelopt/torch/opt/searcher.py index 3052289cb..9e73b143c 100644 --- a/modelopt/torch/opt/searcher.py +++ b/modelopt/torch/opt/searcher.py @@ -35,7 +35,7 @@ import torch.nn as nn from modelopt.torch.utils import distributed as dist -from modelopt.torch.utils import no_stdout, run_forward_loop +from modelopt.torch.utils import no_stdout, print_rank_0, run_forward_loop LimitsTuple = tuple[float, float] ConstraintsDict = dict[str, str | float | dict | None] @@ -212,6 +212,7 @@ def construct_forward_loop( return None def forward_loop_with_silence_check(m: nn.Module) -> None: + print_rank_0("Running forward loop...") with no_stdout() if silent else nullcontext(): if data_loader is not None: run_forward_loop( diff --git a/modelopt/torch/prune/__init__.py b/modelopt/torch/prune/__init__.py index aac5f7e87..847b22e9d 100644 --- a/modelopt/torch/prune/__init__.py +++ b/modelopt/torch/prune/__init__.py @@ -19,8 +19,6 @@ simplifies the overall workflow to accommodate for the simpler nature of pruning algorithms. """ -# nas is a required - so let's check if it's available -import modelopt.torch.nas from modelopt.torch.utils import import_plugin from . import fastnas, gradnas, plugins diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index db6769b7b..40d5d608b 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -24,9 +24,9 @@ Actual dynamic module implementations are at :mod:`modelopt.torch.nas.plugins.megatron`. """ -import copy from collections.abc import Callable from functools import partial +from itertools import product from typing import Any from warnings import warn @@ -43,6 +43,7 @@ reduce_from_tensor_model_parallel_region, ) from pydantic import create_model +from tqdm import tqdm from modelopt.torch.nas.conversion import NASModeRegistry from modelopt.torch.nas.plugins.megatron import ( @@ -57,7 +58,7 @@ _DynamicTransformerLayer, ) from modelopt.torch.nas.registry import DMRegistry -from modelopt.torch.nas.utils import get_subnet_config, sort_parameters +from modelopt.torch.nas.utils import get_subnet_config, sample, sample_and_reset, sort_parameters from modelopt.torch.opt.config import ModeloptBaseConfig, get_kwargs_for_create_model_with_rules from modelopt.torch.opt.conversion import ApplyModeError from modelopt.torch.opt.dynamic import DynamicModule, DynamicSpace @@ -71,6 +72,7 @@ from modelopt.torch.opt.utils import named_hparams from modelopt.torch.utils import distributed as dist from modelopt.torch.utils import get_module_device, print_rank_0 +from modelopt.torch.utils.plugins import param_num_megatron from ..pruning import PruneModeRegistry @@ -172,6 +174,7 @@ class MCoreMinitronSearcher(BaseSearcher): activations_per_rank: list[dict[str, torch.Tensor]] layer_scores: dict[int, torch.Tensor] + top_k_candidates_per_constraint: dict[float, list[tuple[dict, float]]] @property def default_search_config(self) -> SearchConfig: @@ -181,12 +184,20 @@ def default_search_config(self) -> SearchConfig: "max_iter_data_loader": 1024, "skip_sorting": False, "scores_path": None, + # Additional search config for parameter-based pruning + "max_width_pruning": 0.5, # Maximum fraction per width hyperparameter to prune + "max_depth_pruning": 0.25, # Maximum fraction per depth hyperparameter to prune + "top_k": 10, # Number of candidates to consider for score_func validation } @property def default_state_dict(self) -> SearchStateDict: """Return default state dict for importance scores and activations from forward loop.""" - return {"activations_per_rank": [], "layer_scores": {}} + return { + "activations_per_rank": [], + "layer_scores": {}, + "top_k_candidates_per_constraint": {}, + } def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig: """Sanitize the search config dict.""" @@ -200,32 +211,41 @@ def before_search(self) -> None: super().before_search() # Check that the constraint is valid - assert self.constraints.keys() == {"export_config"}, ( - "Only `export_config` constraint is supported for pruning!" - ) - - self.constraints["export_config"] = copy.deepcopy(self.constraints["export_config"]) - export_config = self.constraints["export_config"] - if "num_query_groups" in export_config: - warn("num_query_groups is no longer supported (since 0.41)! It will be ignored.") - if export_config["num_query_groups"] != self.model.config.num_query_groups: # type: ignore[index] - raise ValueError(f"num_query_groups must be {self.model.config.num_query_groups}!") - export_config.pop("num_query_groups") # type: ignore[union-attr] - assert isinstance(export_config, dict) # to keep mypy happy - assert export_config.keys() <= SUPPORTED_HPARAMS, ( - f"Only {SUPPORTED_HPARAMS} are supported for pruning! Received: {export_config.keys()}" - ) + assert len(self.constraints) == 1 and next(iter(self.constraints.keys())) in { + "export_config", + "params", + }, "Only `export_config` or `params` constraint is supported!" + + if "export_config" in self.constraints: + export_config = self.constraints["export_config"] + assert isinstance(export_config, dict) # to keep mypy happy + if "num_query_groups" in export_config: + warn("num_query_groups is no longer supported (since 0.41)! It will be ignored.") + if export_config["num_query_groups"] != self.model.config.num_query_groups: + raise ValueError( + f"num_query_groups must be {self.model.config.num_query_groups}!" + ) + export_config.pop("num_query_groups") + assert export_config.keys() <= SUPPORTED_HPARAMS, ( + f"Only {SUPPORTED_HPARAMS} are supported for pruning! Received: {export_config=}" + ) - # Only sort the parameters that are to be pruned - # If a user only prunes depth, we should not sort width parameters - self.hps_to_sort = SUPPORTED_HPARAMS & export_config.keys() + # Only sort the parameters that are to be pruned + # If a user only prunes depth, we should not sort width parameters + self.hps_to_sort = set(export_config.keys()) + else: + assert isinstance(self.constraints["params"], float), "params must be a float!" + assert self.has_score, "score_func (e.g. MMLU) is required for parameter-based pruning!" + export_config = None + # Sort all parameters for parameter-based pruning + self.hps_to_sort = SUPPORTED_HPARAMS for n, hp in named_hparams(self.model, unique=True): hp_name = n.split(".")[-1] if hp.is_configurable: # Make sure configurable hparams are the ones with right names else implementation needs to be fixed! assert hp_name in SUPPORTED_HPARAMS, f"[ImplError] Invalid hparam {hp_name}!" - if hp_name in export_config: + if export_config is not None and hp_name in export_config: assert export_config[hp_name] in hp.choices, ( f"Invalid choice {export_config[hp_name]} for {n}! Available choices: {hp.choices}" ) @@ -243,10 +263,8 @@ def run_search(self) -> None: registry = ImportanceEstimatorRegistry(unwrapped_model) if self.layer_scores and self.activations_per_rank: # Available from checkpoint - print_rank_0("Loading activations and scores per rank from checkpoint...") registry.set_activations_and_layer_scores(self.activations_per_rank, self.layer_scores) elif not self.config["skip_sorting"]: - print_rank_0("Running forward loop...") assert self.forward_loop is not None is_training = self.model.training self.model.eval() @@ -265,8 +283,17 @@ def run_search(self) -> None: else: sort_parameters(self.model, self.hps_to_sort, verbose=True) + if "params" in self.constraints: + export_config = self.search_best_arch_by_params( + max_params=self.constraints["params"], # type: ignore[arg-type] + max_width_pruning=self.config["max_width_pruning"], + max_depth_pruning=self.config["max_depth_pruning"], + top_k=self.config["top_k"], + ) + else: + export_config = self.constraints["export_config"] + # Prune homogeneously - export_config = self.constraints["export_config"] assert isinstance(export_config, dict) # to keep mypy happy for n, hp in named_hparams(self.model, configurable=True): hp_name = n.split(".")[-1] @@ -281,20 +308,150 @@ def run_search(self) -> None: layers_to_drop = [layer for layer, _ in sorted_layers[num_layers_hp.active :]] # type: ignore[misc] drop_mcore_language_model_layers(self.model, layers_to_drop=layers_to_drop) + # Update model config with pruned architecture # kv_channels can be None so we need to save original from original hidden_size and num_attention_heads - model_cfg = self.model.config - orig_kv_channels = getattr(model_cfg, "kv_channels") + orig_kv_channels = self.model.config.kv_channels if orig_kv_channels is None: - orig_kv_channels = getattr(model_cfg, "hidden_size") // getattr( - model_cfg, "num_attention_heads" + orig_kv_channels = ( + self.model.config.hidden_size // self.model.config.num_attention_heads ) - setattr(model_cfg, "kv_channels", orig_kv_channels) - for n in SUPPORTED_HPARAMS: - if n in export_config: - setattr(model_cfg, n, export_config[n]) + self.model.config.kv_channels = orig_kv_channels + for hp_name, hp_value in export_config.items(): + setattr(self.model.config, hp_name, hp_value) registry.cleanup() + def search_best_arch_by_params( + self, + max_params: float, + max_width_pruning: float = 0.5, + max_depth_pruning: float = 0.25, + top_k: int = 10, + ) -> dict: + """Search for the best architecture based on the given parameters constraints. + + We perform a grid-search over the search space to find subnets (homogeneous) fitting the constraints. + Top-k candidates (sorted by param count) are then validated using the score_func (e.g. MMLU) + and the best subnet is returned. + + Args: + max_params: Maximum number of parameters for the pruned model. + max_width_pruning: Maximum fraction per width hyperparameter to prune (default: 0.5). + Only top (1 - max_width_pruning) choices will be considered. + max_depth_pruning: Maximum fraction per depth hyperparameter to prune (default: 0.25). + Only top (1 - max_depth_pruning) choices will be considered. + top_k: Number of candidates to consider for score_func validation. + + Returns: + export_config: Dictionary mapping hyperparameter names to their pruned values. + """ + print_rank_0( + f"\nSearching for the best pruned architecture under {max_params / 1e9:.2f}B params constraints" + ) + + # 1. Find available search space choices (across all PP ranks) + hp_choices = {} + for n, hp in named_hparams(self.model, configurable=True): + hp_name = n.split(".")[-1] + hp_choices[hp_name] = hp.choices + all_pp_search_spaces = [None] * get_pipeline_model_parallel_world_size() + torch.distributed.all_gather_object( + all_pp_search_spaces, hp_choices, group=get_pipeline_model_parallel_group() + ) + hp_choices = {k: v for d in all_pp_search_spaces for k, v in d.items()} # type: ignore[attr-defined] + + # 2. Perform grid-search over the search space to find subnets fitting the constraints + if max_params not in self.top_k_candidates_per_constraint: + search_space_configs = MCoreMinitronSearcher._generate_search_space_combos( + hp_choices, # type: ignore[arg-type] + max_width_pruning, + max_depth_pruning, + ) + sample(self.model, sample_func=max) # reset for sanity + selected: list[tuple[dict, float]] = [] + for config in tqdm( + search_space_configs, + desc=f"Finding top {top_k} candidates fitting the constraints...", + disable=not dist.is_master(), + ): + # Convert search space config to fnmatch pattern and sample function + # Use partial to bind each value at creation time (avoid late-binding closure issue) + sample_func = { + f"*.{k}": partial(lambda val, choices: val, v) for k, v in config.items() + } + with sample_and_reset(self.model, sample_func=sample_func): # type: ignore[arg-type] + candidate_params = param_num_megatron(self.model) + if candidate_params <= max_params: + selected.append((config, candidate_params)) + assert len(selected) > 0, "No subnets found fitting the constraints!" + self.top_k_candidates_per_constraint[max_params] = sorted( + selected, key=lambda x: x[1], reverse=True + )[:top_k] + self.save_search_checkpoint(verbose=True) + else: + print_rank_0(f"Using top {top_k} candidates from checkpoint") + top_k_candidates = self.top_k_candidates_per_constraint[max_params] + + # 3. Validate top-k candidates using the score_func and return the best subnet + # TODO: update this + best = top_k_candidates[0][0] + + return best + + @staticmethod + def _generate_search_space_combos( + search_space: dict[str, list], + max_width_pruning: float = 0.5, + max_depth_pruning: float = 0.25, + ) -> list[dict[str, Any]]: + """Generate all possible combinations of hyperparameters from the search space. + + Args: + search_space: Dictionary mapping hyperparameter names to their possible sorted choices. + Example: {"hidden_size": [1024, 2048, 3072, 4096], "num_layers": [1, 2, ..., 31, 32]} + max_width_pruning: Maximum fraction of width hyperparameters to prune (default: 0.5). + Only top (1 - max_width_pruning) choices will be considered. + max_depth_pruning: Maximum fraction of depth hyperparameters to prune (default: 0.25). + Only top (1 - max_depth_pruning) choices will be considered. + + Returns: + List of configuration dictionaries, where each dictionary maps hyperparameter + names to their chosen values. Example: + [ + {"hidden_size": 1024, "num_layers": 1}, + {"hidden_size": 1024, "num_layers": 2}, + ... + {"hidden_size": 4096, "num_layers": 32}, + ] + """ + print_rank_0( + f"\nOnly considering atmost {(max_width_pruning * 100):.0f}% for width and " + f"{max_depth_pruning * 100:.0f}% for depth pruning hparams" + ) + + filtered_ss = { + k: sorted(v)[int((1 - max_depth_pruning) * len(v)) :] + if k == "num_layers" + else sorted(v)[int((1 - max_width_pruning) * len(v)) :] + for k, v in search_space.items() + } + + ss_size = 1 + for k, v in filtered_ss.items(): + print_rank_0(f"\tSearch space for {k}: {v}") + ss_size *= len(v) + print_rank_0(f"\tTotal search space in consideration: {ss_size}\n") + + hparam_names = list(filtered_ss.keys()) + hparam_choices_lists = [filtered_ss[name] for name in hparam_names] + + search_space_combos = [ + dict(zip(hparam_names, choices)) for choices in product(*hparam_choices_lists) + ] + assert len(search_space_combos) == ss_size + + return search_space_combos + MCoreMinitronConfig: type[ModeloptBaseConfig] = create_model( "MCoreMinitronConfig", @@ -302,17 +459,17 @@ def run_search(self) -> None: registry=DMRegistry, default_rules={ "megatron.core.models.gpt.GPTModel": { - "hidden_size_divisor": 64, - "ffn_hidden_size_divisor": 64, - "num_moe_experts_divisor": 1, + "hidden_size_divisor": 256, + "ffn_hidden_size_divisor": 256, + "num_moe_experts_divisor": 8, }, **( { "megatron.core.models.mamba.MambaModel": { - "hidden_size_divisor": 64, - "ffn_hidden_size_divisor": 64, - "mamba_head_dim_divisor": 4, - "num_moe_experts_divisor": 1, + "hidden_size_divisor": 256, + "ffn_hidden_size_divisor": 256, + "mamba_head_dim_divisor": 8, + "num_moe_experts_divisor": 8, } } if HAS_MAMBA @@ -325,9 +482,7 @@ def run_search(self) -> None: def get_mcore_minitron_config( - channel_divisor: int = 64, - mamba_head_dim_divisor: int = 4, - num_moe_experts_divisor: int = 1, + channel_divisor: int = 256, mamba_head_dim_divisor: int = 8, num_moe_experts_divisor: int = 8 ) -> ModeloptBaseConfig: """Get a MCoreMinitronConfig with the given channel divisor instead of default.""" config = MCoreMinitronConfig() @@ -562,6 +717,7 @@ def set_activations_and_layer_scores( activations_per_rank: List of dicts from module name to activations. Should match PP size. layer_scores: Dict from layer_number (1-indexed) to score. """ + print_rank_0("Loading activations and scores per rank from checkpoint...") rank = get_pipeline_model_parallel_rank() pp_size = get_pipeline_model_parallel_world_size() assert len(activations_per_rank) == pp_size, ( diff --git a/modelopt/torch/utils/network.py b/modelopt/torch/utils/network.py index 1940295c3..e18c85c3b 100644 --- a/modelopt/torch/utils/network.py +++ b/modelopt/torch/utils/network.py @@ -142,7 +142,7 @@ def param_num_from_forward( Returns: The number of parameters from the model's forward pass in the given unit. - This can helpful for dynamic modules, where the state dict might contain extra parameters that + This can helpful for MoE or dynamic modules, where the state dict might contain extra parameters that is not actively used in the model, e.g., because of a DynamicModule that is deactivated for the forward pass. We circumvent this issue by just counting parameters of modules that appear in a forward pass. diff --git a/modelopt/torch/utils/plugins/__init__.py b/modelopt/torch/utils/plugins/__init__.py index 517c59914..ac1053aa2 100644 --- a/modelopt/torch/utils/plugins/__init__.py +++ b/modelopt/torch/utils/plugins/__init__.py @@ -23,5 +23,8 @@ with import_plugin("megatron_mmlu"): from .megatron_mmlu import * +with import_plugin("megatron_model"): + from .megatron_model import * + with import_plugin("megatron_preprocess_data"): from .megatron_preprocess_data import * diff --git a/modelopt/torch/utils/plugins/megatron_model.py b/modelopt/torch/utils/plugins/megatron_model.py new file mode 100644 index 000000000..5ea2a7236 --- /dev/null +++ b/modelopt/torch/utils/plugins/megatron_model.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""General utilities for Megatron models.""" + +from typing import Any + +import torch +from megatron.core.parallel_state import ( + get_pipeline_model_parallel_group, + get_tensor_model_parallel_group, +) +from megatron.core.transformer.module import MegatronModule + +from ..network import param_num_from_forward + +__all__ = ["param_num_megatron"] + + +def param_num_megatron( + model: MegatronModule, *, from_forward: bool = False, args: Any = None +) -> float: + """Get the number of parameters in the model (reduced across TP and PP ranks). + + Args: + model: The Megatron model. + from_forward: To get the number of params from a forward pass instead of directly counting the params. + This can helpful for MoE or dynamic modules, where the state dict might contain extra parameters that + is not actively used in the model, e.g., because of a DynamicModule that is deactivated for the + forward pass. We circumvent this issue by just counting parameters of modules that appear in a + forward pass. + args: The arguments to pass to the forward pass. Only used if from_forward is True. + + Returns: + The number of parameters in the model (reduced across TP and PP ranks). + """ + if from_forward: + assert args is not None, "args must be provided if from_forward is True" + params = int(param_num_from_forward(model, args, unit=1.0)) + else: + params = sum(p.numel() for p in model.parameters()) + reduced_params = torch.Tensor([params]).to(device=next(model.parameters()).device) + torch.distributed.all_reduce(reduced_params, group=get_pipeline_model_parallel_group()) + torch.distributed.all_reduce(reduced_params, group=get_tensor_model_parallel_group()) + return reduced_params.item() diff --git a/tests/_test_utils/torch/nas_prune/minitron_common.py b/tests/_test_utils/torch/nas_prune/minitron_common.py index 856edd38c..97b12a4ca 100644 --- a/tests/_test_utils/torch/nas_prune/minitron_common.py +++ b/tests/_test_utils/torch/nas_prune/minitron_common.py @@ -19,7 +19,16 @@ def prune_minitron(model, export_config, config, channel_divisor=64): return mtp.prune( model, - mode=[("mcore_minitron", mtp.mcore_minitron.get_mcore_minitron_config(channel_divisor))], + mode=[ + ( + "mcore_minitron", + mtp.mcore_minitron.get_mcore_minitron_config( + channel_divisor=channel_divisor, + mamba_head_dim_divisor=4, + num_moe_experts_divisor=1, + ), + ) + ], constraints={"export_config": export_config}, dummy_input=None, # Not used config=config, diff --git a/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py b/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py index 2679d3090..16b45cdb0 100644 --- a/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py +++ b/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py @@ -25,6 +25,7 @@ from _test_utils.torch.megatron.models import get_mcore_gpt_model from _test_utils.torch.megatron.utils import run_mcore_inference from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.parallel_state import destroy_model_parallel from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.mlp import MLP from megatron.core.transformer.transformer_layer import TransformerLayer @@ -32,6 +33,7 @@ import modelopt.torch.nas as mtn from modelopt.torch.nas.modules import DynamicModuleList from modelopt.torch.nas.plugins.megatron import ( + NumAttentionHeadsHp, _DynamicColumnParallelLinear, _DynamicEmbedding, _DynamicLanguageModelEmbedding, @@ -81,7 +83,7 @@ def _test_gpt_search_space( normalization=normalization, ).cuda() - model = mtn.convert(model, [("mcore_minitron", get_mcore_minitron_config(channel_divisor))]) + mtn.convert(model, [("mcore_minitron", get_mcore_minitron_config(channel_divisor))]) assert isinstance(model, _DynamicMCoreLanguageModel) for m in model.modules(): @@ -153,6 +155,74 @@ def test_expand_head_indices(): assert expand_head_indices(heads, hidden_size_per_head).tolist() == [2, 3, 6, 7, 4, 5, 0, 1] +def test_gpt_self_attention_head_sorting(distributed_setup_size_1): + model = get_mcore_gpt_model( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + initialize_megatron=True, + num_layers=1, + hidden_size=16, + num_attention_heads=8, + num_query_groups=2, + ffn_hidden_size=16, + activation_func="squared_relu", + ).cuda() + + model = mtn.convert(model, "mcore_minitron") + + self_attn = model.decoder.layers[0].self_attention + assert isinstance(self_attn, _DynamicSelfAttention) + assert isinstance(self_attn.linear_qkv, _DynamicQKVColumnParallelLinear) + assert isinstance(self_attn.linear_proj, _DynamicProjRowParallelLinear) + + hp_num_attention_heads = self_attn.get_hparam("num_attention_heads") + assert isinstance(hp_num_attention_heads, NumAttentionHeadsHp) + + # Choices are multiples of num_query_groups (2): [2, 4, 6, 8] + assert hp_num_attention_heads.choices == [2, 4, 6, 8] + assert hp_num_attention_heads._num_query_groups == 2 + + # Set importance and slice order + # Importance per head (group-aware): [2.2, 0.1, 1.1, 2.1, 3.0, 2.0, 0.0, 1.0] + # Group 0 (heads 0-3): [2.2, 0.1, 1.1, 2.1] → sorted: [0, 3, 2, 1] + # Group 1 (heads 4-7): [3.0, 2.0, 0.0, 1.0] → sorted: [4, 5, 7, 6] + # Global ranking (group-aware, flattened): [0, 3, 2, 1, 4, 5, 7, 6] + hp_num_attention_heads._get_importance = lambda: torch.tensor( + [2.2, 0.1, 1.1, 2.1, 3.0, 2.0, 0.0, 1.0] + ) + # _estimate_head_ranking returns ranking as 1D tensor + expected_ranking = torch.tensor([0, 3, 2, 1, 4, 5, 7, 6]) + hp_num_attention_heads.enforce_order(expected_ranking) + + assert hp_num_attention_heads.active_slice.tolist() == [0, 3, 2, 1, 4, 5, 7, 6] + + # check if we get correct selection of sorted + pruned heads after setting active values + hp_num_attention_heads.active = 4 # top 2 heads per group (2 groups * 2 heads = 4 total) + + # Expected: Top 2 heads from each group: [0, 3] from group 0, [4, 5] from group 1 + expected_q_heads = [0, 3, 4, 5] + # In QKV layout (4 heads/group → 6 QKV heads/group): + # Group 0: Q=[0, 3], K=4, V=5 → QKV indices [0, 3, 4, 5] + # Group 1: Q=[4, 5], K=10, V=11 → QKV indices [6, 7, 10, 11] + expected_qkv_heads = [0, 3, 4, 5, 6, 7, 10, 11] + + assert ( + self_attn.linear_qkv._get_output_size_indices().tolist() + == expand_head_indices( + torch.LongTensor(expected_qkv_heads), model.config.kv_channels + ).tolist() + ) + assert ( + self_attn.linear_proj._get_input_size_indices().tolist() + == expand_head_indices( + torch.LongTensor(expected_q_heads), model.config.kv_channels + ).tolist() + ) + + # Clean up since this is not a spawned process + destroy_model_parallel() + + def _test_gpt_moe_search_space(rank, size): channel_divisor = 4 @@ -183,7 +253,10 @@ def _test_gpt_moe_search_space(rank, size): moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, ).cuda() - model = mtn.convert(model, [("mcore_minitron", get_mcore_minitron_config(channel_divisor))]) + mtn.convert( + model, + [("mcore_minitron", get_mcore_minitron_config(channel_divisor, num_moe_experts_divisor=1))], + ) moe = model.decoder.layers[0].mlp assert isinstance(moe, _DynamicMoELayer) diff --git a/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py b/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py index 430b5e261..6a1bc7a8a 100644 --- a/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py +++ b/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py @@ -51,7 +51,7 @@ def _test_mamba_search_space(rank, size): mamba_head_dim_divisor = 4 num_layers = size - hybrid_override_pattern = "M" * size + hybrid_override_pattern = "M" * size # all layers are Mamba layers hidden_size = channel_divisor * 4 mamba_state_dim = channel_divisor mamba_head_dim = mamba_head_dim_divisor * 2 @@ -75,7 +75,10 @@ def _test_mamba_search_space(rank, size): ).cuda() mamba_num_heads = model.decoder.layers[0].mixer.nheads - model = mtn.convert(model, [("mcore_minitron", get_mcore_minitron_config(channel_divisor))]) + mtn.convert( + model, + [("mcore_minitron", get_mcore_minitron_config(channel_divisor, mamba_head_dim_divisor))], + ) assert isinstance(model, _DynamicMCoreLanguageModel) if is_pipeline_first_stage(): diff --git a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py index a0d4877bb..2f1eae76b 100644 --- a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py +++ b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py @@ -29,20 +29,13 @@ ) from _test_utils.torch.misc import compare_outputs, set_seed from _test_utils.torch.nas_prune.minitron_common import prune_minitron -from megatron.core.parallel_state import destroy_model_parallel from megatron.core.transformer.identity_op import IdentityOp import modelopt.torch.nas as mtn from modelopt.torch.nas.conversion import export_searchspace -from modelopt.torch.nas.plugins.megatron import ( - NumAttentionHeadsHp, - _DynamicProjRowParallelLinear, - _DynamicQKVColumnParallelLinear, - _DynamicSelfAttention, - expand_head_indices, -) from modelopt.torch.prune.plugins.mcore_minitron import ( ImportanceEstimatorRegistry, + MCoreMinitronSearcher, _convert_model_to_dynamic_space, get_mcore_minitron_config, ) @@ -124,74 +117,6 @@ def test_mcore_gpt_parameter_sorting(activation_func): ) -def test_mcore_gpt_self_attention_head_sorting(distributed_setup_size_1): - model = get_mcore_gpt_model( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - initialize_megatron=True, - num_layers=1, - hidden_size=16, - num_attention_heads=8, - num_query_groups=2, - ffn_hidden_size=16, - activation_func="squared_relu", - ).cuda() - - model = mtn.convert(model, "mcore_minitron") - - self_attn = model.decoder.layers[0].self_attention - assert isinstance(self_attn, _DynamicSelfAttention) - assert isinstance(self_attn.linear_qkv, _DynamicQKVColumnParallelLinear) - assert isinstance(self_attn.linear_proj, _DynamicProjRowParallelLinear) - - hp_num_attention_heads = self_attn.get_hparam("num_attention_heads") - assert isinstance(hp_num_attention_heads, NumAttentionHeadsHp) - - # Choices are multiples of num_query_groups (2): [2, 4, 6, 8] - assert hp_num_attention_heads.choices == [2, 4, 6, 8] - assert hp_num_attention_heads._num_query_groups == 2 - - # Set importance and slice order - # Importance per head (group-aware): [2.2, 0.1, 1.1, 2.1, 3.0, 2.0, 0.0, 1.0] - # Group 0 (heads 0-3): [2.2, 0.1, 1.1, 2.1] → sorted: [0, 3, 2, 1] - # Group 1 (heads 4-7): [3.0, 2.0, 0.0, 1.0] → sorted: [4, 5, 7, 6] - # Global ranking (group-aware, flattened): [0, 3, 2, 1, 4, 5, 7, 6] - hp_num_attention_heads._get_importance = lambda: torch.tensor( - [2.2, 0.1, 1.1, 2.1, 3.0, 2.0, 0.0, 1.0] - ) - # _estimate_head_ranking returns ranking as 1D tensor - expected_ranking = torch.tensor([0, 3, 2, 1, 4, 5, 7, 6]) - hp_num_attention_heads.enforce_order(expected_ranking) - - assert hp_num_attention_heads.active_slice.tolist() == [0, 3, 2, 1, 4, 5, 7, 6] - - # check if we get correct selection of sorted + pruned heads after setting active values - hp_num_attention_heads.active = 4 # top 2 heads per group (2 groups * 2 heads = 4 total) - - # Expected: Top 2 heads from each group: [0, 3] from group 0, [4, 5] from group 1 - expected_q_heads = [0, 3, 4, 5] - # In QKV layout (4 heads/group → 6 QKV heads/group): - # Group 0: Q=[0, 3], K=4, V=5 → QKV indices [0, 3, 4, 5] - # Group 1: Q=[4, 5], K=10, V=11 → QKV indices [6, 7, 10, 11] - expected_qkv_heads = [0, 3, 4, 5, 6, 7, 10, 11] - - assert ( - self_attn.linear_qkv._get_output_size_indices().tolist() - == expand_head_indices( - torch.LongTensor(expected_qkv_heads), model.config.kv_channels - ).tolist() - ) - assert ( - self_attn.linear_proj._get_input_size_indices().tolist() - == expand_head_indices( - torch.LongTensor(expected_q_heads), model.config.kv_channels - ).tolist() - ) - - # Clean up since this is not a spawned process - destroy_model_parallel() - - def _test_mcore_gpt_pruning( num_attention_heads, num_query_groups, @@ -430,7 +355,7 @@ def _test_mcore_gpt_moe_parameter_sorting(rank, size): model.eval() dynamic_space = _convert_model_to_dynamic_space( - model, get_mcore_minitron_config(channel_divisor) + model, get_mcore_minitron_config(channel_divisor=channel_divisor, num_moe_experts_divisor=1) ) registry = ImportanceEstimatorRegistry(model) # register imp estimators and forward hooks @@ -570,3 +495,29 @@ def test_mcore_gpt_pruning_moe(tmp_path): job=partial(_test_mcore_gpt_pruning_moe, tmp_path / "minitron_scores.pth"), backend="nccl", ) + + +def test_generate_search_space_combos(): + ss = { + "hidden_size": [32, 64, 96, 128, 160], + "num_attention_heads": [8, 16, 24, 32], + "num_layers": [1, 2, 3, 4, 5, 6, 7, 8], + } + ss_combos = MCoreMinitronSearcher._generate_search_space_combos( + ss, max_width_pruning=0.5, max_depth_pruning=0.25 + ) + assert len(ss_combos) == 3 * 2 * 2 + assert ss_combos == [ + {"hidden_size": 96, "num_attention_heads": 24, "num_layers": 7}, + {"hidden_size": 96, "num_attention_heads": 24, "num_layers": 8}, + {"hidden_size": 96, "num_attention_heads": 32, "num_layers": 7}, + {"hidden_size": 96, "num_attention_heads": 32, "num_layers": 8}, + {"hidden_size": 128, "num_attention_heads": 24, "num_layers": 7}, + {"hidden_size": 128, "num_attention_heads": 24, "num_layers": 8}, + {"hidden_size": 128, "num_attention_heads": 32, "num_layers": 7}, + {"hidden_size": 128, "num_attention_heads": 32, "num_layers": 8}, + {"hidden_size": 160, "num_attention_heads": 24, "num_layers": 7}, + {"hidden_size": 160, "num_attention_heads": 24, "num_layers": 8}, + {"hidden_size": 160, "num_attention_heads": 32, "num_layers": 7}, + {"hidden_size": 160, "num_attention_heads": 32, "num_layers": 8}, + ] diff --git a/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py b/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py index d6fa9400b..a7f036bbb 100644 --- a/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py +++ b/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py @@ -78,7 +78,7 @@ def _test_mcore_mamba_parameter_sorting(rank, size): model.eval() dynamic_space = _convert_model_to_dynamic_space( - model, get_mcore_minitron_config(channel_divisor) + model, get_mcore_minitron_config(channel_divisor=channel_divisor, mamba_head_dim_divisor=4) ) registry = ImportanceEstimatorRegistry(model) # register imp estimators and forward hooks From 0e71881df21e86e5fc9995d9f03b6a33863f0d19 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Tue, 30 Dec 2025 07:36:58 -0800 Subject: [PATCH 2/9] Fix param count calculation + account depth pruning Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .vscode/settings.json | 2 +- modelopt/torch/nas/plugins/megatron.py | 12 +- .../torch/prune/plugins/mcore_minitron.py | 200 ++++++++++++------ modelopt/torch/prune/pruning.py | 2 +- modelopt/torch/utils/logging.py | 2 +- modelopt/torch/utils/network.py | 7 + .../torch/utils/plugins/megatron_model.py | 5 + .../torch/nas_prune/minitron_common.py | 1 + .../test_megatron_gpt_dynamic_modules.py | 19 +- .../test_megatron_mamba_dynamic_modules.py | 11 +- .../test_mcore_gpt_minitron_pruning.py | 2 +- 11 files changed, 192 insertions(+), 71 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 0a3a2353e..0e8465ad3 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -40,7 +40,7 @@ "--no-cov", ], "evenBetterToml.schema.enabled": false, // disable toml/json schema since we have custom fields - "python.analysis.extraPaths": [ + "cursorpyright.analysis.extraPaths": [ "./tests/" // add tests to python path just like pytest does in pyproject.toml ], "git.alwaysSignOff": true, diff --git a/modelopt/torch/nas/plugins/megatron.py b/modelopt/torch/nas/plugins/megatron.py index 5814009d0..917b6e7c0 100644 --- a/modelopt/torch/nas/plugins/megatron.py +++ b/modelopt/torch/nas/plugins/megatron.py @@ -1040,6 +1040,7 @@ def modify( ffn_hidden_size_divisor: int = 1, mamba_head_dim_divisor: int = 1, num_moe_experts_divisor: int = 1, + num_layers_divisor: int = 1, ): """Modify the dynamic choices of the module according to provided keyword arguments. @@ -1048,10 +1049,15 @@ def modify( ffn_hidden_size_divisor: The divisor of the mlp ffn_hidden_size. mamba_head_dim_divisor: The divisor of the mamba head_dim. num_moe_experts_divisor: The divisor of the number of MoE experts. + num_layers_divisor: The divisor of the number of layers. """ - hp = self.get_hparam("hidden_size") - choices = {int(make_divisible(c, hidden_size_divisor)) for c in hp.choices} # type: ignore[arg-type] - hp.choices = list(set(hp.choices) & choices | {hp.original}) + for hp_name, divisor in [ + ("hidden_size", hidden_size_divisor), + ("num_layers", num_layers_divisor), + ]: + hp = self.get_hparam(hp_name) + choices = {int(make_divisible(c, divisor)) for c in hp.choices} # type: ignore[arg-type] + hp.choices = list(set(hp.choices) & choices | {hp.original}) for layer in self.decoder.layers: layer.modify( diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index 40d5d608b..e1c6e3b81 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -37,6 +37,7 @@ get_pipeline_model_parallel_group, get_pipeline_model_parallel_rank, get_pipeline_model_parallel_world_size, + get_tensor_model_parallel_group, ) from megatron.core.tensor_parallel import ( gather_from_tensor_model_parallel_region, @@ -58,7 +59,7 @@ _DynamicTransformerLayer, ) from modelopt.torch.nas.registry import DMRegistry -from modelopt.torch.nas.utils import get_subnet_config, sample, sample_and_reset, sort_parameters +from modelopt.torch.nas.utils import get_subnet_config, sample, sort_parameters from modelopt.torch.opt.config import ModeloptBaseConfig, get_kwargs_for_create_model_with_rules from modelopt.torch.opt.conversion import ApplyModeError from modelopt.torch.opt.dynamic import DynamicModule, DynamicSpace @@ -71,8 +72,7 @@ from modelopt.torch.opt.searcher import BaseSearcher, SearchConfig, SearchStateDict from modelopt.torch.opt.utils import named_hparams from modelopt.torch.utils import distributed as dist -from modelopt.torch.utils import get_module_device, print_rank_0 -from modelopt.torch.utils.plugins import param_num_megatron +from modelopt.torch.utils import get_module_device, num2hrb, print_rank_0 from ..pruning import PruneModeRegistry @@ -170,7 +170,15 @@ def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[i class MCoreMinitronSearcher(BaseSearcher): - """Searcher for Minitron pruning algorithm.""" + """Searcher for Minitron pruning algorithm. + + Available additional config options: + - `max_width_pruning`: Maximum fraction per width hyperparameter to prune (default: 0.5). + Only top (1 - max_width_pruning) choices will be considered. + - `max_depth_pruning`: Maximum fraction per depth hyperparameter to prune (default: 0.25). + Only top (1 - max_depth_pruning) choices will be considered. + - `top_k`: Number of candidates to consider for score_func validation (default: 10). + """ activations_per_rank: list[dict[str, torch.Tensor]] layer_scores: dict[int, torch.Tensor] @@ -185,9 +193,9 @@ def default_search_config(self) -> SearchConfig: "skip_sorting": False, "scores_path": None, # Additional search config for parameter-based pruning - "max_width_pruning": 0.5, # Maximum fraction per width hyperparameter to prune - "max_depth_pruning": 0.25, # Maximum fraction per depth hyperparameter to prune - "top_k": 10, # Number of candidates to consider for score_func validation + "max_width_pruning": 0.5, + "max_depth_pruning": 0.25, + "top_k": 10, } @property @@ -234,7 +242,7 @@ def before_search(self) -> None: # If a user only prunes depth, we should not sort width parameters self.hps_to_sort = set(export_config.keys()) else: - assert isinstance(self.constraints["params"], float), "params must be a float!" + assert isinstance(self.constraints["params"], (int, float)), "params must be a float!" assert self.has_score, "score_func (e.g. MMLU) is required for parameter-based pruning!" export_config = None # Sort all parameters for parameter-based pruning @@ -251,17 +259,17 @@ def before_search(self) -> None: ) hp.reset_choices() # Make sure ConcatHparam choices are updated after modify() - def run_search(self) -> None: - """Run actual search.""" - # Run forward loop to collect activations and sort parameters unwrapped_model = self.model for m in self.model.modules(): if isinstance(m, _DynamicMCoreLanguageModel): unwrapped_model = m break assert isinstance(unwrapped_model, _DynamicMCoreLanguageModel), "Model not supported!" + self.unwrapped_model = unwrapped_model - registry = ImportanceEstimatorRegistry(unwrapped_model) + def run_search(self) -> None: + """Run forward loop to collect activations, sort parameters, and prune the model.""" + registry = ImportanceEstimatorRegistry(self.unwrapped_model) if self.layer_scores and self.activations_per_rank: # Available from checkpoint registry.set_activations_and_layer_scores(self.activations_per_rank, self.layer_scores) elif not self.config["skip_sorting"]: @@ -283,51 +291,75 @@ def run_search(self) -> None: else: sort_parameters(self.model, self.hps_to_sort, verbose=True) + if self.layer_scores: + # sort layers by scores and drop the lowest ones + sorted_layers = [ + layer + for layer, _ in sorted(self.layer_scores.items(), key=lambda x: x[1], reverse=True) + ] + else: + assert ( + self.constraints.keys() == {"export_config"} + and "num_layers" not in self.constraints["export_config"] + ), "Cannot prune `num_layers` without collecting layer scores!" + sorted_layers = None + if "params" in self.constraints: - export_config = self.search_best_arch_by_params( - max_params=self.constraints["params"], # type: ignore[arg-type] - max_width_pruning=self.config["max_width_pruning"], - max_depth_pruning=self.config["max_depth_pruning"], - top_k=self.config["top_k"], - ) + assert sorted_layers is not None + export_config = self.search_best_arch_by_params(sorted_layers=sorted_layers) else: export_config = self.constraints["export_config"] # Prune homogeneously - assert isinstance(export_config, dict) # to keep mypy happy + self._prune( + export_config, prune_depth=True, update_config=True, sorted_layers=sorted_layers + ) + + registry.cleanup() + + def _prune( + self, + export_config: dict, + prune_depth: bool = True, + update_config: bool = True, + *, + sorted_layers: list[int] | None = None, + ) -> None: + """Prune the model homogeneously based on the export_config by setting active choices for configurable hparams. + + Args: + export_config: Dictionary mapping hyperparameter names to their pruned values. + prune_depth: Whether to drop layers based on sorted_layers (default: True). + update_config: Whether to update the model config with the pruned architecture (default: True). + sorted_layers: Sorted list of layers (1-indexed) for depth pruning. + """ + # Prune homogeneously for n, hp in named_hparams(self.model, configurable=True): hp_name = n.split(".")[-1] if hp_name in export_config: hp.active = export_config[hp_name] # Drop layers if depth pruning is enabled - num_layers_hp = unwrapped_model.get_hparam("num_layers") - if num_layers_hp.active != num_layers_hp.max: - # sort layers by scores and drop the lowest ones - sorted_layers = sorted(self.layer_scores.items(), key=lambda x: x[1], reverse=True) - layers_to_drop = [layer for layer, _ in sorted_layers[num_layers_hp.active :]] # type: ignore[misc] - drop_mcore_language_model_layers(self.model, layers_to_drop=layers_to_drop) + if prune_depth: + num_layers_hp = self.unwrapped_model.get_hparam("num_layers") + if num_layers_hp.active != num_layers_hp.max: + assert sorted_layers is not None + layers_to_drop = sorted_layers[num_layers_hp.active :] # type: ignore[misc] + drop_mcore_language_model_layers(self.model, layers_to_drop=layers_to_drop) # Update model config with pruned architecture # kv_channels can be None so we need to save original from original hidden_size and num_attention_heads - orig_kv_channels = self.model.config.kv_channels - if orig_kv_channels is None: - orig_kv_channels = ( - self.model.config.hidden_size // self.model.config.num_attention_heads - ) - self.model.config.kv_channels = orig_kv_channels - for hp_name, hp_value in export_config.items(): - setattr(self.model.config, hp_name, hp_value) - - registry.cleanup() + if update_config: + orig_kv_channels = self.model.config.kv_channels + if orig_kv_channels is None: + orig_kv_channels = ( + self.model.config.hidden_size // self.model.config.num_attention_heads + ) + self.model.config.kv_channels = orig_kv_channels + for hp_name, hp_value in export_config.items(): + setattr(self.model.config, hp_name, hp_value) - def search_best_arch_by_params( - self, - max_params: float, - max_width_pruning: float = 0.5, - max_depth_pruning: float = 0.25, - top_k: int = 10, - ) -> dict: + def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict: """Search for the best architecture based on the given parameters constraints. We perform a grid-search over the search space to find subnets (homogeneous) fitting the constraints. @@ -335,18 +367,18 @@ def search_best_arch_by_params( and the best subnet is returned. Args: - max_params: Maximum number of parameters for the pruned model. - max_width_pruning: Maximum fraction per width hyperparameter to prune (default: 0.5). - Only top (1 - max_width_pruning) choices will be considered. - max_depth_pruning: Maximum fraction per depth hyperparameter to prune (default: 0.25). - Only top (1 - max_depth_pruning) choices will be considered. - top_k: Number of candidates to consider for score_func validation. + sorted_layers: Sorted list of layer numbers (1-indexed) for depth pruning. Returns: export_config: Dictionary mapping hyperparameter names to their pruned values. """ + assert sorted(sorted_layers) == list(range(1, self.model.config.num_layers + 1)) + max_params = float(self.constraints["params"]) # type: ignore[arg-type] + max_width_pruning = self.config["max_width_pruning"] + max_depth_pruning = self.config["max_depth_pruning"] + top_k = self.config["top_k"] print_rank_0( - f"\nSearching for the best pruned architecture under {max_params / 1e9:.2f}B params constraints" + f"\nSearching for the best pruned architecture under {num2hrb(max_params)} params constraints..." ) # 1. Find available search space choices (across all PP ranks) @@ -367,23 +399,26 @@ def search_best_arch_by_params( max_width_pruning, max_depth_pruning, ) - sample(self.model, sample_func=max) # reset for sanity + sample(self.model, sample_func=max) # reset to max subnet (for sanity) selected: list[tuple[dict, float]] = [] - for config in tqdm( + for ss_config in tqdm( search_space_configs, desc=f"Finding top {top_k} candidates fitting the constraints...", disable=not dist.is_master(), ): - # Convert search space config to fnmatch pattern and sample function - # Use partial to bind each value at creation time (avoid late-binding closure issue) - sample_func = { - f"*.{k}": partial(lambda val, choices: val, v) for k, v in config.items() - } - with sample_and_reset(self.model, sample_func=sample_func): # type: ignore[arg-type] - candidate_params = param_num_megatron(self.model) - if candidate_params <= max_params: - selected.append((config, candidate_params)) + self._prune(ss_config, prune_depth=False, update_config=False) + layer_ids = None + if ( + "num_layers" in ss_config + and ss_config["num_layers"] < self.model.config.num_layers + ): + layer_ids = sorted_layers[: ss_config["num_layers"]] + candidate_params = _param_num_dynamic(self.model, layer_numbers_to_count=layer_ids) + if candidate_params <= max_params: + selected.append((ss_config, candidate_params)) + sample(self.model, sample_func=max) # reset to max subnet assert len(selected) > 0, "No subnets found fitting the constraints!" + print_rank_0(f"Found {len(selected)} candidates fitting the constraints!") self.top_k_candidates_per_constraint[max_params] = sorted( selected, key=lambda x: x[1], reverse=True )[:top_k] @@ -453,6 +488,41 @@ def _generate_search_space_combos( return search_space_combos +def _param_num_dynamic( + model: _DynamicMCoreLanguageModel, *, layer_numbers_to_count: list[int] | None = None +) -> float: + """Get the number of parameters in the Dynamic Module (reduced across TP and PP ranks). + + Args: + model: GPTModel or MambaModel converted to a DynamicModule. + layer_numbers_to_count: If specified, only count the parameters of the given layer numbers (1-indexed). + Only needed when input is a DynamicModule to correctly count the parameters of the active layers. + """ + + # NOTE: model.parameters() doesnt consider active_slice so we dont get sorted or trimmed parameters! + def get_param_count(mod, name) -> int: + """Use getattr to access parameters correctly.""" + module_path, _, param_name = name.rpartition(".") + submodule = mod.get_submodule(module_path) if module_path else mod + return getattr(submodule, param_name).numel() + + # Account for depth pruning with uneven PP and hybrid models! + params = sum( + get_param_count(model, name) + for name, _ in model.named_parameters() + if ("decoder.layers." not in name or layer_numbers_to_count is None) + ) + if layer_numbers_to_count is not None: + for layer in model.decoder.layers: + if layer.layer_number in layer_numbers_to_count: + params += sum(get_param_count(layer, name) for name, _ in layer.named_parameters()) + + reduced_params = torch.Tensor([params]).to(device=next(model.parameters()).device) + torch.distributed.all_reduce(reduced_params, group=get_pipeline_model_parallel_group()) + torch.distributed.all_reduce(reduced_params, group=get_tensor_model_parallel_group()) + return reduced_params.item() + + MCoreMinitronConfig: type[ModeloptBaseConfig] = create_model( "MCoreMinitronConfig", **get_kwargs_for_create_model_with_rules( @@ -462,6 +532,7 @@ def _generate_search_space_combos( "hidden_size_divisor": 256, "ffn_hidden_size_divisor": 256, "num_moe_experts_divisor": 8, + "num_layers_divisor": 2, }, **( { @@ -470,6 +541,7 @@ def _generate_search_space_combos( "ffn_hidden_size_divisor": 256, "mamba_head_dim_divisor": 8, "num_moe_experts_divisor": 8, + "num_layers_divisor": 2, } } if HAS_MAMBA @@ -482,7 +554,11 @@ def _generate_search_space_combos( def get_mcore_minitron_config( - channel_divisor: int = 256, mamba_head_dim_divisor: int = 8, num_moe_experts_divisor: int = 8 + *, + channel_divisor: int = 256, + mamba_head_dim_divisor: int = 8, + num_moe_experts_divisor: int = 8, + num_layers_divisor: int = 2, ) -> ModeloptBaseConfig: """Get a MCoreMinitronConfig with the given channel divisor instead of default.""" config = MCoreMinitronConfig() @@ -497,6 +573,8 @@ def _set_divisors(c): c[k] = mamba_head_dim_divisor elif k == "num_moe_experts_divisor": c[k] = num_moe_experts_divisor + elif k == "num_layers_divisor": + c[k] = num_layers_divisor _set_divisors(config) return config diff --git a/modelopt/torch/prune/pruning.py b/modelopt/torch/prune/pruning.py index cdc4e7d8f..50a4850ea 100644 --- a/modelopt/torch/prune/pruning.py +++ b/modelopt/torch/prune/pruning.py @@ -78,7 +78,7 @@ def prune( constraints = {"params": "60%"} # Specify export_config with pruned hyperparameters - # This is supported and required if the model is converted via ``mcore_minitron`` mode. + # This is supported only if the model is converted via ``mcore_minitron`` mode. constraints = { "export_config": { "ffn_hidden_size": 128, diff --git a/modelopt/torch/utils/logging.py b/modelopt/torch/utils/logging.py index b8e7aecce..ada1b5361 100644 --- a/modelopt/torch/utils/logging.py +++ b/modelopt/torch/utils/logging.py @@ -46,7 +46,7 @@ def num2hrb(num: float, suffix="") -> str: """Convert big floating number to human readable string.""" step = 1000 # step between units is 1000 - units = ["", "K", "M", "G", "T", "P", "E"] + units = ["", "K", "M", "B", "T", "P", "E"] while abs(num) >= step and len(units) > 1: num /= step units.pop(0) diff --git a/modelopt/torch/utils/network.py b/modelopt/torch/utils/network.py index e18c85c3b..a6a5dfc3e 100644 --- a/modelopt/torch/utils/network.py +++ b/modelopt/torch/utils/network.py @@ -112,6 +112,13 @@ def param_num(network: nn.Module, trainable_only: bool = False, unit=1e6) -> flo Returns: The number of parameters in the model in the given unit. """ + from modelopt.torch.opt.dynamic import DynamicModule + + if isinstance(network, DynamicModule): + # NOTE: model.parameters() doesnt consider active_slice so we dont get sorted or trimmed parameters! + raise NotImplementedError( + "param_num doesn't support DynamicModule. Please use param_num_from_forward instead." + ) return ( sum( p.numel() if not trainable_only or p.requires_grad else 0 diff --git a/modelopt/torch/utils/plugins/megatron_model.py b/modelopt/torch/utils/plugins/megatron_model.py index 5ea2a7236..7e4c90c21 100644 --- a/modelopt/torch/utils/plugins/megatron_model.py +++ b/modelopt/torch/utils/plugins/megatron_model.py @@ -46,9 +46,14 @@ def param_num_megatron( Returns: The number of parameters in the model (reduced across TP and PP ranks). """ + from modelopt.torch.opt.dynamic import DynamicModule + if from_forward: assert args is not None, "args must be provided if from_forward is True" params = int(param_num_from_forward(model, args, unit=1.0)) + elif isinstance(model, DynamicModule): + # NOTE: model.parameters() doesnt consider active_slice so we dont get sorted or trimmed parameters! + raise NotImplementedError("DynamicModule input is not supported without from_forward.") else: params = sum(p.numel() for p in model.parameters()) reduced_params = torch.Tensor([params]).to(device=next(model.parameters()).device) diff --git a/tests/_test_utils/torch/nas_prune/minitron_common.py b/tests/_test_utils/torch/nas_prune/minitron_common.py index 97b12a4ca..318ffdda8 100644 --- a/tests/_test_utils/torch/nas_prune/minitron_common.py +++ b/tests/_test_utils/torch/nas_prune/minitron_common.py @@ -26,6 +26,7 @@ def prune_minitron(model, export_config, config, channel_divisor=64): channel_divisor=channel_divisor, mamba_head_dim_divisor=4, num_moe_experts_divisor=1, + num_layers_divisor=1, ), ) ], diff --git a/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py b/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py index 16b45cdb0..412b9ad37 100644 --- a/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py +++ b/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py @@ -83,7 +83,15 @@ def _test_gpt_search_space( normalization=normalization, ).cuda() - mtn.convert(model, [("mcore_minitron", get_mcore_minitron_config(channel_divisor))]) + mtn.convert( + model, + [ + ( + "mcore_minitron", + get_mcore_minitron_config(channel_divisor=channel_divisor, num_layers_divisor=1), + ) + ], + ) assert isinstance(model, _DynamicMCoreLanguageModel) for m in model.modules(): @@ -255,7 +263,14 @@ def _test_gpt_moe_search_space(rank, size): mtn.convert( model, - [("mcore_minitron", get_mcore_minitron_config(channel_divisor, num_moe_experts_divisor=1))], + [ + ( + "mcore_minitron", + get_mcore_minitron_config( + channel_divisor=channel_divisor, num_moe_experts_divisor=1, num_layers_divisor=1 + ), + ) + ], ) moe = model.decoder.layers[0].mlp diff --git a/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py b/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py index 6a1bc7a8a..868647eb6 100644 --- a/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py +++ b/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py @@ -77,7 +77,16 @@ def _test_mamba_search_space(rank, size): mtn.convert( model, - [("mcore_minitron", get_mcore_minitron_config(channel_divisor, mamba_head_dim_divisor))], + [ + ( + "mcore_minitron", + get_mcore_minitron_config( + channel_divisor=channel_divisor, + mamba_head_dim_divisor=mamba_head_dim_divisor, + num_layers_divisor=1, + ), + ) + ], ) assert isinstance(model, _DynamicMCoreLanguageModel) diff --git a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py index 2f1eae76b..81bcae04e 100644 --- a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py +++ b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py @@ -78,7 +78,7 @@ def _test_mcore_gpt_parameter_sorting(activation_func, rank, size): model.eval() dynamic_space = _convert_model_to_dynamic_space( - model, get_mcore_minitron_config(channel_divisor) + model, get_mcore_minitron_config(channel_divisor=channel_divisor) ) registry = ImportanceEstimatorRegistry(model) # register imp estimators and forward hooks From 04cf10c5a4d4504f3ed21f131449709e19aa6f80 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Fri, 2 Jan 2026 04:16:35 -0800 Subject: [PATCH 3/9] Add score calculation logic Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- modelopt/torch/nas/plugins/megatron.py | 3 +- .../torch/prune/plugins/mcore_minitron.py | 51 ++++++++++++++++--- 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/modelopt/torch/nas/plugins/megatron.py b/modelopt/torch/nas/plugins/megatron.py index 917b6e7c0..7447af5ef 100644 --- a/modelopt/torch/nas/plugins/megatron.py +++ b/modelopt/torch/nas/plugins/megatron.py @@ -52,7 +52,7 @@ from modelopt.torch.opt.searcher import ConstraintsDict from modelopt.torch.trace import Symbol from modelopt.torch.utils import distributed as dist -from modelopt.torch.utils import make_divisible, print_rank_0, random +from modelopt.torch.utils import make_divisible, random from modelopt.torch.utils.plugins import param_num_megatron from ..algorithms import ( @@ -627,7 +627,6 @@ def modify( def _export_reinit_token_dispatcher(self) -> None: """Reinitialize the token dispatcher after pruning.""" - print_rank_0("Reinitializing token dispatcher after pruning") if hasattr(moe_utils, "get_default_model_comm_pgs"): model_comm_pgs = moe_utils.get_default_model_comm_pgs() else: diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index e1c6e3b81..5ed8fad0d 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -25,6 +25,7 @@ """ from collections.abc import Callable +from dataclasses import dataclass from functools import partial from itertools import product from typing import Any @@ -54,6 +55,7 @@ _DynamicMambaMixer, _DynamicMCoreLanguageModel, _DynamicMLP, + _DynamicMoELayer, _DynamicSelfAttention, _DynamicSequentialMLP, _DynamicTransformerLayer, @@ -169,6 +171,13 @@ def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[i model.config.num_layers = new_num_layers +@dataclass +class CandidateSubnet: + ss_config: dict + params: float + score: float | None + + class MCoreMinitronSearcher(BaseSearcher): """Searcher for Minitron pruning algorithm. @@ -182,7 +191,8 @@ class MCoreMinitronSearcher(BaseSearcher): activations_per_rank: list[dict[str, torch.Tensor]] layer_scores: dict[int, torch.Tensor] - top_k_candidates_per_constraint: dict[float, list[tuple[dict, float]]] + # Dict from params constraint to list of tuples (ss_config, params, score) + top_k_candidates_per_constraint: dict[float, list[CandidateSubnet]] @property def default_search_config(self) -> SearchConfig: @@ -359,6 +369,12 @@ def _prune( for hp_name, hp_value in export_config.items(): setattr(self.model.config, hp_name, hp_value) + # Reinitialize the MoE token dispatcher after pruning + for m in self.model.modules(): + if isinstance(m, _DynamicMoELayer): + m._export_reinit_token_dispatcher() + break + def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict: """Search for the best architecture based on the given parameters constraints. @@ -400,7 +416,7 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict: max_depth_pruning, ) sample(self.model, sample_func=max) # reset to max subnet (for sanity) - selected: list[tuple[dict, float]] = [] + selected = [] for ss_config in tqdm( search_space_configs, desc=f"Finding top {top_k} candidates fitting the constraints...", @@ -415,23 +431,44 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict: layer_ids = sorted_layers[: ss_config["num_layers"]] candidate_params = _param_num_dynamic(self.model, layer_numbers_to_count=layer_ids) if candidate_params <= max_params: - selected.append((ss_config, candidate_params)) + selected.append(CandidateSubnet(ss_config, candidate_params, None)) sample(self.model, sample_func=max) # reset to max subnet assert len(selected) > 0, "No subnets found fitting the constraints!" print_rank_0(f"Found {len(selected)} candidates fitting the constraints!") self.top_k_candidates_per_constraint[max_params] = sorted( - selected, key=lambda x: x[1], reverse=True + selected, key=lambda x: x.params, reverse=True )[:top_k] self.save_search_checkpoint(verbose=True) else: print_rank_0(f"Using top {top_k} candidates from checkpoint") top_k_candidates = self.top_k_candidates_per_constraint[max_params] + print_rank_0(f"\n====================\nTop {top_k} candidates:") + for candidate in top_k_candidates: + print_rank_0(f"\t{candidate.ss_config} -> {num2hrb(candidate.params)} params") + print_rank_0("====================\n") + # 3. Validate top-k candidates using the score_func and return the best subnet - # TODO: update this - best = top_k_candidates[0][0] + for candidate in tqdm( + top_k_candidates, + desc=f"Validating top {top_k} candidates on given score_func...", + disable=not dist.is_master(), + ): + if candidate.score is None: # not restored from checkpoint + self._prune(candidate.ss_config, prune_depth=False, update_config=False) + candidate.score = self.eval_score(silent=True) + sample(self.model, sample_func=max) # reset to max subnet + self.save_search_checkpoint(verbose=False) + print_rank_0( + f"\t{candidate.ss_config} -> {num2hrb(candidate.params)} params, {candidate.score:.4f} score" + ) - return best + dist.barrier() + best = max(top_k_candidates, key=lambda x: x.score) # type: ignore[arg-type, return-value] + print_rank_0( + f"\n[BEST SUBNET] {best.ss_config} -> {num2hrb(best.params)} params, {best.score:.4f} score\n" + ) + return best.ss_config @staticmethod def _generate_search_space_combos( From 514e8d99dc366d9741baea50938dda4ab32b120f Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Sun, 4 Jan 2026 11:11:56 -0800 Subject: [PATCH 4/9] Fix param count for shared embeddings and output layer Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- modelopt/torch/nas/plugins/megatron.py | 101 +----------------- .../torch/prune/plugins/mcore_minitron.py | 38 ++++++- modelopt/torch/utils/plugins/__init__.py | 3 - .../torch/utils/plugins/megatron_generate.py | 2 + modelopt/torch/utils/plugins/megatron_mmlu.py | 2 + .../torch/utils/plugins/megatron_model.py | 62 ----------- .../utils/plugins/megatron_preprocess_data.py | 2 + 7 files changed, 42 insertions(+), 168 deletions(-) delete mode 100644 modelopt/torch/utils/plugins/megatron_model.py diff --git a/modelopt/torch/nas/plugins/megatron.py b/modelopt/torch/nas/plugins/megatron.py index 7447af5ef..83fa3c1d8 100644 --- a/modelopt/torch/nas/plugins/megatron.py +++ b/modelopt/torch/nas/plugins/megatron.py @@ -18,24 +18,18 @@ import types from abc import ABC from collections.abc import Callable, Sequence -from typing import Any import torch import torch.nn as nn from megatron.core.fusions.fused_layer_norm import FusedLayerNorm from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.gpt import GPTModel -from megatron.core.parallel_state import ( - get_data_parallel_group, - is_pipeline_first_stage, - is_pipeline_last_stage, -) +from megatron.core.parallel_state import is_pipeline_first_stage, is_pipeline_last_stage from megatron.core.tensor_parallel.layers import ( ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, ) -from megatron.core.transformer import MegatronModule from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.mlp import MLP @@ -49,24 +43,14 @@ from modelopt.torch.nas.modules import DynamicModuleList from modelopt.torch.opt.dynamic import DynamicModule from modelopt.torch.opt.hparam import HPType -from modelopt.torch.opt.searcher import ConstraintsDict from modelopt.torch.trace import Symbol from modelopt.torch.utils import distributed as dist -from modelopt.torch.utils import make_divisible, random -from modelopt.torch.utils.plugins import param_num_megatron - -from ..algorithms import ( - MODULE_TYPE_TO_CONSTRAINTS_FUNC, - ConstraintEvalFunc, - ConstraintInterpolator, - ConstraintsFunc, - ConstraintsRes, -) +from modelopt.torch.utils import make_divisible + from ..hparams.concat import build_concat_hp from ..modules import _DynamicLayerNorm from ..modules.utils import get_sliced_tensor, get_sliced_tensor_by_slices from ..registry import DMRegistry -from ..search_space import SampleFunc from ..traced_hp import TracedHp SUPPORTED_MODELS = {GPTModel: "megatron.core.models.gpt.GPTModel"} @@ -1079,82 +1063,3 @@ def export(self) -> torch.nn.Module: ).export() self.output_layer.export() return super().export() - - -class MegatronConstraintsFunc(ConstraintsFunc): - """A Functor class to check if sub-net satisfied all provided constraints. - - We intentionally expose some attributes like `limits` s.t. we can modify it manually. - """ - - _sample_points_dict: dict[tuple[str, ...], dict[str, SampleFunc]] = { - ("params",): {"min": min, "centroid": random.centroid, "max": max}, - } - - def __init__( - self, - model: MegatronModule, - constraints: ConstraintsDict, - dummy_input: Any | tuple[Any, ...], - deployment: dict | None = None, - fast_eval: bool = True, - ): - """Initialize with additional data parallel group info from megatron.""" - for key in constraints: - if key != "params": - raise ValueError("Only params constraints is supported for MegatronModule!") - - self.model = model - self.dummy_input = dummy_input - self.deployment = deployment - self._fast_eval = fast_eval - - # Getting data parallel group for - self.dp_group = get_data_parallel_group() - - # initialize latency interpolator - keys_for_interpolation = ("params",) - if ConstraintsFunc.is_configurable(self.model, "depth"): - keys_for_interpolation += ("flops_min_depth",) - self._latency_interpolator = ConstraintInterpolator( - self.model, - points_funcs={k: self.constraint_eval_funcs[k] for k in keys_for_interpolation}, - value_func=self._get_true_latency, - ) - # set fast/regular mode for latency interpolator - self._latency_interpolator.collect_mode = not self.fast_eval - - # set limit at the end with setter to use sanity checks on constraints - self._limits = {} - self.limits = constraints - - @property - def constraint_eval_funcs(self) -> dict[str, ConstraintEvalFunc]: - """Get constraint eval fns.""" - return { - "params": self._get_params, - } - - def _get_params(self, _: ConstraintsRes | None = None) -> float: - """Get number of model parameters from forward pass.""" - return param_num_megatron(self.model, from_forward=True, args=self.dummy_input) - - def _get_flops(self, _: ConstraintsRes | None = None) -> float: - """Get inference FLOPs.""" - raise NotImplementedError - - def _get_flops_min_depth(self, _: ConstraintsRes | None = None) -> float: - """Get inference FLOPs with depth set to minimum.""" - raise NotImplementedError - - def _get_true_latency(self, _: ConstraintsRes | None = None) -> float: - """Get true inference latency.""" - raise NotImplementedError - - def _get_latency(self, precomputed: ConstraintsRes | None = None) -> float: - """Get inference latency from interpolator.""" - raise NotImplementedError - - -# Clear the mapping and reinsert. -MODULE_TYPE_TO_CONSTRAINTS_FUNC[MegatronModule] = MegatronConstraintsFunc diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index 5ed8fad0d..51684fd6c 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -34,6 +34,8 @@ import torch import torch.nn as nn import torch.nn.functional as F +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.mamba.mamba_model import MambaModel from megatron.core.parallel_state import ( get_pipeline_model_parallel_group, get_pipeline_model_parallel_rank, @@ -103,6 +105,7 @@ "MCoreMinitronSearcher", "drop_mcore_language_model_layers", "get_mcore_minitron_config", + "get_mcore_param_count", ] @@ -300,6 +303,7 @@ def run_search(self) -> None: print_rank_0("Skipping sorting parameters...") else: sort_parameters(self.model, self.hps_to_sort, verbose=True) + registry.cleanup() if self.layer_scores: # sort layers by scores and drop the lowest ones @@ -325,8 +329,6 @@ def run_search(self) -> None: export_config, prune_depth=True, update_config=True, sorted_layers=sorted_layers ) - registry.cleanup() - def _prune( self, export_config: dict, @@ -419,7 +421,7 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict: selected = [] for ss_config in tqdm( search_space_configs, - desc=f"Finding top {top_k} candidates fitting the constraints...", + desc=f"Finding top {top_k} (`config['top_k']`) candidates fitting the constraints...", disable=not dist.is_master(), ): self._prune(ss_config, prune_depth=False, update_config=False) @@ -451,12 +453,12 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict: # 3. Validate top-k candidates using the score_func and return the best subnet for candidate in tqdm( top_k_candidates, - desc=f"Validating top {top_k} candidates on given score_func...", + desc=f"Validating top {top_k} candidates on given score_func (this will take some time)...", disable=not dist.is_master(), ): if candidate.score is None: # not restored from checkpoint self._prune(candidate.ss_config, prune_depth=False, update_config=False) - candidate.score = self.eval_score(silent=True) + candidate.score = self.eval_score(silent=False) sample(self.model, sample_func=max) # reset to max subnet self.save_search_checkpoint(verbose=False) print_rank_0( @@ -525,6 +527,30 @@ def _generate_search_space_combos( return search_space_combos +def get_mcore_param_count(model: GPTModel | MambaModel) -> float: + """Get the number of parameters in the MCore GPTModel or MambaModel (reduced across TP and PP ranks).""" + assert isinstance(model, (GPTModel, MambaModel)), "Model must be a GPTModel or MambaModel" + if isinstance(model, DynamicModule): + return _param_num_dynamic(model) + else: + return _param_num(model) + + +def _param_num(model: GPTModel | MambaModel) -> float: + """Get the number of parameters in the model (reduced across TP and PP ranks).""" + # Dont double count output_layer parameters if model.share_embeddings_and_output_weights is True + params = sum( + p.numel() + for name, p in model.named_parameters() + if not model.share_embeddings_and_output_weights or "output_layer.weight" not in name + ) + + reduced_params = torch.Tensor([params]).to(device=next(model.parameters()).device) + torch.distributed.all_reduce(reduced_params, group=get_pipeline_model_parallel_group()) + torch.distributed.all_reduce(reduced_params, group=get_tensor_model_parallel_group()) + return reduced_params.item() + + def _param_num_dynamic( model: _DynamicMCoreLanguageModel, *, layer_numbers_to_count: list[int] | None = None ) -> float: @@ -544,10 +570,12 @@ def get_param_count(mod, name) -> int: return getattr(submodule, param_name).numel() # Account for depth pruning with uneven PP and hybrid models! + # Dont double count output_layer parameters if model.share_embeddings_and_output_weights is True params = sum( get_param_count(model, name) for name, _ in model.named_parameters() if ("decoder.layers." not in name or layer_numbers_to_count is None) + and not (model.share_embeddings_and_output_weights and "output_layer.weight" in name) ) if layer_numbers_to_count is not None: for layer in model.decoder.layers: diff --git a/modelopt/torch/utils/plugins/__init__.py b/modelopt/torch/utils/plugins/__init__.py index ac1053aa2..517c59914 100644 --- a/modelopt/torch/utils/plugins/__init__.py +++ b/modelopt/torch/utils/plugins/__init__.py @@ -23,8 +23,5 @@ with import_plugin("megatron_mmlu"): from .megatron_mmlu import * -with import_plugin("megatron_model"): - from .megatron_model import * - with import_plugin("megatron_preprocess_data"): from .megatron_preprocess_data import * diff --git a/modelopt/torch/utils/plugins/megatron_generate.py b/modelopt/torch/utils/plugins/megatron_generate.py index d542d935a..d83006711 100644 --- a/modelopt/torch/utils/plugins/megatron_generate.py +++ b/modelopt/torch/utils/plugins/megatron_generate.py @@ -24,6 +24,8 @@ from megatron.core.transformer import MegatronModule from tqdm import tqdm +__all__ = ["megatron_generate", "megatron_prefill"] + def get_current_memory_info(): """Get current memory usage.""" diff --git a/modelopt/torch/utils/plugins/megatron_mmlu.py b/modelopt/torch/utils/plugins/megatron_mmlu.py index 3b997268b..b03d338c0 100644 --- a/modelopt/torch/utils/plugins/megatron_mmlu.py +++ b/modelopt/torch/utils/plugins/megatron_mmlu.py @@ -47,6 +47,8 @@ from .megatron_generate import megatron_generate +__all__ = ["megatron_mmlu"] + def _get_all_subjects(): """All subjects (anatomy, ...) can be acquired from querying all subsets and splits.""" diff --git a/modelopt/torch/utils/plugins/megatron_model.py b/modelopt/torch/utils/plugins/megatron_model.py deleted file mode 100644 index 7e4c90c21..000000000 --- a/modelopt/torch/utils/plugins/megatron_model.py +++ /dev/null @@ -1,62 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""General utilities for Megatron models.""" - -from typing import Any - -import torch -from megatron.core.parallel_state import ( - get_pipeline_model_parallel_group, - get_tensor_model_parallel_group, -) -from megatron.core.transformer.module import MegatronModule - -from ..network import param_num_from_forward - -__all__ = ["param_num_megatron"] - - -def param_num_megatron( - model: MegatronModule, *, from_forward: bool = False, args: Any = None -) -> float: - """Get the number of parameters in the model (reduced across TP and PP ranks). - - Args: - model: The Megatron model. - from_forward: To get the number of params from a forward pass instead of directly counting the params. - This can helpful for MoE or dynamic modules, where the state dict might contain extra parameters that - is not actively used in the model, e.g., because of a DynamicModule that is deactivated for the - forward pass. We circumvent this issue by just counting parameters of modules that appear in a - forward pass. - args: The arguments to pass to the forward pass. Only used if from_forward is True. - - Returns: - The number of parameters in the model (reduced across TP and PP ranks). - """ - from modelopt.torch.opt.dynamic import DynamicModule - - if from_forward: - assert args is not None, "args must be provided if from_forward is True" - params = int(param_num_from_forward(model, args, unit=1.0)) - elif isinstance(model, DynamicModule): - # NOTE: model.parameters() doesnt consider active_slice so we dont get sorted or trimmed parameters! - raise NotImplementedError("DynamicModule input is not supported without from_forward.") - else: - params = sum(p.numel() for p in model.parameters()) - reduced_params = torch.Tensor([params]).to(device=next(model.parameters()).device) - torch.distributed.all_reduce(reduced_params, group=get_pipeline_model_parallel_group()) - torch.distributed.all_reduce(reduced_params, group=get_tensor_model_parallel_group()) - return reduced_params.item() diff --git a/modelopt/torch/utils/plugins/megatron_preprocess_data.py b/modelopt/torch/utils/plugins/megatron_preprocess_data.py index ac05e44f1..94bf268bc 100644 --- a/modelopt/torch/utils/plugins/megatron_preprocess_data.py +++ b/modelopt/torch/utils/plugins/megatron_preprocess_data.py @@ -42,6 +42,8 @@ from megatron.core.datasets import indexed_dataset from transformers import AutoTokenizer +__all__ = ["megatron_preprocess_data"] + class _Encoder: tokenizer: AutoTokenizer = None From 49981e54bdcb4a1a84bc64c5e68b0d37b2a6836c Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Sun, 4 Jan 2026 11:17:56 -0800 Subject: [PATCH 5/9] Fix PP>1 score_func validation Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../torch/prune/plugins/mcore_minitron.py | 27 ++++++++----------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index 51684fd6c..a1ec333d8 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -325,15 +325,12 @@ def run_search(self) -> None: export_config = self.constraints["export_config"] # Prune homogeneously - self._prune( - export_config, prune_depth=True, update_config=True, sorted_layers=sorted_layers - ) + self._prune(export_config, prune_depth=True, sorted_layers=sorted_layers) def _prune( self, export_config: dict, prune_depth: bool = True, - update_config: bool = True, *, sorted_layers: list[int] | None = None, ) -> None: @@ -342,7 +339,6 @@ def _prune( Args: export_config: Dictionary mapping hyperparameter names to their pruned values. prune_depth: Whether to drop layers based on sorted_layers (default: True). - update_config: Whether to update the model config with the pruned architecture (default: True). sorted_layers: Sorted list of layers (1-indexed) for depth pruning. """ # Prune homogeneously @@ -361,15 +357,14 @@ def _prune( # Update model config with pruned architecture # kv_channels can be None so we need to save original from original hidden_size and num_attention_heads - if update_config: - orig_kv_channels = self.model.config.kv_channels - if orig_kv_channels is None: - orig_kv_channels = ( - self.model.config.hidden_size // self.model.config.num_attention_heads - ) - self.model.config.kv_channels = orig_kv_channels - for hp_name, hp_value in export_config.items(): - setattr(self.model.config, hp_name, hp_value) + orig_kv_channels = self.model.config.kv_channels + if orig_kv_channels is None: + orig_kv_channels = ( + self.model.config.hidden_size // self.model.config.num_attention_heads + ) + self.model.config.kv_channels = orig_kv_channels + for hp_name, hp_value in export_config.items(): + setattr(self.model.config, hp_name, hp_value) # Reinitialize the MoE token dispatcher after pruning for m in self.model.modules(): @@ -424,7 +419,7 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict: desc=f"Finding top {top_k} (`config['top_k']`) candidates fitting the constraints...", disable=not dist.is_master(), ): - self._prune(ss_config, prune_depth=False, update_config=False) + self._prune(ss_config, prune_depth=False) layer_ids = None if ( "num_layers" in ss_config @@ -457,7 +452,7 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict: disable=not dist.is_master(), ): if candidate.score is None: # not restored from checkpoint - self._prune(candidate.ss_config, prune_depth=False, update_config=False) + self._prune(candidate.ss_config, prune_depth=False) candidate.score = self.eval_score(silent=False) sample(self.model, sample_func=max) # reset to max subnet self.save_search_checkpoint(verbose=False) From ca95a63d64da54987323f978ed1eb9770df6c63e Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Tue, 6 Jan 2026 08:40:45 -0800 Subject: [PATCH 6/9] Fix dropping and reverting layers for mmlu eval Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../torch/prune/plugins/mcore_minitron.py | 50 +++++++++---------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index a1ec333d8..1ecef5c4b 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -127,7 +127,7 @@ def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[i assert isinstance(model, supported_model_types), ( f"Model should have one of {supported_model_types} submodule, got {model}" ) - print_rank_0(f"Dropping layers {layers_to_drop} from {n} ({type(model)}).") + print_rank_0(f"Dropping decoder layers {layers_to_drop} from model.") # get the number of layers remaining in each pp rank layers_remaining_per_pp = torch.zeros( @@ -151,25 +151,14 @@ def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[i new_num_layers = sum(layers_remaining_per_pp) # reindex kept layers, exclude sharded state dict for dropped layers - layer_offset = sum(layers_remaining_per_pp[: get_pipeline_model_parallel_rank()]) - layer_number = layer_offset + 1 - dropped_layers = [] + layer_number = sum(layers_remaining_per_pp[: get_pipeline_model_parallel_rank()]) + 1 + kept_layers = [] for layer in model.decoder.layers: - if layer.layer_number in layers_to_drop: - layer.layer_number = -1 # should not be used - # layer.sharded_state_dict = lambda prefix, sharded_offsets, metadata: {} - dropped_layers.append(layer) - else: + if layer.layer_number not in layers_to_drop: layer.layer_number = layer_number - layer.get_transformer_layer_offset = lambda: layer_offset layer_number += 1 - - # remove dropped layers from the modulelist - model.decoder.layers = nn.ModuleList( - [layer for layer in model.decoder.layers if layer.layer_number != -1] - ) - for layer in dropped_layers: - del layer + kept_layers.append(layer) + model.decoder.layers = nn.ModuleList(kept_layers) model.config.num_layers = new_num_layers @@ -187,7 +176,7 @@ class MCoreMinitronSearcher(BaseSearcher): Available additional config options: - `max_width_pruning`: Maximum fraction per width hyperparameter to prune (default: 0.5). Only top (1 - max_width_pruning) choices will be considered. - - `max_depth_pruning`: Maximum fraction per depth hyperparameter to prune (default: 0.25). + - `max_depth_pruning`: Maximum fraction per depth hyperparameter to prune (default: 0.2). Only top (1 - max_depth_pruning) choices will be considered. - `top_k`: Number of candidates to consider for score_func validation (default: 10). """ @@ -407,6 +396,7 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict: # 2. Perform grid-search over the search space to find subnets fitting the constraints if max_params not in self.top_k_candidates_per_constraint: + max_num_layers = self.unwrapped_model.get_hparam("num_layers").max search_space_configs = MCoreMinitronSearcher._generate_search_space_combos( hp_choices, # type: ignore[arg-type] max_width_pruning, @@ -421,10 +411,7 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict: ): self._prune(ss_config, prune_depth=False) layer_ids = None - if ( - "num_layers" in ss_config - and ss_config["num_layers"] < self.model.config.num_layers - ): + if "num_layers" in ss_config and ss_config["num_layers"] < max_num_layers: layer_ids = sorted_layers[: ss_config["num_layers"]] candidate_params = _param_num_dynamic(self.model, layer_numbers_to_count=layer_ids) if candidate_params <= max_params: @@ -437,7 +424,7 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict: )[:top_k] self.save_search_checkpoint(verbose=True) else: - print_rank_0(f"Using top {top_k} candidates from checkpoint") + print_rank_0(f"\nUsing top {top_k} candidates from checkpoint") top_k_candidates = self.top_k_candidates_per_constraint[max_params] print_rank_0(f"\n====================\nTop {top_k} candidates:") @@ -452,10 +439,19 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict: disable=not dist.is_master(), ): if candidate.score is None: # not restored from checkpoint - self._prune(candidate.ss_config, prune_depth=False) + all_layers = self.unwrapped_model.decoder.layers + start_layer_number = all_layers[0].layer_number + + self._prune(candidate.ss_config, prune_depth=True, sorted_layers=sorted_layers) candidate.score = self.eval_score(silent=False) - sample(self.model, sample_func=max) # reset to max subnet self.save_search_checkpoint(verbose=False) + + # reset to max subnet and revert dropped layers + sample(self.model, sample_func=max) + for layer in all_layers: + layer.layer_number = start_layer_number + start_layer_number += 1 + self.unwrapped_model.decoder.layers = all_layers print_rank_0( f"\t{candidate.ss_config} -> {num2hrb(candidate.params)} params, {candidate.score:.4f} score" ) @@ -471,7 +467,7 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict: def _generate_search_space_combos( search_space: dict[str, list], max_width_pruning: float = 0.5, - max_depth_pruning: float = 0.25, + max_depth_pruning: float = 0.2, ) -> list[dict[str, Any]]: """Generate all possible combinations of hyperparameters from the search space. @@ -480,7 +476,7 @@ def _generate_search_space_combos( Example: {"hidden_size": [1024, 2048, 3072, 4096], "num_layers": [1, 2, ..., 31, 32]} max_width_pruning: Maximum fraction of width hyperparameters to prune (default: 0.5). Only top (1 - max_width_pruning) choices will be considered. - max_depth_pruning: Maximum fraction of depth hyperparameters to prune (default: 0.25). + max_depth_pruning: Maximum fraction of depth hyperparameters to prune (default: 0.2). Only top (1 - max_depth_pruning) choices will be considered. Returns: From f18f8e20a04954bc717a59d4bed135a9cbdc987d Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Tue, 6 Jan 2026 23:58:24 -0800 Subject: [PATCH 7/9] Allow skipping some hparams in NAS and further restrict search space Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../torch/prune/plugins/mcore_minitron.py | 56 +++++++++++++------ .../torch/nas_prune/minitron_common.py | 3 +- .../test_megatron_gpt_dynamic_modules.py | 11 +++- .../test_megatron_mamba_dynamic_modules.py | 3 +- .../test_mcore_gpt_minitron_pruning.py | 15 ++++- .../test_mcore_mamba_minitron_pruning.py | 7 ++- 6 files changed, 69 insertions(+), 26 deletions(-) diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index 1ecef5c4b..fd78f9159 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -173,11 +173,12 @@ class CandidateSubnet: class MCoreMinitronSearcher(BaseSearcher): """Searcher for Minitron pruning algorithm. - Available additional config options: - - `max_width_pruning`: Maximum fraction per width hyperparameter to prune (default: 0.5). + Available additional config options (used when `params` constraint is provided): + - `max_width_pruning`: Maximum fraction per width hyperparameter to prune (default: 0.40). Only top (1 - max_width_pruning) choices will be considered. - - `max_depth_pruning`: Maximum fraction per depth hyperparameter to prune (default: 0.2). + - `max_depth_pruning`: Maximum fraction per depth hyperparameter to prune (default: 0.20). Only top (1 - max_depth_pruning) choices will be considered. + - `hparams_to_skip`: List of hparams to skip during the search (default: None). - `top_k`: Number of candidates to consider for score_func validation (default: 10). """ @@ -195,8 +196,9 @@ def default_search_config(self) -> SearchConfig: "skip_sorting": False, "scores_path": None, # Additional search config for parameter-based pruning - "max_width_pruning": 0.5, - "max_depth_pruning": 0.25, + "max_width_pruning": 0.40, + "max_depth_pruning": 0.20, + "hparams_to_skip": None, "top_k": 10, } @@ -378,6 +380,7 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict: max_params = float(self.constraints["params"]) # type: ignore[arg-type] max_width_pruning = self.config["max_width_pruning"] max_depth_pruning = self.config["max_depth_pruning"] + hparams_to_skip = self.config["hparams_to_skip"] top_k = self.config["top_k"] print_rank_0( f"\nSearching for the best pruned architecture under {num2hrb(max_params)} params constraints..." @@ -401,6 +404,7 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict: hp_choices, # type: ignore[arg-type] max_width_pruning, max_depth_pruning, + hparams_to_skip, ) sample(self.model, sample_func=max) # reset to max subnet (for sanity) selected = [] @@ -466,18 +470,20 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict: @staticmethod def _generate_search_space_combos( search_space: dict[str, list], - max_width_pruning: float = 0.5, - max_depth_pruning: float = 0.2, + max_width_pruning: float = 0.40, + max_depth_pruning: float = 0.20, + hparams_to_skip: list[str] | None = None, ) -> list[dict[str, Any]]: """Generate all possible combinations of hyperparameters from the search space. Args: search_space: Dictionary mapping hyperparameter names to their possible sorted choices. Example: {"hidden_size": [1024, 2048, 3072, 4096], "num_layers": [1, 2, ..., 31, 32]} - max_width_pruning: Maximum fraction of width hyperparameters to prune (default: 0.5). + max_width_pruning: Maximum fraction of width hyperparameters to prune (default: 0.40). Only top (1 - max_width_pruning) choices will be considered. - max_depth_pruning: Maximum fraction of depth hyperparameters to prune (default: 0.2). + max_depth_pruning: Maximum fraction of depth hyperparameters to prune (default: 0.20). Only top (1 - max_depth_pruning) choices will be considered. + hparams_to_skip: List of hparams to skip during the search (default: None). Returns: List of configuration dictionaries, where each dictionary maps hyperparameter @@ -494,11 +500,22 @@ def _generate_search_space_combos( f"{max_depth_pruning * 100:.0f}% for depth pruning hparams" ) + if hparams_to_skip: + print_rank_0(f"Skipping {hparams_to_skip=} during search space generation...") + for hparam in hparams_to_skip: + if hparam in search_space: + search_space.pop(hparam) + else: + warn(f"Hparam {hparam} not found in search space! Skipping...") + filtered_ss = { - k: sorted(v)[int((1 - max_depth_pruning) * len(v)) :] - if k == "num_layers" - else sorted(v)[int((1 - max_width_pruning) * len(v)) :] + k: ( + sorted(v)[int((1 - max_depth_pruning) * len(v)) :] + if k == "num_layers" + else sorted(v)[int((1 - max_width_pruning) * len(v)) :] + ) for k, v in search_space.items() + if len(v) > 1 } ss_size = 1 @@ -586,7 +603,7 @@ def get_param_count(mod, name) -> int: default_rules={ "megatron.core.models.gpt.GPTModel": { "hidden_size_divisor": 256, - "ffn_hidden_size_divisor": 256, + "ffn_hidden_size_divisor": 512, "num_moe_experts_divisor": 8, "num_layers_divisor": 2, }, @@ -594,7 +611,7 @@ def get_param_count(mod, name) -> int: { "megatron.core.models.mamba.MambaModel": { "hidden_size_divisor": 256, - "ffn_hidden_size_divisor": 256, + "ffn_hidden_size_divisor": 512, "mamba_head_dim_divisor": 8, "num_moe_experts_divisor": 8, "num_layers_divisor": 2, @@ -611,20 +628,23 @@ def get_param_count(mod, name) -> int: def get_mcore_minitron_config( *, - channel_divisor: int = 256, + hidden_size_divisor: int = 256, + ffn_hidden_size_divisor: int = 512, mamba_head_dim_divisor: int = 8, num_moe_experts_divisor: int = 8, num_layers_divisor: int = 2, ) -> ModeloptBaseConfig: - """Get a MCoreMinitronConfig with the given channel divisor instead of default.""" + """Get a MCoreMinitronConfig with the given divisors instead of default.""" config = MCoreMinitronConfig() def _set_divisors(c): for k, v in c.items(): if isinstance(v, dict): _set_divisors(v) - elif k in ["hidden_size_divisor", "ffn_hidden_size_divisor"]: - c[k] = channel_divisor + elif k == "hidden_size_divisor": + c[k] = hidden_size_divisor + elif k == "ffn_hidden_size_divisor": + c[k] = ffn_hidden_size_divisor elif k == "mamba_head_dim_divisor": c[k] = mamba_head_dim_divisor elif k == "num_moe_experts_divisor": diff --git a/tests/_test_utils/torch/nas_prune/minitron_common.py b/tests/_test_utils/torch/nas_prune/minitron_common.py index 318ffdda8..a4da1771d 100644 --- a/tests/_test_utils/torch/nas_prune/minitron_common.py +++ b/tests/_test_utils/torch/nas_prune/minitron_common.py @@ -23,7 +23,8 @@ def prune_minitron(model, export_config, config, channel_divisor=64): ( "mcore_minitron", mtp.mcore_minitron.get_mcore_minitron_config( - channel_divisor=channel_divisor, + hidden_size_divisor=channel_divisor, + ffn_hidden_size_divisor=channel_divisor, mamba_head_dim_divisor=4, num_moe_experts_divisor=1, num_layers_divisor=1, diff --git a/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py b/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py index 412b9ad37..6771bb9a0 100644 --- a/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py +++ b/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py @@ -88,7 +88,11 @@ def _test_gpt_search_space( [ ( "mcore_minitron", - get_mcore_minitron_config(channel_divisor=channel_divisor, num_layers_divisor=1), + get_mcore_minitron_config( + hidden_size_divisor=channel_divisor, + ffn_hidden_size_divisor=channel_divisor, + num_layers_divisor=1, + ), ) ], ) @@ -267,7 +271,10 @@ def _test_gpt_moe_search_space(rank, size): ( "mcore_minitron", get_mcore_minitron_config( - channel_divisor=channel_divisor, num_moe_experts_divisor=1, num_layers_divisor=1 + hidden_size_divisor=channel_divisor, + ffn_hidden_size_divisor=channel_divisor, + num_moe_experts_divisor=1, + num_layers_divisor=1, ), ) ], diff --git a/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py b/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py index 868647eb6..aa499abf0 100644 --- a/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py +++ b/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py @@ -81,7 +81,8 @@ def _test_mamba_search_space(rank, size): ( "mcore_minitron", get_mcore_minitron_config( - channel_divisor=channel_divisor, + hidden_size_divisor=channel_divisor, + ffn_hidden_size_divisor=channel_divisor, mamba_head_dim_divisor=mamba_head_dim_divisor, num_layers_divisor=1, ), diff --git a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py index 81bcae04e..5af0f0c2c 100644 --- a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py +++ b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py @@ -78,7 +78,10 @@ def _test_mcore_gpt_parameter_sorting(activation_func, rank, size): model.eval() dynamic_space = _convert_model_to_dynamic_space( - model, get_mcore_minitron_config(channel_divisor=channel_divisor) + model, + get_mcore_minitron_config( + hidden_size_divisor=channel_divisor, ffn_hidden_size_divisor=channel_divisor + ), ) registry = ImportanceEstimatorRegistry(model) # register imp estimators and forward hooks @@ -355,7 +358,12 @@ def _test_mcore_gpt_moe_parameter_sorting(rank, size): model.eval() dynamic_space = _convert_model_to_dynamic_space( - model, get_mcore_minitron_config(channel_divisor=channel_divisor, num_moe_experts_divisor=1) + model, + get_mcore_minitron_config( + hidden_size_divisor=channel_divisor, + ffn_hidden_size_divisor=channel_divisor, + num_moe_experts_divisor=1, + ), ) registry = ImportanceEstimatorRegistry(model) # register imp estimators and forward hooks @@ -500,11 +508,12 @@ def test_mcore_gpt_pruning_moe(tmp_path): def test_generate_search_space_combos(): ss = { "hidden_size": [32, 64, 96, 128, 160], + "ffn_hidden_size": [128, 256, 384, 512, 640], "num_attention_heads": [8, 16, 24, 32], "num_layers": [1, 2, 3, 4, 5, 6, 7, 8], } ss_combos = MCoreMinitronSearcher._generate_search_space_combos( - ss, max_width_pruning=0.5, max_depth_pruning=0.25 + ss, max_width_pruning=0.5, max_depth_pruning=0.25, hparams_to_skip=["ffn_hidden_size"] ) assert len(ss_combos) == 3 * 2 * 2 assert ss_combos == [ diff --git a/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py b/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py index a7f036bbb..20a611569 100644 --- a/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py +++ b/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py @@ -78,7 +78,12 @@ def _test_mcore_mamba_parameter_sorting(rank, size): model.eval() dynamic_space = _convert_model_to_dynamic_space( - model, get_mcore_minitron_config(channel_divisor=channel_divisor, mamba_head_dim_divisor=4) + model, + get_mcore_minitron_config( + hidden_size_divisor=channel_divisor, + ffn_hidden_size_divisor=channel_divisor, + mamba_head_dim_divisor=4, + ), ) registry = ImportanceEstimatorRegistry(model) # register imp estimators and forward hooks From f2ee949f2ad2f0f502685e101bb7aa6c02dfb431 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Thu, 8 Jan 2026 04:40:20 -0800 Subject: [PATCH 8/9] Add unit test for Hybrid NAS-based auto pruning Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- modelopt/torch/nas/search_space.py | 4 +- .../torch/prune/plugins/mcore_minitron.py | 5 +- tests/_test_utils/torch/megatron/models.py | 4 +- .../torch/nas_prune/minitron_common.py | 4 +- .../test_mcore_gpt_minitron_pruning.py | 10 +- .../test_mcore_mamba_minitron_pruning.py | 167 +++++++++++++++++- 6 files changed, 179 insertions(+), 15 deletions(-) diff --git a/modelopt/torch/nas/search_space.py b/modelopt/torch/nas/search_space.py index 6da4d425a..4c7d2172a 100644 --- a/modelopt/torch/nas/search_space.py +++ b/modelopt/torch/nas/search_space.py @@ -135,9 +135,7 @@ def sort_parameters(self, hps_to_sort: set[str] | None = None, verbose: bool = F hps_to_sort: A set of hparam names to sort. If not provided or empty, all hparams will be sorted. verbose: Whether to print the search space and hparam importances. """ - print_rank_0("Sorting parameters...") - if verbose: - self.print_summary() + print_rank_0("\nSorting parameters...") # get config and set to max config = self.config() diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index fd78f9159..42e7c7430 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -214,7 +214,8 @@ def default_state_dict(self) -> SearchStateDict: def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig: """Sanitize the search config dict.""" config = super().sanitize_search_config(config) - config["checkpoint"] = config["scores_path"] + if config["scores_path"]: + config["checkpoint"] = config["scores_path"] config["verbose"] = True # Print for all ranks return config @@ -457,7 +458,7 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict: start_layer_number += 1 self.unwrapped_model.decoder.layers = all_layers print_rank_0( - f"\t{candidate.ss_config} -> {num2hrb(candidate.params)} params, {candidate.score:.4f} score" + f"\t{candidate.ss_config} -> {num2hrb(candidate.params)} params, {candidate.score:.4f} score\n" ) dist.barrier() diff --git a/tests/_test_utils/torch/megatron/models.py b/tests/_test_utils/torch/megatron/models.py index 76ddc5a94..a22eaaf9e 100644 --- a/tests/_test_utils/torch/megatron/models.py +++ b/tests/_test_utils/torch/megatron/models.py @@ -314,6 +314,7 @@ def get_mcore_mamba_hybrid_model( sequence_parallel: bool = False, # Mamba-specific parameters mamba_state_dim: int = 32, + mamba_num_heads: int | None = None, mamba_head_dim: int = 16, mamba_num_groups: int = 2, # MoE-specific parameters @@ -347,6 +348,7 @@ def get_mcore_mamba_hybrid_model( num_query_groups=num_query_groups, ffn_hidden_size=ffn_hidden_size, mamba_state_dim=mamba_state_dim, + mamba_num_heads=mamba_num_heads, mamba_head_dim=mamba_head_dim, mamba_num_groups=mamba_num_groups, num_moe_experts=num_moe_experts, @@ -358,7 +360,7 @@ def get_mcore_mamba_hybrid_model( **config_kwargs, ) - if not (skip_moe or "E" in Symbols.VALID): + if not (skip_moe or "E" in Symbols.VALID): # Mcore 0.16+ has MoE support warn("MoE blocks are not supported in current MambaModel. Skipping MoE blocks.") skip_moe = True diff --git a/tests/_test_utils/torch/nas_prune/minitron_common.py b/tests/_test_utils/torch/nas_prune/minitron_common.py index a4da1771d..87da6414c 100644 --- a/tests/_test_utils/torch/nas_prune/minitron_common.py +++ b/tests/_test_utils/torch/nas_prune/minitron_common.py @@ -16,7 +16,7 @@ import modelopt.torch.prune as mtp -def prune_minitron(model, export_config, config, channel_divisor=64): +def prune_minitron(model, constraints, config, channel_divisor=64): return mtp.prune( model, mode=[ @@ -31,7 +31,7 @@ def prune_minitron(model, export_config, config, channel_divisor=64): ), ) ], - constraints={"export_config": export_config}, + constraints=constraints, dummy_input=None, # Not used config=config, ) diff --git a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py index 5af0f0c2c..c06c84326 100644 --- a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py +++ b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py @@ -202,6 +202,7 @@ def forward_loop(m): export_config["hidden_size"] = pruned_hidden_size if pruned_num_layers_div != 1: export_config["num_layers"] = pruned_num_layers + constraints = {"export_config": export_config} config = { "scores_path": ckpt_path, @@ -211,7 +212,7 @@ def forward_loop(m): assert ckpt_path is None else: config["forward_loop"] = forward_loop - model, pruning_scores = prune_minitron(model, export_config, config, channel_divisor) + model, pruning_scores = prune_minitron(model, constraints, config, channel_divisor) if not skip_sorting: assert pruning_scores["layer_scores"] assert pruning_scores["activations_per_rank"] @@ -248,7 +249,7 @@ def forward_loop(m): model_rerun = _get_model(initialize_megatron=False) model_rerun.load_state_dict(sd) model_rerun, pruning_scores = prune_minitron( - model_rerun, export_config, {"scores_path": ckpt_path}, channel_divisor + model_rerun, constraints, {"scores_path": ckpt_path}, channel_divisor ) output_rerun = run_mcore_inference(model_rerun, prompt_tokens, pruned_hidden_size) @@ -450,10 +451,11 @@ def forward_loop(m): "moe_shared_expert_intermediate_size": pruned_moe_shared_ffn, "num_moe_experts": pruned_num_moe_experts, } + constraints = {"export_config": export_config} prune_minitron( model, - export_config, + constraints, {"scores_path": ckpt_path, "forward_loop": forward_loop}, channel_divisor, ) @@ -491,7 +493,7 @@ def forward_loop(m): # Assert re-pruning from scores_path works without running the forward loop again model_rerun = _get_model(initialize_megatron=False) model_rerun.load_state_dict(sd) - prune_minitron(model_rerun, export_config, {"scores_path": ckpt_path}, channel_divisor) + prune_minitron(model_rerun, constraints, {"scores_path": ckpt_path}, channel_divisor) output_rerun = run_mcore_inference(model_rerun, prompt_tokens, pruned_hidden_size) assert torch.allclose(output, output_rerun, atol=1e-5) diff --git a/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py b/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py index 20a611569..dbcbff4b7 100644 --- a/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py +++ b/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py @@ -14,8 +14,11 @@ # limitations under the License. +import contextlib +import io from functools import partial +import pytest import torch from _test_utils.import_helper import skip_if_no_megatron @@ -29,6 +32,7 @@ ) from _test_utils.torch.misc import compare_outputs, set_seed from _test_utils.torch.nas_prune.minitron_common import prune_minitron +from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols from megatron.core.ssm.mamba_layer import MambaLayer from megatron.core.transformer.identity_op import IdentityOp @@ -37,6 +41,7 @@ ImportanceEstimatorRegistry, _convert_model_to_dynamic_space, get_mcore_minitron_config, + get_mcore_param_count, ) SEED = 1234 @@ -167,7 +172,7 @@ def _get_model(initialize_megatron=True): mamba_num_heads = mamba_layer.mixer.nheads def forward_loop(m): - for _ in range(5): + for _ in range(2): run_mcore_inference_with_dummy_input(m, batch_size, hidden_size) # Traditional GPT pruning parameters @@ -191,9 +196,10 @@ def forward_loop(m): "moe_shared_expert_intermediate_size": pruned_ffn_hidden_size, "num_moe_experts": pruned_num_moe_experts, } + constraints = {"export_config": export_config} prune_minitron( model, - export_config, + constraints, {"forward_loop": forward_loop, "scores_path": ckpt_path}, channel_divisor, ) @@ -225,7 +231,7 @@ def forward_loop(m): # Assert re-pruning from scores_path works without running the forward loop again model = _get_model(initialize_megatron=False) - prune_minitron(model, export_config, {"scores_path": ckpt_path}, channel_divisor) + prune_minitron(model, constraints, {"scores_path": ckpt_path}, channel_divisor) def test_mcore_mamba_hybrid_pruning(tmp_path): @@ -234,3 +240,158 @@ def test_mcore_mamba_hybrid_pruning(tmp_path): job=partial(_test_mcore_mamba_hybrid_pruning, tmp_path / "modelopt_minitron_scores.pth"), backend="nccl", ) + + +def _test_mcore_mamba_hybrid_pruning_nas(ckpt_path, rank, size): + channel_divisor = 4 + + # TODO: MoE in MambaModel requires Mcore 0.16+ + num_layers = 4 # Atleast one of "M, *, -, E" blocks + hybrid_pattern = "M*-M" # "ME*-" + hidden_size = 16 + ffn_hidden_size = 32 + num_attention_heads = 16 + num_query_groups = 4 + mamba_state_dim = 4 + mamba_num_heads = 16 + mamba_head_dim = 16 + mamba_num_groups = 2 + num_moe_experts = None + moe_ffn_hidden_size = None + moe_shared_expert_intermediate_size = None + # num_moe_experts = 8 + # moe_ffn_hidden_size = 16 + # moe_shared_expert_intermediate_size = 16 + vocab_size = 32 + batch_size = 2 + + model = get_mcore_mamba_hybrid_model( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=size, + initialize_megatron=True, + num_layers=num_layers, + hybrid_override_pattern=hybrid_pattern, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, + mamba_state_dim=mamba_state_dim, + mamba_num_heads=mamba_num_heads, + mamba_head_dim=mamba_head_dim, + mamba_num_groups=mamba_num_groups, + moe_ffn_hidden_size=moe_ffn_hidden_size, + moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, + num_moe_experts=num_moe_experts, + vocab_size=vocab_size, + ).cuda() + + param_count = get_mcore_param_count(model) + assert param_count == 31776.0, param_count + + def forward_loop(m): + for _ in range(2): + run_mcore_inference_with_dummy_input(m, batch_size, hidden_size) + + def score_func(m): + c = m.config + return ( + c.num_layers + + c.hidden_size + + c.ffn_hidden_size + + c.mamba_num_heads + + c.mamba_head_dim + + c.num_attention_heads + # + c.num_moe_experts + # + c.moe_ffn_hidden_size + # + c.moe_shared_expert_intermediate_size + ) + + constraints = {"params": int(param_count * 0.7)} + config = { + "forward_loop": forward_loop, + "scores_path": ckpt_path, + "score_func": score_func, + "max_width_pruning": 0.5, + "max_depth_pruning": 0.5, + "hparams_to_skip": ["num_attention_heads"], + "top_k": 10, + } + + # Capture stdout to assert search space output + stdout_capture = io.StringIO() + with contextlib.redirect_stdout(stdout_capture): + model, searcher_state = prune_minitron(model, constraints, config, channel_divisor) + + # Assert expected search space output is present + captured_output = stdout_capture.getvalue() + print(captured_output) + if rank == 0: + assert "Search space for num_layers: [3, 4]" in captured_output + assert "Search space for hidden_size: [12, 16]" in captured_output + assert "Search space for mamba_num_heads: [10, 12, 14, 16]" in captured_output + assert "Search space for mamba_head_dim: [12, 16]" in captured_output + assert "Search space for ffn_hidden_size: [20, 24, 28, 32]" in captured_output + assert "Total search space in consideration: 128" in captured_output + + # NOTE: Slight variation in layer ordering for Attention and MLP depending on PP configuration + # This affects param counts when num_layers is pruned + sorted_layers = [ + layer + for layer, _ in sorted( + searcher_state["layer_scores"].items(), key=lambda x: x[1], reverse=True + ) + ] + # fmt: off + if sorted_layers == [1, 4, 2, 3]: + expected_top_k = [ + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 32}, 22196.0, 94.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 28}, 22068.0, 90.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 24}, 21940.0, 86.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 32}, 21916.0, 94.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 28}, 21820.0, 90.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 20}, 21812.0, 82.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 24}, 21724.0, 86.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 20}, 21628.0, 82.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 10, "mamba_head_dim": 16, "ffn_hidden_size": 32}, 21180.0, 94.0], # noqa: E501 + [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 20}, 21140.0, 81.0], # noqa: E501 + ] + elif sorted_layers == [1, 4, 3, 2]: + expected_top_k = [ + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 32}, 22196.0, 94.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 28}, 22068.0, 90.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 24}, 21940.0, 86.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 32}, 21916.0, 94.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 28}, 21820.0, 90.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 20}, 21812.0, 82.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 24}, 21724.0, 86.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 20}, 21628.0, 82.0], # noqa: E501 + [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 32}, 21524.0, 93.0], # noqa: E501 + [{"num_layers": 3, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 32}, 21412.0, 93.0], # noqa: E501 + ] + else: + raise RuntimeError(f"FIXME: Non deterministic test, assertions may fail: {sorted_layers}") + # fmt: on + + assert get_mcore_param_count(model) == 22196.0 + + top_k = searcher_state["top_k_candidates_per_constraint"][constraints["params"]] + assert len(top_k) == 10 + for actual, (ss_config, params, score) in zip(top_k, expected_top_k): + assert actual.ss_config == ss_config, (actual.ss_config, ss_config) + assert actual.params == params, (actual.params, params) + assert actual.score == score, (actual.score, score) + + +def test_mcore_mamba_hybrid_pruning_nas(tmp_path): + set_seed(SEED) + if torch.cuda.device_count() > 4: + pytest.skip("Skipping test for more than 4 GPUs") + if "E" in Symbols.VALID: + pytest.skip("TODO: Update test for MoE in Mamba (Mcore 0.16+)") + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_mcore_mamba_hybrid_pruning_nas, tmp_path / "modelopt_minitron_scores.pth" + ), + backend="nccl", + ) From e41b10822287b0dc837b47fd06da7ea8057f9e2f Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Thu, 8 Jan 2026 10:29:54 -0800 Subject: [PATCH 9/9] Update pruning readme Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- examples/pruning/README.md | 192 ++++++++++++++++++++++++++----------- 1 file changed, 136 insertions(+), 56 deletions(-) diff --git a/examples/pruning/README.md b/examples/pruning/README.md index c9441dd0d..d9c17cdc3 100644 --- a/examples/pruning/README.md +++ b/examples/pruning/README.md @@ -17,27 +17,36 @@ This section focuses on applying Model Optimizer's state-of-the-art complementar | Pre-Requisites | Required & optional packages to use this technique | \[[Link](#pre-requisites)\] | | | Getting Started | Learn how to use the pruning API | \[[Link](#getting-started)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/3_pruning.html)\] | | Support Matrix | View the support matrix to see available pruning algorithms and their compatibility with different models and frameworks | \[[Link](#support-matrix)\] | | -| Pruning Guidelines | Guidelines for choosing how and how much to prune for best results | \[[Link](#pruning-guidelines)\] | | | Examples | Examples of different pruning methods | \[[Link](#examples)\] | | +| Pruning Guidelines | Guidelines for choosing how and how much to prune for best results | \[[Link](#pruning-guidelines)\] | | | Resources | Extra links to relevant resources | \[[Link](#resources)\] | | ## Pre-Requisites -For Minitron pruning for Megatron-LM / NeMo models, use the NeMo container (e.g., `nvcr.io/nvidia/nemo:25.09`) which has all the dependencies installed. +For Minitron pruning for Megatron-LM / NeMo models, use the NeMo container (e.g., `nvcr.io/nvidia/nemo:25.11`) which has all the dependencies installed. Make sure to upgrade Model Optimizer to the latest version using `pip`. For FastNAS pruning for PyTorch Computer Vision models, no additional dependencies are required. -For GradNAS pruning for Hugging Face BERT / GPT-J, no additional dependencies are requisred. +For GradNAS pruning for Hugging Face BERT / GPT-J, no additional dependencies are required. ## Getting Started -As part of the pruning process, you will need to set up the training and/or validation data loaders, and optionally define a validation score function (FastNAS) or loss function (GradNAS) and specify the desired pruning constraints (See [Support Matrix](#support-matrix) for available pruning constraints). +As part of the pruning process, you will need to set up the training and/or validation data loaders, and optionally define a validation score function (Minitron, FastNAS) or loss function (GradNAS) and specify the desired pruning constraints (See [Support Matrix](#support-matrix) for available pruning constraints). + +To prune your model, you can simply call the `mtp.prune` API and save the pruned model. If the model is pruned using Minitron, you can use your standard saving and loading functions since it is a homogeneous pruning; while for FastNAS or GradNAS, you need to use `mto.save` and `mto.restore` to save and restore the heterogeneous pruned model. + +### Minitron + +Minitron pruning supports two modes: + +1. **Manual Pruning**: Manually specify the target dimensions for each pruning axis (e.g., `constraints = {"export_config": {"hidden_size": 3072, "ffn_hidden_size": 9216}}`) +2. **NAS-based Auto Pruning (New)**: Specify a target parameter count (e.g., `constraints = {"params": 6e9}`) and let the algorithm automatically search for the best architecture that maximizes a user-defined score function (e.g. MMLU, negative validation loss, etc.) -To prune your model, you can simply call the `mtp.prune` API and save the pruned model. If the model is pruned using FastNAS or GradNAS, you need to use `mto.save` and `mto.restore` to save and restore the pruned model; while for Minitron pruning, you can use your standard saving and loading functions since it is a homogeneous pruning. +Please see example snippets of both modes for Minitron pruning on Megatron-Core GPT model below. For end-to-end examples script (M-LM / NeMo framework), please refer to the examples below. -Please see an example snippet of Minitron pruning for Megatron-Core GPT model below (for other algorithms, please refer to the examples below). +#### Common Setup ```python import modelopt.torch.prune as mtp @@ -45,11 +54,11 @@ from megatron.core.models.gpt import GPTModel from megatron.core.post_training.modelopt.gpt.model_specs import get_gpt_modelopt_spec from megatron.core.transformer.transformer_config import TransformerConfig -# Load the Megatron-Core GPTModel with ModelOpt transformer layer spec -config = TransformerConfig(...) +# Load the Megatron-Core GPTModel MambaModel with ModelOpt transformer layer spec +model_config = TransformerConfig(...) model = GPTModel( - config=config, - transformer_layer_spec=get_gpt_modelopt_spec(config, remap_te_layernorm=True), + config=model_config, + transformer_layer_spec=get_gpt_modelopt_spec(model_config, remap_te_layernorm=True), ... ) @@ -60,41 +69,141 @@ from megatron.training.training import evaluate_and_print_results def forward_loop(_): evaluate_and_print_results(prefix, forward_step, train_iterator, model, ...) - -# Specify the pruning constraints (Check Support Matrix for available pruning dimensions) -export_config = { - "hidden_size": 3072, - "ffn_hidden_size": 9216, -} - - # Run the pruning process (if model is a list then pass model[0] to the prune API) -# Save minitron scores at scores_path so we can re-run pruning with different export configs without running the forward loop again -# NOTE: Skip scores_path on re-running if you want to change the dataset and re-calibrate +# Save minitron scores at checkpoint so we can re-run pruning with different constraints without running the forward loop again +# NOTE: Skip checkpoint on re-running if you want to change the dataset and re-calibrate model, pruning_scores = mtp.prune( model, mode="mcore_minitron", - constraints={"export_config": export_config}, + constraints=constraints, dummy_input=None, # Not used - config={"forward_loop": forward_loop, "scores_path": "modelopt_minitron_scores.pth"}, + config=config, ) ``` -If your model parameters are already sorted, you can skip the sorting step by setting `"skip_sorting": True` in `config` instead of passing `forward_loop`. - > [!Note] > Fine-tuning / distillation is required after pruning to recover the accuracy. Please refer to [end-to-end pruning and distillation tutorial](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) for more details. +#### 1. Manual Pruning + +This mode can be useful when you know the exact dimensions you want to prune to (e.g. fitting a specific latency / memory budget). + +```python +# Specify the pruning constraints (Check Support Matrix for available pruning dimensions) +constraints = {"export_config": {"hidden_size": 3072, "ffn_hidden_size": 9216}} +config = {"forward_loop": forward_loop, "checkpoint": "/path/to/cache/pruning/scores.pth"} + +mtp.prune(...) +``` + +**Under the Hood:** + +1. **Importance Scoring**: Runs forward passes on calibration data (512-1024 samples) to compute activation magnitudes for each neuron/head/layer (takes ~5 minutes for an 8B model) +2. **Ranking**: Ranks all parameters within each pruning dimension (e.g., all hidden dimensions, all attention heads) by their importance scores +3. **Pruning**: Removes the least important parameters to meet the specified target dimensions in `export_config` +4. **Weight Slicing**: Slices the model weights according to the pruned architecture (homogeneous pruning - all layers pruned uniformly) + +> [!TIP] +> Checkout the [Pruning Guidelines](#pruning-guidelines) section for more details on how to choose the best pruning strategy and distillation hyperparameters. + +#### 2. NAS-based Auto Pruning + +This mode can be useful when you don't know the exact dimensions you want to prune to and want the algorithm to search for the best architecture that maximizes a user-defined score function at the cost of longer runtime. + +```python +# Define the score function to maximize (e.g., MMLU, negative validation loss, etc.) +# The algorithm will search for the best architecture that maximizes this score +from modelopt.torch.utils.plugins.megatron_mmlu import megatron_mmlu + +def score_func(m): + return megatron_mmlu(m, tokenizer, percentage=0.05) # 5% sampled data for faster eval + +# Specify target parameter count and configure the auto pruning algorithm +constraints = {"params": 6e9} # Prune to 6B parameters +config = { + "forward_loop": forward_loop, + "checkpoint": "/path/to/cache/pruning/scores.pth", + "score_func": score_func, + # Optional: Configure search space constraints (showing defaults) + "max_width_pruning": 0.4, # Maximum 40% per width pruning hparam + "max_depth_pruning": 0.2, # Maximum 20% per depth pruning hparam (num_layers) + "hparams_to_skip": [], # Disable pruning specific hparams, e.g., ["num_attention_heads"] + "top_k": 10, # Number of top architectures to evaluate (use 20 for better results at the cost of 2x time) +} + +mtp.prune(...) +``` + +**Under the Hood:** + +1. **Importance Scoring**: Same as manual pruning - computes activation magnitudes for all parameters (takes ~5 minutes for an 8B model) +2. **Search Space Construction**: Generates a search space of possible architectures based search space config and other configs (`max_width_pruning`, `max_depth_pruning`, `hparams_to_skip`) +3. **Architecture Search**: Find candidate architectures that meet the parameter constraint and evaluate `top_k` (based on number of parameters) of them using `score_func` e.g. MMLU, negative validation loss, etc. (takes ~10 mins per candidate for an 8B model pruning) +4. **Best Architecture Selection**: Returns the architecture (best `export_config`) with the highest actual score from the top-K evaluated architectures +5. **Weight Slicing**: Slices the model weights according to the best pruned architecture found + +> [!Note] +> As per the [original paper](https://arxiv.org/pdf/2407.14679), ideally we need to perform a short Knowledge Distillation on ~2B tokens for all top-K candidate architectures before evaluating the score function, which will take a lot longer to prune, require splitting the pruning process into multiple stages and a lot more compute for pruning but can lead to better pruned model. If you are interested to do this, you can take the top-K candidate's `export_config` from the pruning logs and then export all models separately and perform Knowledge Distillation on each of them before evaluating the score function. + +#### Advanced Configuration + +For finer control over the search space (e.g., granularity of pruning choices), you can configure the divisors: + +```python +# Configure search space granularity (showing defaults) +ss_config = mtp.mcore_minitron.get_mcore_minitron_config( + hidden_size_divisor=256, + ffn_hidden_size_divisor=512, + mamba_head_dim_divisor=8, + num_moe_experts_divisor=8, + num_layers_divisor=2, +) + +# Use the custom search space config +mtp.prune(model, mode=[("mcore_minitron", ss_config)], ...) +``` + +If your model parameters are already sorted and you just want to prune the weights, you can skip the sorting step by setting `"skip_sorting": True` in `config` instead of passing `forward_loop`. + ## Support Matrix | **Algorithm** | **Model** | **Pruning Constraints** | | :---: | :---: | :---: | -| Minitron | Megatron-core / NeMo based GPT / Mamba / MoE / Hybrid LLM Models1 | Export config with width (`hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `mamba_num_heads`, `mamba_head_dim`, `num_moe_experts`, `moe_ffn_hidden_size`, `moe_shared_expert_intermediate_size`) and/or depth (`num_layers`) values | -| FastNAS | Computer Vision models | flops, parameters | -| GradNAS | HuggingFace BERT, GPT-J | flops, parameters | +| Minitron | Megatron-core / NeMo based GPT / Mamba / MoE / Hybrid LLM Models1 | **Manual:** `export_config` with width (`hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `mamba_num_heads`, `mamba_head_dim`, `num_moe_experts`, `moe_ffn_hidden_size`, `moe_shared_expert_intermediate_size`) and/or depth (`num_layers`) pruned values
**Auto:** `params` (requires `score_func` in config) | +| FastNAS | Computer Vision models | `flops`, `params` | +| GradNAS | HuggingFace BERT, GPT-J | `flops`, `params` | > *1.Only Pipeline Parallel models are supported. Hugging Face models can be converted to Megatron-LM/NeMo format and used subsequently.* +## Examples + +### Minitron Pruning for Megatron-LM / NeMo Framework LLMs (e.g. Qwen 3, Nemotron Nano) + +Checkout the Minitron pruning example for the [Megatron-LM Framework](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt#-pruning) or [NeMo Framework](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/pruning/pruning.html) which showcases the usage of the powerful Minitron pruning algorithm developed by NVIDIA Research for pruning LLMs like Llama-3.1-8B, Qwen3-8B, Nemotron-Nano-9B-v2, Nemotron-3-Nano-30B-A3B, etc. +Both frameworks support importing from a Hugging Face pretrained checkpoint. + +You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Qwen3-8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial. + +Some of the models pruned using Minitron method followed by distillation and post-training are: + +- [Minitron Collection on Hugging Face](https://huggingface.co/collections/nvidia/minitron) +- [NVIDIA-Nemotron-Nano-9B-v2](https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-9B-v2) + +### FastNAS Pruning for PyTorch Computer Vision Models + +Check out the FastNAS pruning example usage in the [documentation](https://nvidia.github.io/Model-Optimizer/guides/3_pruning.html#pruning-and-subnet-search). + +You can also take a look at FastNAS pruning interactive notebook [cifar_resnet](./cifar_resnet.ipynb) in this directory +which showcases the usage of FastNAS for pruning a ResNet 20 model for the CIFAR-10 dataset. The notebook +also shows how to profile the model to understand the search space of possible pruning options and demonstrates +how to save and restore pruned models. + +### GradNAS Pruning for HuggingFace Language Models (e.g. BERT) + +Checkout the BERT pruning example in [chained_optimizations](../chained_optimizations/README.md) directory +which showcases the usage of GradNAS for pruning BERT model for Question Answering followed by fine-tuning +with distillation and quantization. The example also demonstrates how to save and restore pruned models. + ## Pruning Guidelines ### Minitron @@ -173,35 +282,6 @@ After pruning, distillation is required to recover model accuracy. Below are rec > [!TIP] > If you know the maximum learning rate used during the original training, a good rule of thumb for knowledge distillation is to use **1/5th of that maximum LR** when compressing by ~50%. -## Examples - -### Minitron Pruning for Megatron-LM / NeMo Framework LLMs (e.g. Qwen 3, Nemotron Nano) - -Checkout the Minitron pruning example for the [Megatron-LM Framework](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt#-pruning) or [NeMo Framework](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/pruning/pruning.html) which showcases the usage of the powerful Minitron pruning algorithm developed by NVIDIA Research for pruning LLMs like Llama 3.1 8B, Qwen 3 8B, Nemotron Nano 12B v2, etc. -Both frameworks support importing from a Hugging Face pretrained checkpoint. - -You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Qwen 3 8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial. - -Some of the models pruned using Minitron method followed by distillation and post-training are: - -- [Minitron Collection on Hugging Face](https://huggingface.co/collections/nvidia/minitron) -- [NVIDIA-Nemotron-Nano-9B-v2](https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-9B-v2) - -### FastNAS Pruning for PyTorch Computer Vision Models - -Check out the FastNAS pruning example usage in the [documentation](https://nvidia.github.io/Model-Optimizer/guides/3_pruning.html#pruning-and-subnet-search). - -You can also take a look at FastNAS pruning interactive notebook [cifar_resnet](./cifar_resnet.ipynb) in this directory -which showcases the usage of FastNAS for pruning a ResNet 20 model for the CIFAR-10 dataset. The notebook -also shows how to profile the model to understand the search space of possible pruning options and demonstrates -how to save and restore pruned models. - -### GradNAS Pruning for HuggingFace Language Models (e.g. BERT) - -Checkout the BERT pruning example in [chained_optimizations](../chained_optimizations/README.md) directory -which showcases the usage of GradNAS for pruning BERT model for Question Answering followed by fine-tuning -with distillation and quantization. The example also demonstrates how to save and restore pruned models. - ## Resources - 📅 [Roadmap](https://github.com/NVIDIA/Model-Optimizer/issues/146)