diff --git a/examples/puzzletron/README.md b/examples/puzzletron/README.md index e3a909d22..f16162083 100644 --- a/examples/puzzletron/README.md +++ b/examples/puzzletron/README.md @@ -17,6 +17,7 @@ In this example, we compress the [Llama-3.1-8B-Instruct](https://huggingface.co/ ```bash pip install -e .[hf,puzzletron] +pip install -r requirements.txt ``` - For this example we are using 2x NVIDIA H100 80GB HBM3 to show multi-GPU steps. You can use also use s single GPU. @@ -231,6 +232,24 @@ vllm bench latency --model path/to/model --load-format safetensors --trust-remot vllm bench throughput --model path/to/model --input-len 2000 --output-len 100 --load-format safetensors --trust-remote-code ``` +## Knowledge Distillation + +To recover degradation in the quality of the compressed model, we can use knowledge distillation. This allows transferring the capabilities of the original model to the pruned one. For this, we will use [NeMo framework](https://github.com/NVIDIA-NeMo/NeMo) with the [nemo:25.07](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo?version=25.07) container. + +First, convert the HF model to NeMo format: + +```bash +python -m nemo_export/convert_hf_to_nemo --input-ckpt-path path/to/HF-model --output-ckpt-path path/to/save/model-nemo +``` + +Now you can utilize all the training features available in NeMo, including distillation. Please refer to the [NeMo distillation documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/distillation/distillation.html). + +[Optional] Once distillation is complete, you can convert the distilled model back to the HuggingFace format. + +```bash +python -m nemo_export/convert_nemo_to_hf --input-ckpt-path path/to/nemo-model --output-ckpt-path path/to/save/model-HF +``` + ## Advanced Usage Modify `llama-3_1-8B_pruneffn_memory.yaml` file for advanced compression scenarios. diff --git a/examples/puzzletron/nemo_export/convert_hf_to_nemo.py b/examples/puzzletron/nemo_export/convert_hf_to_nemo.py new file mode 100644 index 000000000..0cf16b448 --- /dev/null +++ b/examples/puzzletron/nemo_export/convert_hf_to_nemo.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from pathlib import Path +from typing import Any + +from nemo.collections import llm + +from modelopt.torch.puzzletron.export.MCore.llama_nemotron import ( + PuzzletronLlamaNemotronModel, + PuzzletronNemotronModelConfig, +) + + +def convert_model( + hf_model_path_local: str, output_path_nemo_local: str, overwrite: bool = False +) -> Any: + """Convert a Puzzletron HuggingFace model to NeMo format. + + Args: + hf_model_path_local: Path to the input Puzzletron HuggingFace model directory + output_path_nemo_local: Path where the converted Puzzletron NeMo model will be saved + overwrite: Whether to overwrite existing output directory + """ + + model = PuzzletronLlamaNemotronModel(config=PuzzletronNemotronModelConfig) + # NOTE: API call to import_ckpt is here: https://github.com/NVIDIA-NeMo/NeMo/blob/294ddff187f68c055d87ffe9400e65975b38693d/nemo/collections/llm/api.py#L888 + print( + f"calling import_ckpt with model: {model}, " + f"source: {hf_model_path_local}, " + f"output_path: {output_path_nemo_local}, " + f"overwrite: {overwrite}" + ) + nemo2_path = llm.import_ckpt( + model=model, + source="hf://" + hf_model_path_local, + output_path=Path(output_path_nemo_local), + overwrite=overwrite, + ) + + print(f"Model saved to {nemo2_path}") + return nemo2_path + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert Puzzletron HuggingFace model to NeMo format" + ) + parser.add_argument( + "--input-ckpt-path", + "-i", + type=str, + required=True, + help="Path to the input Puzzletron HuggingFace model directory", + ) + parser.add_argument( + "--output-ckpt-path", + "-o", + type=str, + required=True, + help="Path where the converted Puzzletron NeMo model will be saved", + ) + parser.add_argument( + "--overwrite", + action="store_true", + default=False, + help="Whether to overwrite existing output directory (default: False)", + ) + + args = parser.parse_args() + + # Validate input path + if not os.path.exists(args.input_ckpt_path): + raise FileNotFoundError(f"Input model path does not exist: {args.input_ckpt_path}") + + # Create output directory if it doesn't exist + os.makedirs(os.path.dirname(args.output_ckpt_path), exist_ok=True) + + print(f"Converting model from {args.input_ckpt_path} to {args.output_ckpt_path}") + convert_model(args.input_ckpt_path, args.output_ckpt_path, args.overwrite) + + +if __name__ == "__main__": + main() diff --git a/examples/puzzletron/nemo_export/convert_nemo_to_hf.py b/examples/puzzletron/nemo_export/convert_nemo_to_hf.py new file mode 100644 index 000000000..4645ae5b4 --- /dev/null +++ b/examples/puzzletron/nemo_export/convert_nemo_to_hf.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from pathlib import Path +from typing import Any + +from nemo.collections import llm + +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import copy_deci_lm_hf_code + + +def convert_model( + nemo_model_path_local: str, output_path_hf_local: str, overwrite: bool = False +) -> Any: + """Convert a NeMo model to HuggingFace format. + + Args: + nemo_model_path_local: Path to the input NeMo model file (.nemo) + output_path_hf_local: Path where the converted HuggingFace model will be saved + overwrite: Whether to overwrite existing output directory + """ + + # NOTE: API call to export_ckpt is here: https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/llm/api.py#L987 + print( + f"calling export_ckpt with path: {nemo_model_path_local}, " + f"target: hf, output_path: {output_path_hf_local}, " + f"target_model_name: PuzzletronLlamaNemotronModel, " + f"overwrite: {overwrite}" + ) + + hf_path = llm.export_ckpt( + path=nemo_model_path_local, + target="hf", + output_path=Path(output_path_hf_local), + target_model_name="PuzzletronLlamaNemotronModel", + overwrite=overwrite, + ) + + copy_deci_lm_hf_code(hf_path) + + print(f"Model saved to {hf_path}") + return hf_path + + +def main() -> None: + parser = argparse.ArgumentParser(description="Convert NeMo model to HuggingFace format") + parser.add_argument( + "--input-ckpt-path", + "-i", + type=str, + required=True, + help="Path to the input NeMo model checkpoint", + ) + parser.add_argument( + "--output-ckpt-path", + "-o", + type=str, + required=True, + help="Path where the converted Puzzletron HuggingFace model will be saved", + ) + parser.add_argument( + "--overwrite", + action="store_true", + default=False, + help="Whether to overwrite existing output directory (default: False)", + ) + + args = parser.parse_args() + + # Validate input path + if not os.path.exists(args.input_ckpt_path): + raise FileNotFoundError(f"Input model path does not exist: {args.input_ckpt_path}") + + # Create output directory if it doesn't exist + os.makedirs(os.path.dirname(args.output_ckpt_path), exist_ok=True) + + print(f"Converting model from {args.input_ckpt_path} to {args.output_ckpt_path}") + convert_model(args.input_ckpt_path, args.output_ckpt_path, args.overwrite) + + +if __name__ == "__main__": + main() diff --git a/examples/puzzletron/requirements.txt b/examples/puzzletron/requirements.txt new file mode 100644 index 000000000..fe63c413b --- /dev/null +++ b/examples/puzzletron/requirements.txt @@ -0,0 +1 @@ +lm-eval==0.4.9 diff --git a/modelopt/torch/puzzletron/export/MCore/llama_nemotron.py b/modelopt/torch/puzzletron/export/MCore/llama_nemotron.py new file mode 100644 index 000000000..d4292322f --- /dev/null +++ b/modelopt/torch/puzzletron/export/MCore/llama_nemotron.py @@ -0,0 +1,1015 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# based on https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/llm/gpt/model/llama_nemotron.py + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Any, Callable, Dict, Optional, Union + +import torch +import torch.nn.functional as F +from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel, torch_dtype_from_mcore_config +from nemo.collections.llm.gpt.model.llama import ( + Llama3Config, + Llama31Config, + Llama31Config70B, + LlamaConfig, + apply_rope_scaling, +) +from nemo.collections.llm.utils import Config +from nemo.lightning import OptimizerModule, io, teardown +from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME +from nemo.lightning.io.pl import ckpt_to_weights_subdir +from nemo.lightning.io.state import TransformFns +from nemo.lightning.pytorch.utils import dtype_from_hf, dtype_from_str +from nemo.utils import logging +from nemo.utils.import_utils import safe_import +from torch import nn + +from modelopt.torch.puzzletron.tools.logger import mprint + +# from nemo.collections.llm.gpt.model.llama_nemotron import Llama33NemotronSuper49BConfig + + +_, HAVE_TE = safe_import("transformer_engine") +from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import ( + get_gpt_heterogeneous_layer_spec, +) +from megatron.core.transformer.heterogeneous.heterogeneous_config import ( + HeterogeneousTransformerConfig, +) +from megatron.core.transformer.spec_utils import ModuleSpec + +if TYPE_CHECKING: + from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + from peft import AutoPeftModelForCausalLM, PeftConfig + from transformers import GenerationConfig, LlamaForCausalLM + from transformers import LlamaConfig as HFLlamaConfig + + from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig + +from modelopt.torch.puzzletron.export.MCore.llama_nemotron_utils import ( + _build_puzzletron_mappings_and_transforms, + _config_to_dict, + convert_attention_config_from_cfg_object, + convert_mlp_config_from_cfg_object, + convert_nemo_config_to_hf_decilm_config, + dtype_from_dict, + merge_qkv_for_puzzletron, + split_qkv_for_puzzletron, +) +from modelopt.torch.puzzletron.export.MCore.puzzletron_layer_specs import ( + PuzzletronHeterogeneousTransformerConfig, + get_gpt_heterogeneous_layer_spec_puzzletron, +) + + +def heterogeneous_layer_spec_puzzletron( + config: PuzzletronHeterogeneousTransformerConfig, +) -> ModuleSpec: + return get_gpt_heterogeneous_layer_spec_puzzletron(config, use_transformer_engine=HAVE_TE) + + +# Refactored to inherit directly from GPTConfig instead of Llama31Config70B +# This makes it easier to understand what attributes are set through the hierarchy +@dataclass +class PuzzletronNemotronModelConfig(GPTConfig, PuzzletronHeterogeneousTransformerConfig): + """Configuration for Puzzletron Nemotron models. + + DESIGN RATIONALE: + ================ + Refactored from original inheritance (Llama31Config70B + PuzzletronHeterogeneousTransformerConfig) + to explicit attribute definition for clarity and maintainability. Maintains identical behavior + to the original Llama hierarchy while enabling future flexibility. + + ATTRIBUTE ORGANIZATION: + ====================== + Explicitly defines attributes from the Llama hierarchy: + Llama31Config70B → Llama31Config → Llama3Config → LlamaConfig → GPTConfig + + FUTURE DEVELOPMENT: + ================== + Attributes can be freely modified/removed for future Puzzletron models. + In this case the tests in test_puzzletron_nemotron_config_inheritance.py will need to be updated. + Current explicit definition is for clarity during transition period. + """ + + # Override attributes from PuzzletronHeterogeneousTransformerConfig with Llama hierarchy values + # These ensure we maintain the same behavior as the original Llama31Config70B inheritance + + # ===== LlamaConfig attributes ===== + # Core model architecture + # NOTE: Default is F.silu, but this is overridden during instantiation to match all blocks + # See instantiate_nemo_config_from_adapted_dict() which enforces same activation across blocks + activation_func: Callable = F.silu + normalization: str = "RMSNorm" + gated_linear_unit: bool = True + position_embedding_type: str = "rope" + add_bias_linear: bool = False + # seq_length: int = 4096 # (will be overridden by Llama31Config70B) + attention_dropout: float = 0.0 + hidden_dropout: float = 0.0 + share_embeddings_and_output_weights: bool = False + # Fusion settings + bias_activation_fusion: bool = True + masked_softmax_fusion: bool = True + persist_layer_norm: bool = True + bias_dropout_fusion: bool = True + apply_rope_fusion: bool = True + use_transformer_engine_op_fuser: Optional[bool] = None + + # ===== Llama3Config attributes ===== + num_query_groups: int = 8 + # init_method_std: float = 0.01 # (will be overridden by Llama31Config) + layernorm_epsilon: float = 1.0e-05 + rotary_percent: float = 1.0 + + # ===== Llama31Config attributes ===== + scale_factor: float = 8.0 + low_freq_factor: float = 1.0 + high_freq_factor: float = 4.0 + old_context_len: int = 8192 + init_method_std: float = 0.02 # (overrides Llama3Config) + + # ===== Llama31Config70B attributes ===== + # Core model architecture (70B-specific) + rotary_base: int = 500_000 + seq_length: int = 131072 # (overrides LlamaConfig) + num_layers: int = 80 # + hidden_size: int = 8192 # + ffn_hidden_size: int = 28672 # + num_attention_heads: int = 64 # + kv_channels: int = 128 # (derived from hidden_size // num_attention_heads) + make_vocab_size_divisible_by: int = 128 # + + # ===== PuzzletronHeterogeneousTransformerConfig attributes ===== + # Actual new PuzzleNemotronModelConfig attributes + heterogeneous_layers_config_path: Optional[str] = None + heterogeneous_layers_config_encoded_json: Optional[str] = None + transformer_layer_spec: Union[ModuleSpec, Callable[["GPTConfig"], ModuleSpec]] = ( + heterogeneous_layer_spec_puzzletron + ) + + # HF-specific metadata for lossless round-trip conversion (HF → NeMo → HF) + # Stores HF config fields that don't have direct NeMo equivalents + source_hf_config_metadata: Optional[Dict[str, Any]] = None + + # NOTE: How activation_func is handled for Puzzletron models + # ============================================================== + # Puzzletron models can define activation functions per-block, but MCore's validation + # only checks the global activation_func (not per-block activations). + # See: https://github.com/NVIDIA/Megatron-LM/blob/268fda08592528b7bc1a21aadaed259980ca8efb/megatron/core/transformer/transformer_config.py#L1043-L1061 + # + # Current approach (enforced in instantiate_nemo_config_from_adapted_dict): + # - All blocks must use the SAME activation function (None allowed for no-op blocks) + # - The global activation_func is set to match the blocks' shared activation + # - This ensures MCore's global validation passes correctly + # + # Rationale: + # 1. MCore validates global activation_func during __post_init__() (lines 1043-1061) + # 2. NeMo calls __post_init__() AGAIN during trainer.strategy.connect(model) + # See: https://github.com/NVIDIA/NeMo/blob/2e19aebd8c8fa9ff7ce9b5076ce130404713443c/nemo/lightning/_strategy_lib.py#L172-L175 + # 3. At runtime, MCore uses per-block activations from get_config_for_layer() + # See: https://github.com/NVIDIA/Megatron-LM/blob/268fda08592528b7bc1a21aadaed259980ca8efb/megatron/core/transformer/transformer_block.py#L308-L319 + # + # For heterogeneous activations across blocks, MCore would need to update their + # validation logic to support per-block validation (e.g., in get_config_for_layer() or MLP.__init__) + + # ===== Llama31Config method ===== + def configure_model( + self, tokenizer, pre_process=None, post_process=None, vp_stage=None + ) -> "MCoreGPTModel": + """Configure and instantiate a Megatron Core Llama 3.1 model. + + NOTE: This method is originally from Llama31Config and is explicitly included here + for consistency and clarity. It maintains the same behavior as the original + Llama hierarchy inheritance approach. + + Extends the base configuration with Llama 3.1 specific RoPE scaling. + This method applies RoPE scaling for extended context length support. + """ + model = super().configure_model(tokenizer, pre_process, post_process, vp_stage) + # Apply rope scaling for Llama3.1 model + model.rotary_pos_emb.inv_freq = apply_rope_scaling( + model.rotary_pos_emb.inv_freq, + factor=self.scale_factor, + low_freq_factor=self.low_freq_factor, + high_freq_factor=self.high_freq_factor, + old_context_len=self.old_context_len, + ) + return model + + @classmethod + def from_dict_with_preprocessing(cls, config_dict): + # Potentially adapt the config_dict before instantiation + instance = cls(**config_dict) + # Potentially adapt the config after instantiation + return instance + + # static method + @staticmethod + def create_adapted_config_dict_from_puzzletron_config(cfg): + # TODO: consider doing do this without conversion to dictionary in the future (instead have an adapted config object) + # Create an empty config object of the same class as cfg + adapted_cfg_dict = dict() + orig_cfg_dict = vars(cfg) + + # Extract first set of values from the original config + adapted_cfg_dict["head_dim"] = orig_cfg_dict["head_dim"] + adapted_cfg_dict["num_attention_heads"] = orig_cfg_dict["num_attention_heads"] + # Handle rope_scaling - can be None, missing, or a dict + adapted_cfg_dict["rope_scaling"] = orig_cfg_dict.get("rope_scaling") or {} + + block_conf = { + "block_configs": [ + { + "attention": convert_attention_config_from_cfg_object( + orig_cfg_dict["block_configs"][i].attention, + adapted_cfg_dict["num_attention_heads"], + adapted_cfg_dict["head_dim"], + ), + "mlp": { + **convert_mlp_config_from_cfg_object( + orig_cfg_dict["block_configs"][i].ffn, + ( + orig_cfg_dict["block_configs"][i].parallel_blocks + if hasattr(orig_cfg_dict["block_configs"][i], "parallel_blocks") + else None + ), + ), + # Store the per-block activation function as a string (for JSON serialization) + "hidden_act": ( + orig_cfg_dict["block_configs"][i].ffn.hidden_act + if not ( + orig_cfg_dict["block_configs"][i].ffn.no_op + or orig_cfg_dict["block_configs"][i].ffn.replace_with_linear + ) + else None + ), + }, + } + for i in range(len(orig_cfg_dict["block_configs"])) + ] + } + if orig_cfg_dict["o_proj_bias"] != orig_cfg_dict["attention_bias"]: + raise NotImplementedError("o_proj_bias is not fully supported") + if orig_cfg_dict["position_embedding_type"] not in ["rope", "yarn"]: + # this one is not supported by MCore + raise ValueError( + f"only rope and yarn are supported, got {orig_cfg_dict['position_embedding_type']}" + ) + + # Handle dtype (new format uses 'dtype', old format uses 'torch_dtype') + # Check 'dtype' first, then fall back to 'torch_dtype' + if "dtype" in orig_cfg_dict and orig_cfg_dict["dtype"] is not None: + mprint(f"DEBUG: dtype found in config: {orig_cfg_dict['dtype']}") + adapted_cfg_dict["torch_dtype"] = orig_cfg_dict["dtype"] + elif "torch_dtype" in orig_cfg_dict and orig_cfg_dict["torch_dtype"] is not None: + mprint(f"DEBUG: torch_dtype found in config: {orig_cfg_dict['torch_dtype']}") + adapted_cfg_dict["torch_dtype"] = orig_cfg_dict["torch_dtype"] + else: + mprint( + f"WARNING: neither dtype nor torch_dtype found in config (or both are None), setting to bfloat16" + ) + adapted_cfg_dict["torch_dtype"] = "bfloat16" + + # TODO: check how config keys such as position_embedding_type are handled (since they're not passed to the constructor) + adapted_cfg_dict["heterogeneous_layers_config_path"] = None + adapted_cfg_dict["block_configs"] = block_conf["block_configs"] + adapted_cfg_dict["heterogeneous_layers_config_encoded_json"] = json.dumps( + block_conf, ensure_ascii=False + ) + adapted_cfg_dict["transformer_layer_spec"] = heterogeneous_layer_spec_puzzletron + adapted_cfg_dict["vocab_size"] = orig_cfg_dict["vocab_size"] + adapted_cfg_dict["num_layers"] = len(orig_cfg_dict["block_configs"]) + adapted_cfg_dict["hidden_size"] = orig_cfg_dict["hidden_size"] + # adapted_cfg_dict['num_attention_heads'] = cfg["num_attention_heads"] + adapted_cfg_dict["kv_channels"] = adapted_cfg_dict["head_dim"] + adapted_cfg_dict["scale_factor"] = float( + adapted_cfg_dict["rope_scaling"].get("factor", 8.0) + ) + adapted_cfg_dict["rotary_base"] = int(orig_cfg_dict.get("rope_theta", 500_000)) + adapted_cfg_dict["seq_length"] = int(orig_cfg_dict.get("max_position_embeddings", 131072)) + adapted_cfg_dict["init_method_std"] = float(orig_cfg_dict.get("initializer_range", 0.02)) + adapted_cfg_dict["layernorm_epsilon"] = float(orig_cfg_dict.get("rms_norm_eps", 1e-5)) + adapted_cfg_dict["share_embeddings_and_output_weights"] = bool( + orig_cfg_dict.get("tie_word_embeddings", False) + ) + # adapted_cfg_dict["make_vocab_size_divisible_by"] = 128 + + # Preserve HF-specific config fields that don't have NeMo equivalents + # This enables lossless round-trip conversion HF → NeMo → HF + source_hf_config_metadata = {} + + # eos_token_id: HF can have multiple EOS tokens [128001, 128008, 128009] + # but NeMo tokenizer only supports single eos_id (uses the last one) + if "eos_token_id" in orig_cfg_dict: + source_hf_config_metadata["eos_token_id"] = orig_cfg_dict["eos_token_id"] + + # auto_map: HF-specific field for custom model class loading via trust_remote_code + # Not relevant to NeMo but needed for HF model.from_pretrained() to work + if "auto_map" in orig_cfg_dict: + source_hf_config_metadata["auto_map"] = orig_cfg_dict["auto_map"] + + # dtype: HF uses 'dtype' field, NeMo uses 'torch_dtype', preserve both + if "dtype" in orig_cfg_dict: + source_hf_config_metadata["dtype"] = orig_cfg_dict["dtype"] + + # Store as direct config attribute (will be serialized by NeMo automatically) + adapted_cfg_dict["source_hf_config_metadata"] = ( + source_hf_config_metadata if source_hf_config_metadata else None + ) + + return adapted_cfg_dict + + +class PuzzletronLlamaNemotronModel(GPTModel): + """Llama-Nemotron model implementation based on the GPT model architecture. + + This class provides a high-level interface for Llama-Nemotron models, + implementing the specific architecture and settings needed for Llama-Nemotron models. + """ + + def __init__( + self, + config: Annotated[ + Optional[PuzzletronNemotronModelConfig] | type[PuzzletronNemotronModelConfig], + Config[PuzzletronNemotronModelConfig], + ] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, + ): + super().__init__( + config or PuzzletronNemotronModelConfig(), + optim=optim, + tokenizer=tokenizer, + model_transform=model_transform, + ) + + +def instantiate_nemo_config_from_adapted_dict( + adapted_cfg_dict: dict, + generation_config: Optional["GenerationConfig"] = None, +) -> PuzzletronNemotronModelConfig: + """ + Instantiate PuzzletronNemotronModelConfig from adapted config dict. + + This function is shared by the importer and tests to ensure consistency. + + Args: + adapted_cfg_dict: Dict created by create_adapted_config_dict_from_puzzletron_config + generation_config: Optional generation config to attach + + Returns: + PuzzletronNemotronModelConfig instance + """ + + # Helper function for vocab size divisibility + def make_vocab_size_divisible_by(vocab_size: int) -> int: + base = 128 + while vocab_size % base != 0: + base //= 2 + return base + + # Keys used for PuzzletronNemotronModelConfig instantiation + INSTANTIATION_KEYS = { + "heterogeneous_layers_config_encoded_json", + "transformer_layer_spec", + "num_layers", + "hidden_size", + "num_attention_heads", + "kv_channels", + "scale_factor", + "init_method_std", + "layernorm_epsilon", + "seq_length", + "rotary_base", + "vocab_size", + "share_embeddings_and_output_weights", + "source_hf_config_metadata", + } + + # Keys that are metadata or derived (not directly passed to constructor) + metadata_keys = set(adapted_cfg_dict.keys()) - INSTANTIATION_KEYS + + mprint(f"DEBUG: Keys used for instantiation: {sorted(INSTANTIATION_KEYS)}") + mprint(f"DEBUG: Metadata keys (not used for direct instantiation): {sorted(metadata_keys)}") + for key in sorted(metadata_keys): + value = adapted_cfg_dict[key] + if isinstance(value, (list, dict)): + mprint(f" - {key}: {type(value).__name__} with {len(value)} items") + elif callable(value): + mprint(f" - {key}: {value.__name__ if hasattr(value, '__name__') else 'callable'}") + else: + mprint(f" - {key}: {value}") + + model_dtype = dtype_from_dict(adapted_cfg_dict) + + # Determine the unique activation_func from all blocks + # MCore validates the global activation_func, so we need to set it to match all blocks + heterogeneous_config = json.loads(adapted_cfg_dict["heterogeneous_layers_config_encoded_json"]) + block_list = heterogeneous_config.get("block_configs", []) + + # Assert that block_configs exists and is not empty + assert block_list, ( + "No block_configs found in heterogeneous_layers_config_encoded_json. " + "The JSON structure must contain a 'block_configs' list with at least one block." + ) + + activation_funcs = [] + + for i, block in enumerate(block_list): + # Extract hidden_act from MLP config (if present) + if "mlp" in block and "hidden_act" in block["mlp"]: + hidden_act_str = block["mlp"]["hidden_act"] + + # Track None/null values (used for no-op blocks) + if hidden_act_str is None: + activation_funcs.append(None) + continue + + # For now, only support silu and gelu activations + # See: https://github.com/NVIDIA/Megatron-LM/blob/268fda08592528b7bc1a21aadaed259980ca8efb/megatron/core/transformer/transformer_config.py#L1043-L1048 + if hidden_act_str == "silu": + activation_funcs.append(F.silu) + elif hidden_act_str == "gelu": + activation_funcs.append(F.gelu) + else: + raise NotImplementedError( + f"Unsupported activation function: '{hidden_act_str}' in block {i}. " + f"Only 'silu', 'gelu', and None/null are currently supported. " + f"MCore's bias_activation_fusion only validates these activation functions." + ) + # If no hidden_act key or no MLP, we treat it as None + else: + activation_funcs.append(None) + + # Separate None and not-None activations + not_none_activations = [f for f in activation_funcs if f is not None] + + # Check that all not-None activation functions are the same + unique_not_none = {id(f) for f in not_none_activations} + + if len(unique_not_none) == 0: + # No activation functions found (all blocks are no-op or have None) + # Default to F.silu to pass MCore validation + global_activation_func = F.silu + mprint( + "WARNING: No not-None activation functions found in blocks, defaulting global activation_func to F.silu" + ) + elif len(unique_not_none) == 1: + # All not-None blocks use the same activation function (safe) + global_activation_func = not_none_activations[0] + func_name = ( + global_activation_func.__name__ + if hasattr(global_activation_func, "__name__") + else str(global_activation_func) + ) + none_count = activation_funcs.count(None) + total_count = len(activation_funcs) + mprint( + f"INFO: All {total_count - none_count} not-None blocks use the same activation function: {func_name} ({none_count} None/no-op blocks)" + ) + else: + # Multiple different not-None activation functions found (currently not supported/tested) + func_names = [f.__name__ if hasattr(f, "__name__") else "None" for f in activation_funcs] + unique_func_names = set(f.__name__ for f in not_none_activations) + assert False, ( + f"Puzzletron blocks must all use the same activation function (None allowed for no-op blocks). " + f"Found {len(unique_not_none)} different not-None activation functions across blocks: {unique_func_names}. " + f"Block activations: {func_names}. " + f"MCore's validation only checks the global activation_func, which would not match heterogeneous activations. " + f"Either make all blocks use the same activation, or update MCore to support per-block validation." + ) + + return PuzzletronNemotronModelConfig( + heterogeneous_layers_config_encoded_json=adapted_cfg_dict[ + "heterogeneous_layers_config_encoded_json" + ], + heterogeneous_layers_config_path=None, # We directly load the block config as json + transformer_layer_spec=adapted_cfg_dict["transformer_layer_spec"], + activation_func=global_activation_func, # Set to match all blocks + num_layers=adapted_cfg_dict["num_layers"], + hidden_size=adapted_cfg_dict["hidden_size"], + num_attention_heads=adapted_cfg_dict["num_attention_heads"], + kv_channels=adapted_cfg_dict["kv_channels"], + scale_factor=adapted_cfg_dict["scale_factor"], + init_method_std=adapted_cfg_dict["init_method_std"], + layernorm_epsilon=adapted_cfg_dict["layernorm_epsilon"], + seq_length=adapted_cfg_dict["seq_length"], + rotary_base=adapted_cfg_dict["rotary_base"], + make_vocab_size_divisible_by=make_vocab_size_divisible_by(adapted_cfg_dict["vocab_size"]), + vocab_size=adapted_cfg_dict["vocab_size"], + share_embeddings_and_output_weights=adapted_cfg_dict["share_embeddings_and_output_weights"], + # HF-specific metadata for lossless round-trip conversion + source_hf_config_metadata=adapted_cfg_dict.get("source_hf_config_metadata"), + fp16=(model_dtype == torch.float16), + bf16=(model_dtype == torch.bfloat16), + params_dtype=model_dtype, + generation_config=generation_config, + ) + + +@io.model_importer(PuzzletronLlamaNemotronModel, "hf") +class PuzzletronHFLlamaNemotronImporter( + io.ModelConnector["LlamaForCausalLM", PuzzletronLlamaNemotronModel] +): + """Importer for converting Hugging Face Llama-Nemotron models to NeMo format. + + This class handles the conversion of Hugging Face's LlamaForCausalLM models + to NeMo's PuzzletronLlamaNemotronModel format, including weight mapping and configuration translation. + """ + + # Base mapping using standard LLaMA weight names + # Layernorm wildcards are replaced with per-layer mappings in convert_state() + # TODO: MoE and Mamba layer conversions have not been tested yet + default_mapping = { + "model.embed_tokens.weight": "embedding.word_embeddings.weight", + "model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight", + "model.layers.*.mlp.down_proj.weight": "decoder.layers.*.mlp.linear_fc2.weight", + "model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "model.norm.weight": "decoder.final_layernorm.weight", + "lm_head.weight": "output_layer.weight", + } + + def init(self) -> PuzzletronLlamaNemotronModel: + """Initialize a NeMo LlamaModel instance. + + Returns: + LlamaModel: Initialized NeMo Llama model with the appropriate configuration + and tokenizer. + """ + config = self.config + mprint(f"DEBUG: NeMo config dtype settings:") + mprint(f" - config.bf16: {config.bf16}") + mprint(f" - config.fp16: {config.fp16}") + mprint(f" - config.params_dtype: {config.params_dtype}") + return PuzzletronLlamaNemotronModel(config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + """Apply the conversion from HF to NeMo format. + + Args: + output_path: Path where the converted model will be saved + + Returns: + Path: Path to the saved NeMo model + """ + from transformers import AutoModelForCausalLM + + logging.info(f"Load Puzzletron HF model {str(self)}") + source = AutoModelForCausalLM.from_pretrained( + str(self), trust_remote_code=True, torch_dtype="auto" + ) + logging.info("Initialize NeMo Puzzletron Llama Nemotron model") + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + + mprint( + f"Converted Llama-Nemotron model to Nemo, model saved to {output_path} in {source.dtype}." + ) + + teardown(trainer, target) + del trainer, target + + return output_path + + def convert_state(self, source: Any, target: Any) -> Any: + """Convert state dict from HF format to NeMo format. + + Maps the weights from the HF model to the NeMo model according to + the appropriate mapping scheme. + + Args: + source: Source HF model + target: Target NeMo model + + Returns: + The result of applying the transforms + """ + mapping = self.default_mapping.copy() + + if target.config.normalization == "LayerNorm": + mapping["model.norm.bias"] = "decoder.final_layernorm.bias" + if getattr(source.config, "tie_word_embeddings", False): + del mapping["lm_head.weight"] + + # Puzzletron models must have block_configs for heterogeneous layer support + assert hasattr(source.config, "block_configs"), "Puzzletron models must have block_configs" + + # Build per-layer specific mappings for heterogeneous support + attn_mapping, ffn_mapping, mamba_mapping, moe_mapping, transform_specs = ( + _build_puzzletron_mappings_and_transforms(source.config) + ) + + # Remove layernorm wildcards from default_mapping - these will be replaced with + # specific per-layer mappings based on each layer's architecture. + for pattern in [ + "model.layers.*.input_layernorm.weight", + "model.layers.*.post_attention_layernorm.weight", + ]: + if pattern in mapping: + del mapping[pattern] + + # Add all layer-specific mappings + mapping.update(**attn_mapping) + mapping.update(**ffn_mapping) + mapping.update(**mamba_mapping) + mapping.update(**moe_mapping) + + # Create transforms from specification + transforms = [] + + # Helper to create merge_qkv closure with proper layer index capture + def make_merge_qkv_fn(layer_idx): + def merge_qkv_fn(ctx, q, k, v): + return merge_qkv_for_puzzletron(ctx, q, k, v, idx=layer_idx) + + return merge_qkv_fn + + for spec in transform_specs: + if spec["transform_function"] == "merge_qkv_for_puzzletron": + # Fixed: proper closure to avoid variable capture issues + layer_idx = spec["kwargs"]["idx"] + transforms.append( + io.state_transform( + source_key=spec["source_key"], + target_key=spec["target_key"], + fn=make_merge_qkv_fn(layer_idx), + ) + ) + elif spec["transform_function"] == "merge_fc1_for_moe": + transforms.append( + io.state_transform( + source_key=spec["source_key"], + target_key=spec["target_key"], + fn=TransformFns.merge_fc1, + ) + ) + + # Add standard FC1 merge transform + transforms.append( + io.state_transform( + source_key=( + "model.layers.*.mlp.gate_proj.weight", + "model.layers.*.mlp.up_proj.weight", + ), + target_key="decoder.layers.*.mlp.linear_fc1.weight", + fn=TransformFns.merge_fc1, + ) + ) + return io.apply_transforms(source, target, mapping=mapping, transforms=transforms) + + @property + def tokenizer(self) -> "AutoTokenizer": + """Get the tokenizer for the HF model. + + Returns: + AutoTokenizer: Tokenizer instance initialized from the HF model's tokenizer + """ + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + return AutoTokenizer(self.save_hf_tokenizer_assets(str(self)), trust_remote_code=True) + + @property + def config(self) -> PuzzletronNemotronModelConfig: + """Create a NeMo LlamaNemotronConfig from the HF model config. + + Translates the HF configuration parameters to the equivalent NeMo + configuration. + + Returns: + PuzzletronNemotronModelConfig: Puzzletron NeMo configuration for Llama models + """ + from transformers import AutoConfig, GenerationConfig + + source = AutoConfig.from_pretrained(str(self), trust_remote_code=True) + + # Validate that this is a proper Puzzletron-Nemotron checkpoint + assert getattr(source, "rope_scaling", None), ( + "Llama-Nemotron model should have rope scaling" + ) + assert getattr(source, "block_configs", None) is not None, ( + "Puzzletron-Nemotron model should be heterogeneous and have block configs" + ) + + adapted_cfg_dict = ( + PuzzletronNemotronModelConfig.create_adapted_config_dict_from_puzzletron_config(source) + ) + + try: + generation_config = GenerationConfig.from_pretrained(str(self)) + except Exception: + generation_config = None + + output = instantiate_nemo_config_from_adapted_dict( + adapted_cfg_dict, generation_config=generation_config + ) + return output + + +@io.model_exporter(PuzzletronLlamaNemotronModel, "hf") +class PuzzletronHFLlamaNemotronExporter( + io.ModelConnector[PuzzletronLlamaNemotronModel, "LlamaForCausalLM"] +): + """Exporter for converting NeMo Puzzletron Llama-Nemotron models to Hugging Face format. + + This class handles the conversion of NeMo's PuzzletronLlamaNemotronModel to Hugging Face's + LlamaForCausalLM format, including weight mapping and configuration translation. + It supports heterogeneous model architectures with Puzzletron-specific configurations. + + The exporter performs the following key operations: + 1. Initializes a Hugging Face model with appropriate configuration + 2. Maps weights from NeMo format to Hugging Face format + 3. Handles special cases for heterogeneous architectures with Mamba, MoE, and other custom layers + 4. Saves the converted model and tokenizer to the specified output path + + Attributes: + tokenizer: The tokenizer associated with the NeMo model + config: The configuration for the Hugging Face model + + Methods: + init: Initialize a Hugging Face model instance + apply: Convert and save the model to Hugging Face format + convert_state: Convert model weights from NeMo to Hugging Face format + """ + + # Base mapping for NeMo -> HF conversion (reversed from importer) + # Layernorm wildcards are replaced with per-layer mappings in convert_state() + default_mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "output_layer.weight": "lm_head.weight", + } + + @property + def config(self) -> "DeciLMConfig": + """Create a HF DeciLMConfig from the NeMo model config. + + This method constructs a DeciLMConfig for Puzzletron models by parsing the + heterogeneous_layers_config_encoded_json from the NeMo config and mapping + the fields to the HF DeciLM format. + + Returns: + DeciLMConfig: HF configuration for Puzzletron DeciLM models + """ + # Load the NeMo config + source_config = io.load_context(str(self), subpath="model.config") + + # Get preserved HF config metadata (stored as direct attribute) + # This enables lossless round-trip conversion HF → NeMo → HF + source_hf_config_metadata = getattr(source_config, "source_hf_config_metadata", None) or {} + + # Get EOS token ID(s) - prefer preserved value from source HF config metadata + # (HF supports multiple EOS tokens, NeMo tokenizer only has single eos_id) + eos_token_id = source_hf_config_metadata.get("eos_token_id", self.tokenizer.eos_id) + + # Use the shared conversion function + return convert_nemo_config_to_hf_decilm_config( + nemo_config=source_config, + vocab_size=self.tokenizer.vocab_size, + eos_token_id=eos_token_id, + bos_token_id=self.tokenizer.bos_id, + pad_token_id=getattr(self.tokenizer, "pad_id", None), + ) + + def init(self, dtype=torch.bfloat16, from_config=False, model_name=None) -> "LlamaForCausalLM": + """Initialize a Hugging Face LlamaForCausalLM model instance. + + This method creates a new Hugging Face model instance with the appropriate configuration + and data type. Puzzletron models always use from_config=True and create a DeciLMForCausalLM. + + Args: + dtype (torch.dtype, optional): Data type for model parameters. Defaults to torch.bfloat16. + from_config (bool, optional): Whether to initialize from config or load from pretrained. + For Puzzletron models, this should always be True. Defaults to False. + model_name (str, optional): Name of the pretrained model to load. Not used for Puzzletron + models since we generate the config dynamically. Defaults to None. + + Returns: + DeciLMForCausalLM: Initialized Hugging Face DeciLM model instance + + Raises: + ValueError: If model_name is provided (not supported for Puzzletron models) + """ + from transformers.modeling_utils import no_init_weights + + from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import ( + DeciLMForCausalLM, + ) + + with no_init_weights(): + if from_config: + # Puzzletron models: create DeciLMForCausalLM from self.config property + model = DeciLMForCausalLM(self.config) + model = model.to(dtype=dtype) + return model + else: + # Puzzletron models don't support loading from pretrained HF model cards + raise ValueError( + "Puzzletron models do not have official HF model cards. " + "Use from_config=True to create models from NeMo config." + ) + + def apply(self, output_path: Path, target_model_name=None) -> Path: + """Convert and save a NeMo Puzzletron Llama-Nemotron model to Hugging Face format. + + This method performs the complete conversion process: + 1. Loads the NeMo model checkpoint + 2. Creates the Hugging Face model from config + 3. Converts and transfers the weights + 4. Saves the converted model and tokenizer + + Args: + output_path (Path): Directory path where the converted model will be saved + target_model_name (str, optional): Not used for Puzzletron models. Kept for API compatibility. + + Returns: + Path: Path to the saved Hugging Face model directory + """ + logging.info("Loading Puzzletron Llama-Nemotron NeMo checkpoint..") + source, _ = self.nemo_load(str(self)) + + # Puzzletron models always use from_config=True to generate DeciLMConfig dynamically + target = self.init( + torch_dtype_from_mcore_config(source.config), + from_config=True, + model_name=None, + ) + target = self.convert_state(source, target) + + target = target.cpu() + target.save_pretrained(output_path) + self.tokenizer.tokenizer.save_pretrained(output_path) + + # Copy custom Python files needed for Puzzletron models + from modelopt.torch.puzzletron.export.MCore.llama_nemotron_utils import ( + copy_puzzletron_python_files_to_decilm_checkpoint, + embed_chat_template_in_tokenizer_config, + ) + + copy_puzzletron_python_files_to_decilm_checkpoint(str(output_path)) + + # Fix tokenizer: embed chat_template from .jinja file into tokenizer_config.json + # NeMo's HF → NeMo import extracts chat_template to .jinja but doesn't preserve + # it in tokenizer_config.json. We restore it here for accuracy parity. + embed_chat_template_in_tokenizer_config(str(self), str(output_path)) + + return output_path + + def convert_state(self, source: Any, target: Any) -> Any: + """Convert state dict from NeMo format to HF format. + + Maps the weights from the NeMo model to the HF model according to + the appropriate mapping scheme for Puzzletron models. + + This method follows the same pattern as the importer but with reversed mappings: + 1. Start with default mapping + 2. Remove layernorm wildcards (will be replaced with per-layer mappings) + 3. Build per-layer specific mappings using helper function and reverse them + 4. Create transforms for weight conversions + + Args: + source: Source NeMo model + target: Target HF model + + Returns: + The target model with weights transferred from source + """ + mapping = self.default_mapping.copy() + + # Handle LayerNorm bias if present + if source.config.normalization == "LayerNorm": + mapping["decoder.final_layernorm.bias"] = "model.norm.bias" + + # Handle tied embeddings + if getattr(source.config, "share_embeddings_and_output_weights", False): + # Remove output_layer mapping if embeddings are tied + if "output_layer.weight" in mapping: + del mapping["output_layer.weight"] + + # Build per-layer specific mappings for heterogeneous support + attn_mapping, ffn_mapping, mamba_mapping, moe_mapping, transform_specs = ( + _build_puzzletron_mappings_and_transforms(source.config) + ) + + # Remove layernorm wildcards from default_mapping - these will be replaced with + # specific per-layer mappings based on each layer's architecture. + for pattern in [ + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + ]: + if pattern in mapping: + del mapping[pattern] + + # For exporter: reverse all mappings (HF -> NeMo becomes NeMo -> HF) + attn_mapping = {v: k for k, v in attn_mapping.items()} + ffn_mapping = {v: k for k, v in ffn_mapping.items()} + mamba_mapping = {v: k for k, v in mamba_mapping.items()} + moe_mapping = {v: k for k, v in moe_mapping.items()} + + # Add all layer-specific mappings + mapping.update(**attn_mapping) + mapping.update(**ffn_mapping) + mapping.update(**mamba_mapping) + mapping.update(**moe_mapping) + + # Create transforms from specifications (reversed for exporter) + transforms = [] + + # Helper to create split_qkv closure with proper layer index capture + def make_split_qkv_fn(layer_idx): + def split_qkv_fn(ctx, qkv): + return split_qkv_for_puzzletron(ctx, qkv, idx=layer_idx) + + return split_qkv_fn + + for spec in transform_specs: + if spec["transform_function"] == "merge_qkv_for_puzzletron": + # For exporter: split QKV (NeMo -> HF) + layer_idx = spec["kwargs"]["idx"] + transforms.append( + io.state_transform( + source_key=spec["target_key"], # NeMo key + target_key=spec["source_key"], # HF key + fn=make_split_qkv_fn(layer_idx), + ) + ) + elif spec["transform_function"] == "merge_fc1_for_moe": + # For exporter: split FC1 for MoE (NeMo -> HF) + transforms.append( + io.state_transform( + source_key=spec["target_key"], # NeMo key + target_key=spec["source_key"], # HF key + fn=TransformFns.split_fc1, + ) + ) + + # Add standard transforms for FC1 splitting and padding pruning + transforms.extend( + [ + io.state_transform( + source_key="decoder.layers.*.mlp.linear_fc1.weight", + target_key=( + "model.layers.*.mlp.gate_proj.weight", + "model.layers.*.mlp.up_proj.weight", + ), + fn=TransformFns.split_fc1, + ), + io.state_transform( + source_key="embedding.word_embeddings.weight", + target_key="model.embed_tokens.weight", + fn=TransformFns.prune_padding, + ), + io.state_transform( + source_key="output_layer.weight", + target_key="lm_head.weight", + fn=TransformFns.prune_padding, + ), + ] + ) + + return io.apply_transforms( + source, + target, + mapping=mapping, + transforms=transforms, + ) + + @property + def tokenizer(self) -> "TokenizerSpec": + """Get the tokenizer from the NeMo model. + + Returns: + TokenizerSpec: Tokenizer from the NeMo model + """ + return io.load_context(str(self), subpath="model").tokenizer + + +__all__ = [ + "PuzzletronLlamaNemotronModel", +] diff --git a/modelopt/torch/puzzletron/export/MCore/llama_nemotron_utils.py b/modelopt/torch/puzzletron/export/MCore/llama_nemotron_utils.py new file mode 100644 index 000000000..8d01ec953 --- /dev/null +++ b/modelopt/torch/puzzletron/export/MCore/llama_nemotron_utils.py @@ -0,0 +1,729 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from dataclasses import asdict +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from megatron.core.transformer.spec_utils import ModuleSpec +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import ( + AutoTokenizer as NemoAutoTokenizer, +) +from nemo.collections.llm.gpt.model.base import GPTModel +from nemo.collections.llm.gpt.model.llama_nemotron import ( + HFLlamaNemotronImporter, + PuzzletronNemotronModelConfig, +) +from nemo.lightning import io, teardown +from nemo.lightning.io.state import TransformFns +from nemo.lightning.pytorch.utils import dtype_from_str +from nemo.utils.import_utils import safe_import +from transformers import AutoModelForCausalLM, AutoTokenizer + +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch.puzzletron.export.MCore.puzzletron_layer_specs import ( + PuzzletronAttentionConfig, + PuzzletronHeterogeneousTransformerConfig, + PuzzletronMLPConfig, + get_gpt_heterogeneous_layer_spec_puzzletron, +) + + +def convert_attention_config_from_cfg_object(attention_config, num_attention_heads, head_dim): + for unsupported_key in [ + "llama4", + "num_sink_tokens", + "sparsify", + "unshifted_sink", + "use_prefill_window_in_sink_attention", + ]: + if hasattr(attention_config, unsupported_key) and getattr( + attention_config, unsupported_key + ) not in [ + None, + False, + ]: + # + # if attention_config.get(unsupported_key, None) not in [None, False]: + raise NotImplementedError(f"{unsupported_key} is not supported") + window_size = attention_config.window_size if hasattr(attention_config, "window_size") else None + if window_size is not None: + window_size = (window_size, 0) + is_mamba = attention_config.mamba if hasattr(attention_config, "mamba") else False + n_heads_in_group = ( + attention_config.n_heads_in_group if hasattr(attention_config, "n_heads_in_group") else 1 + ) + if n_heads_in_group is None: + n_heads_in_group = 1 + return asdict( + PuzzletronAttentionConfig( + no_op=attention_config.no_op if hasattr(attention_config, "no_op") else False, + replace_with_linear=( + attention_config.replace_with_linear + if hasattr(attention_config, "replace_with_linear") + else False + ), + num_attention_heads=num_attention_heads, + num_query_groups=num_attention_heads // n_heads_in_group, + kv_channels=head_dim, + window_size=window_size, + multi_latent_attention=False, + is_mamba=is_mamba, + mamba_state_dim=( + attention_config.mamba.state_dim + if is_mamba and hasattr(attention_config.mamba, "state_dim") + else 128 + ), + mamba_head_dim=( + attention_config.mamba.head_dim + if is_mamba and hasattr(attention_config.mamba, "head_dim") + else 64 + ), + mamba_num_groups=( + attention_config.mamba.num_groups + if is_mamba and hasattr(attention_config.mamba, "num_groups") + else 8 + ), + mamba_num_heads=( + attention_config.mamba.num_heads + if is_mamba and hasattr(attention_config.mamba, "num_heads") + else None + ), + ) + ) + + +def convert_mlp_config_from_cfg_object(mlp_config, parallel_blocks): + """Convert MLP config from HF format to NeMo format. + + Args: + mlp_config: The MLP configuration object from HF + parallel_blocks: Parallel blocks configuration (not currently supported) + """ + if parallel_blocks is not None: + raise NotImplementedError("parallel_blocks is not supported") + if not hasattr(mlp_config, "gated") or mlp_config.gated is False: + raise NotImplementedError("notgated MLP is not supported") + + # Validate this block's activation function + if not hasattr(mlp_config, "hidden_act"): + raise ValueError(f"MLP config must have hidden_act attribute") + # if mlp_config.hidden_act != block_hidden_act: + # raise ValueError(f"MLP config hidden_act mismatch: config has {mlp_config.hidden_act}, expected {block_hidden_act}") + + if hasattr(mlp_config, "sparsify") and mlp_config.sparsify is not None: + raise NotImplementedError("sparsify is not supported") + is_moe = hasattr(mlp_config, "moe") and mlp_config.moe is not None + # Note: hidden_act is validated above but not stored in PuzzletronMLPConfig + # It will be used at the call site for the NeMo model config + return asdict( + PuzzletronMLPConfig( + no_op=mlp_config.no_op if hasattr(mlp_config, "no_op") else False, + replace_with_linear=mlp_config.replace_with_linear + if hasattr(mlp_config, "replace_with_linear") + else False, + ffn_hidden_size=mlp_config.intermediate_size + if hasattr(mlp_config, "intermediate_size") + else None, + num_moe_experts=( + mlp_config.moe.num_local_experts + if is_moe and hasattr(mlp_config.moe, "num_local_experts") + else None + ), + moe_shared_expert_intermediate_size=( + mlp_config.moe.shared_expert_intermediate_dim + if is_moe and hasattr(mlp_config.moe, "shared_expert_intermediate_dim") + else None + ), + moe_ffn_hidden_size=( + mlp_config.moe.expert_intermediate_dim + if is_moe and hasattr(mlp_config.moe, "expert_intermediate_dim") + else None + ), + moe_router_topk=( + mlp_config.moe.num_experts_per_tok + if is_moe and hasattr(mlp_config.moe, "num_experts_per_tok") + else 2 + ), + ) + ) + + +def convert_nemo_config_to_hf_decilm_config( + nemo_config: "PuzzletronNemotronModelConfig", + vocab_size: int, + eos_token_id: Union[int, List[int], None] = None, + bos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, +) -> "DeciLMConfig": + """Convert a NeMo PuzzletronNemotronModelConfig to HF DeciLMConfig. + + This function extracts the conversion logic from the exporter so it can be + used in unit tests without requiring file I/O. + + Args: + nemo_config: The NeMo config to convert + vocab_size: Vocabulary size for the HF config + eos_token_id: EOS token ID(s). Can be int or list of ints. + bos_token_id: BOS token ID + pad_token_id: PAD token ID + + Returns: + DeciLMConfig: The equivalent HF config + """ + + # Get preserved HF config metadata (stored as direct attribute) + # This enables lossless round-trip conversion HF → NeMo → HF + source_hf_config_metadata = getattr(nemo_config, "source_hf_config_metadata", None) or {} + + # Parse the heterogeneous layers config from JSON + block_configs = [] + + if ( + hasattr(nemo_config, "heterogeneous_layers_config_encoded_json") + and nemo_config.heterogeneous_layers_config_encoded_json + ): + try: + heterogeneous_config = json.loads(nemo_config.heterogeneous_layers_config_encoded_json) + raw_block_configs = heterogeneous_config.get("block_configs", []) + + for i, raw_block_config in enumerate(raw_block_configs): + attn_block = raw_block_config.get("attention", {}) + mlp_block = raw_block_config.get("mlp", {}) + + # Configure attention + attention_config = { + "no_op": attn_block.get("no_op", False), + "replace_with_linear": attn_block.get("replace_with_linear", False), + "sparsify": attn_block.get("sparsify", None), + "n_heads_in_group": attn_block.get( + "num_attention_heads", nemo_config.num_attention_heads + ) + // attn_block.get("num_query_groups", nemo_config.num_query_groups), + "window_length": attn_block.get("window_size", None), + "num_sink_tokens": attn_block.get("num_sink_tokens", None), + "use_prefill_window_in_sink_attention": attn_block.get( + "use_prefill_window_in_sink_attention", False + ), + "unshifted_sink": attn_block.get("unshifted_sink", False), + } + + # Handle Mamba: convert from NeMo flat structure to HF nested structure + if attn_block.get("is_mamba", False): + attention_config["mamba"] = { + "state_dim": attn_block.get("mamba_state_dim", 128), + "num_heads": attn_block.get( + "mamba_num_heads", nemo_config.num_attention_heads + ), + "head_dim": attn_block.get("mamba_head_dim", 64), + "num_groups": attn_block.get("mamba_num_groups", 8), + } + else: + attention_config["mamba"] = None + + # Handle Llama4: pass through as dict if present + attention_config["llama4"] = attn_block.get("llama4", None) + + # Configure FFN + ffn_config = { + "no_op": mlp_block.get("no_op", False), + "replace_with_linear": mlp_block.get("replace_with_linear", False), + "sparsify": mlp_block.get("sparsify", None), + "intermediate_size": mlp_block.get( + "ffn_hidden_size", nemo_config.ffn_hidden_size + ), + "gated": True, # Puzzletron uses gated activations + # Use the activation function name extracted from this block's config + "hidden_act": mlp_block.get("hidden_act", None), + } + + # Handle MoE: convert from NeMo flat structure to HF nested structure + num_moe_experts = mlp_block.get("num_moe_experts", None) + if num_moe_experts is not None: + ffn_config["moe"] = { + "num_local_experts": num_moe_experts, + "num_experts_per_tok": mlp_block.get("moe_router_topk", 1), + "expert_intermediate_dim": mlp_block.get("moe_ffn_hidden_size", 8192), + "shared_expert_intermediate_dim": mlp_block.get( + "moe_shared_expert_intermediate_size", 8192 + ), + } + else: + ffn_config["moe"] = None + + block_configs.append({"attention": attention_config, "ffn": ffn_config}) + except (json.JSONDecodeError, KeyError) as e: + raise ValueError(f"Could not parse heterogeneous config JSON: {e}") + else: + raise ValueError("No block configs found in source configuration") + + # Create rope scaling config + rope_scaling = { + "factor": nemo_config.scale_factor, + "low_freq_factor": getattr(nemo_config, "low_freq_factor", 1.0), + "high_freq_factor": getattr(nemo_config, "high_freq_factor", 4.0), + "original_max_position_embeddings": getattr(nemo_config, "old_context_len", 8192), + "rope_type": "llama3", + } + + # Get EOS token ID(s) - prefer preserved value from source HF config metadata or provided value + if eos_token_id is None: + eos_token_id = source_hf_config_metadata.get("eos_token_id", None) + + # Create DeciLM config + hf_config = DeciLMConfig( + block_configs=block_configs, + hidden_size=nemo_config.hidden_size, + max_position_embeddings=nemo_config.seq_length, + num_attention_heads=nemo_config.num_attention_heads, + num_hidden_layers=nemo_config.num_layers, + tie_word_embeddings=nemo_config.share_embeddings_and_output_weights, + vocab_size=vocab_size, + rms_norm_eps=nemo_config.layernorm_epsilon, + attention_bias=getattr(nemo_config, "attention_bias", False), + o_proj_bias=getattr( + nemo_config, "o_proj_bias", getattr(nemo_config, "attention_bias", False) + ), + rope_theta=nemo_config.rotary_base, + rope_scaling=rope_scaling, + position_embedding_type="rope", + architectures=["DeciLMForCausalLM"], + model_type="nemotron-nas", + eos_token_id=eos_token_id, + bos_token_id=bos_token_id, + pad_token_id=pad_token_id, + head_dim=nemo_config.kv_channels, + # Restore auto_map from preserved metadata (needed for trust_remote_code loading) + auto_map=source_hf_config_metadata.get( + "auto_map", + { + "AutoConfig": "configuration_decilm.DeciLMConfig", + "AutoModelForCausalLM": "modeling_decilm.DeciLMForCausalLM", + }, + ), + # Restore dtype field from preserved metadata + dtype=source_hf_config_metadata.get("dtype", "bfloat16"), + ) + + return hf_config + + +def _config_to_dict(config) -> Dict[str, Any]: + """Convert a config object to a dictionary. + + Args: + config: Either an object with attributes or already a dictionary + + Returns: + Dictionary representation of the config + """ + if isinstance(config, dict): + return config + return vars(config) + + +def _build_puzzletron_mappings_and_transforms( + source_config: PuzzletronHeterogeneousTransformerConfig, +) -> Tuple[Dict[str, str], Dict[str, str], Dict[str, str], Dict[str, str], List[Dict[str, Any]]]: + """Build mappings and transform specifications for Puzzletron heterogeneous models. + + Args: + source_config: The Puzzletron heterogeneous transformer configuration + + Returns: + Tuple containing: + - attn_mapping: Attention layer mappings + - ffn_mapping: FFN layer mappings + - mamba_mapping: Mamba layer mappings + - moe_mapping: MoE layer mappings + - transform_specs: List of transform specifications with source_key, target_key, transform_function + """ + attn_mapping = {} + ffn_mapping = {} + mamba_mapping = {} + moe_mapping = {} + transform_specs = [] + + # Determine config type and extract block configs + is_hf_config = hasattr(source_config, "block_configs") + is_nemo_config = ( + hasattr(source_config, "heterogeneous_layers_config_encoded_json") + and source_config.heterogeneous_layers_config_encoded_json + ) + assert not (is_hf_config and is_nemo_config), "Cannot have both HF and NeMo config" + + if is_hf_config: + # HF config case (importer) + block_configs = source_config.block_configs + elif is_nemo_config: + # NeMo config case (exporter) - parse JSON + try: + heterogeneous_config = json.loads( + source_config.heterogeneous_layers_config_encoded_json + ) + block_configs = heterogeneous_config.get("block_configs", []) + except (json.JSONDecodeError, KeyError): + block_configs = [] + else: + block_configs = [] + + # Check if we found any block configs + if not block_configs: + raise ValueError( + "No block configs found in source configuration. " + "Expected either 'block_configs' attribute (HF config) or " + "'heterogeneous_layers_config_encoded_json' attribute (NeMo config) with valid block configs." + ) + + # TODO it is better (more stable) to use target.config.block_configs + for idx, block_config in enumerate(block_configs): + # Convert block config to dictionary + block_dict = _config_to_dict(block_config) + + # Extract attention and FFN configs (handle both HF "ffn" and NeMo "mlp" keys) + attn = block_dict.get("attention") + ffn = block_dict.get("ffn") or block_dict.get("mlp") + + # Convert sub-configs to dictionaries + attn_dict = _config_to_dict(attn) if attn else {} + ffn_dict = _config_to_dict(ffn) if ffn else {} + + # Process attention config + # Handle both HF (mamba) and NeMo (is_mamba) keys + is_mamba = attn_dict.get("mamba") or attn_dict.get("is_mamba") + + if not attn or attn_dict.get("no_op"): + value = None + elif attn_dict.get("replace_with_linear"): + value = f"decoder.layers.{idx}.self_attention.layer_norm_weight" + elif is_mamba is not None: + value = f"decoder.layers.{idx}.self_attention.in_proj.layer_norm_weight" + for mamba_key in [ + "dt_bias", + "A_log", + "D", + "in_proj.weight", + "conv1d.weight", + "conv1d.bias", + "norm.weight", + "out_proj.weight", + ]: + mamba_mapping[f"model.layers.{idx}.self_attn.mamba_mixer.{mamba_key}"] = ( + f"decoder.layers.{idx}.self_attention.{mamba_key}" + ) + else: + value = f"decoder.layers.{idx}.self_attention.linear_qkv.layer_norm_weight" + # Store transform spec for QKV merging + transform_specs.append( + { + "source_key": ( + f"model.layers.{idx}.self_attn.q_proj.weight", + f"model.layers.{idx}.self_attn.k_proj.weight", + f"model.layers.{idx}.self_attn.v_proj.weight", + ), + "target_key": f"decoder.layers.{idx}.self_attention.linear_qkv.weight", + "transform_function": "merge_qkv_for_puzzletron", + "kwargs": {"idx": idx}, + } + ) + + if value is not None: + attn_mapping[f"model.layers.{idx}.input_layernorm.weight"] = value + + # Process FFN config + # Handle both HF (moe, moe.shared_expert_intermediate_dim) and NeMo (num_moe_experts, moe_shared_expert_intermediate_size) keys + moe_config = ffn_dict.get("moe") or ffn_dict.get("num_moe_experts") + shared_expert_size = None + if moe_config: + # Convert moe_config to dict if it's an object (HF case) + moe_dict = ( + _config_to_dict(moe_config) if not isinstance(moe_config, (int, type(None))) else {} + ) + shared_expert_size = moe_dict.get("shared_expert_intermediate_dim") or ffn_dict.get( + "moe_shared_expert_intermediate_size" + ) + + if not ffn or ffn_dict.get("no_op"): + value = None + elif ffn_dict.get("replace_with_linear"): + value = f"decoder.layers.{idx}.mlp.layer_norm_weight" + elif moe_config is not None: + value = f"decoder.layers.{idx}.pre_mlp_layernorm.weight" + moe_mapping[f"model.layers.{idx}.mlp.router.weight"] = ( + f"decoder.layers.{idx}.mlp.router.weight" + ) + # Store transform spec for MoE expert FC1 merging + transform_specs.append( + { + "source_key": ( + f"model.layers.{idx}.mlp.experts.*.gate_proj.weight", + f"model.layers.{idx}.mlp.experts.*.up_proj.weight", + ), + "target_key": f"decoder.layers.{idx}.mlp.experts.local_experts.*.linear_fc1.weight", + "transform_function": "merge_fc1_for_moe", + "kwargs": {}, + } + ) + moe_mapping[f"model.layers.{idx}.mlp.experts.*.down_proj.weight"] = ( + f"decoder.layers.{idx}.mlp.experts.local_experts.*.linear_fc2.weight" + ) + # Check for shared expert + if shared_expert_size not in [None, 0]: + # Store transform spec for MoE shared expert FC1 merging + transform_specs.append( + { + "source_key": ( + f"model.layers.{idx}.mlp.shared_expert.gate_proj.weight", + f"model.layers.{idx}.mlp.shared_expert.up_proj.weight", + ), + "target_key": f"decoder.layers.{idx}.mlp.shared_experts.linear_fc1.weight", + "transform_function": "merge_fc1_for_moe", + "kwargs": {}, + } + ) + moe_mapping[f"model.layers.{idx}.mlp.shared_expert.down_proj.weight"] = ( + f"decoder.layers.{idx}.mlp.shared_experts.linear_fc2.weight" + ) + else: + value = f"decoder.layers.{idx}.mlp.linear_fc1.layer_norm_weight" + + if value is not None: + ffn_mapping[f"model.layers.{idx}.post_attention_layernorm.weight"] = value + + return attn_mapping, ffn_mapping, mamba_mapping, moe_mapping, transform_specs + + +def merge_qkv_for_puzzletron( + ctx: io.state.TransformCTX, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + idx: Optional[int] = None, +): + """ + Merge q, k, v to interleave-concatenated qkv. + - Modified version of nemo.lightning.io.state.TransformFns.merge_qkv for Puzzletron + - idx can be provided to fetch megatron_config for a specific layer + - heads_per_group is derived from the shape of q and k, instead of calculating (head_num // num_query_groups) from config values + - num_query_groups is not fetched from a global config value, but calculated from head_num and heads_per_group + + Example: import HF {q|k|v}_proj to layer linear_qkv + """ + if idx is not None: + megatron_config = ctx.target.decoder.layers[idx].config + else: + megatron_config = ctx.target.config + head_num = megatron_config.num_attention_heads + heads_per_group = ( + q.shape[0] // k.shape[0] + ) # NOTE: This is important to support heterogeneous attention + num_query_groups = head_num // heads_per_group + hidden_size = megatron_config.hidden_size + head_size = megatron_config.kv_channels + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] + new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + +def split_qkv_for_puzzletron( + ctx: io.state.TransformCTX, qkv: torch.Tensor, idx: Optional[int] = None +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Split interleave-concatenated qkv to separate q, k, v. + - Inverse operation of merge_qkv_for_puzzletron for Puzzletron + - idx can be provided to fetch megatron_config for a specific layer + - heads_per_group is derived from the shape of qkv, instead of calculating from config values + - num_query_groups is not fetched from a global config value, but calculated from head_num and heads_per_group + + Example: export NeMo layer linear_qkv to HF {q|k|v}_proj + """ + if idx is not None: + megatron_config = ctx.source.decoder.layers[idx].config + else: + megatron_config = ctx.source.config + + head_num = megatron_config.num_attention_heads + head_size = megatron_config.kv_channels + # hidden_size = megatron_config.hidden_size + + # Calculate qkv_total_dim from the actual qkv tensor shape + # qkv shape is [head_size * (head_num + 2 * num_query_groups), hidden_size] + qkv_total_dim = qkv.shape[0] // head_size + num_query_groups = (qkv_total_dim - head_num) // 2 + heads_per_group = head_num // num_query_groups + + # Reshape qkv to 3D: [qkv_total_dim, head_size, hidden_size] + qkv = qkv.reshape([qkv_total_dim, head_size, -1]) + + # when converting base model (linear_qkv), hidden size = megatron_config.hidden_size + # when converting lora (linear_qkv.adapter.linear_out), hidden size = lora_r + actual_hidden_size = qkv.size(-1) + + # Create slice indices for q, k, v + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_proj = qkv[q_slice].reshape(-1, actual_hidden_size).cpu() + k_proj = qkv[k_slice].reshape(-1, actual_hidden_size).cpu() + v_proj = qkv[v_slice].reshape(-1, actual_hidden_size).cpu() + + return q_proj, k_proj, v_proj + + +def dtype_from_dict(config_dict): + """ + Extracts torch dtype from a HF config. + Handles both 'torch_dtype' (old format) and 'dtype' (new format). + """ + # Try torch_dtype first (old format), then dtype (new format) + if "torch_dtype" in config_dict: + torch_dtype = config_dict["torch_dtype"] + elif "dtype" in config_dict: + torch_dtype = config_dict["dtype"] + else: + raise ValueError("Expected config dict to have attr `torch_dtype` or `dtype`") + + if isinstance(torch_dtype, torch.dtype): + return torch_dtype + elif isinstance(torch_dtype, str): + return dtype_from_str(torch_dtype) + else: + raise ValueError(f"dtype is not of type str/torch.dtype, got {type(torch_dtype)}") + + +def copy_puzzletron_python_files_to_decilm_checkpoint(output_path: str) -> None: + """Copy custom Python files from puzzle_tools package to output directory. + + Puzzletron models require custom Python files (configuration_decilm.py, + modeling_decilm.py, etc.) to be present in the checkpoint directory for + loading with transformers.AutoModel. + + This function copies all Python files from puzzle_tools/deci_lm_hf_code/ + to ensure the exported checkpoint is fully functional. + + Args: + output_path: Directory where HF model is being saved + """ + import logging + import shutil + from pathlib import Path + + # Get the puzzle_tools/deci_lm_hf_code directory + # Navigate from this file: export/MCore/llama_nemotron_utils.py -> v1/puzzle_tools/deci_lm_hf_code/ + package_dir = Path(__file__).parent.parent.parent / "puzzle_tools" / "deci_lm_hf_code" + + if not package_dir.exists(): + logging.warning( + f"Custom files directory not found: {package_dir}. " + f"Exported checkpoint may not be loadable without these files." + ) + return + + # Copy all Python files from the package + output_dir = Path(output_path) + copied_files = [] + for src_file in package_dir.glob("*.py"): + dest_file = output_dir / src_file.name + shutil.copy2(src_file, dest_file) + copied_files.append(src_file.name) + + logging.info(f"Copied {len(copied_files)} custom Python files to {output_path}") + logging.debug(f"Custom files copied: {', '.join(sorted(copied_files)[:5])}...") # Show first 5 + + +def embed_chat_template_in_tokenizer_config(nemo_checkpoint_path: str, output_path: str) -> None: + """Embed chat_template from .jinja file into tokenizer_config.json. + + NeMo's HF → NeMo import extracts chat_template to a separate .jinja file + but doesn't preserve it in tokenizer_config.json. This causes accuracy drops + in evaluation. This function restores it by: + 1. Reading chat_template.jinja from the NeMo checkpoint + 2. Embedding it into the exported tokenizer_config.json + + Args: + nemo_checkpoint_path: Path to the NeMo checkpoint (.nemo file/directory) + output_path: Directory where HF model is being saved + """ + import logging + from pathlib import Path + + # Path to NeMo checkpoint tokenizer files + nemo_checkpoint = Path(nemo_checkpoint_path) + nemo_chat_template_jinja = ( + nemo_checkpoint / "context" / "nemo_tokenizer" / "chat_template.jinja" + ) + + # Path to exported tokenizer config + output_dir = Path(output_path) + output_tokenizer_config = output_dir / "tokenizer_config.json" + + # Check if both files exist + if not nemo_chat_template_jinja.exists(): + logging.debug( + f"No chat_template.jinja found in NeMo checkpoint at {nemo_chat_template_jinja}" + ) + return + + if not output_tokenizer_config.exists(): + logging.warning(f"tokenizer_config.json not found at {output_tokenizer_config}") + return + + # Read chat_template from .jinja file + chat_template_content = nemo_chat_template_jinja.read_text() + + # Load tokenizer_config.json + with open(output_tokenizer_config, "r") as f: + tokenizer_config = json.load(f) + + # Check if chat_template is already embedded (shouldn't be, but be safe) + if "chat_template" in tokenizer_config: + logging.debug("chat_template already embedded in tokenizer_config.json, skipping") + return + + # Embed the chat_template + tokenizer_config["chat_template"] = chat_template_content + + # Save updated tokenizer_config.json + with open(output_tokenizer_config, "w") as f: + json.dump(tokenizer_config, f, indent=2, ensure_ascii=False) + + logging.info(f"✓ Embedded chat_template from NeMo checkpoint into tokenizer_config.json") + logging.debug(f" Template length: {len(chat_template_content)} characters") diff --git a/modelopt/torch/puzzletron/export/MCore/puzzletron_hf_config_utils.py b/modelopt/torch/puzzletron/export/MCore/puzzletron_hf_config_utils.py new file mode 100644 index 000000000..11a8798ba --- /dev/null +++ b/modelopt/torch/puzzletron/export/MCore/puzzletron_hf_config_utils.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import asdict + +import torch +from megatron.core.transformer.spec_utils import ModuleSpec +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import ( + AutoTokenizer as NemoAutoTokenizer, +) +from nemo.collections.llm.gpt.model.base import GPTModel +from nemo.collections.llm.gpt.model.llama_nemotron import HFLlamaNemotronImporter +from nemo.lightning import io, teardown +from nemo.lightning.io.state import TransformFns +from nemo.utils.import_utils import safe_import +from transformers import AutoModelForCausalLM, AutoTokenizer + +from modelopt.torch.puzzletron.export.MCore.puzzletron_layer_specs import ( + PuzzletronAttentionConfig, + PuzzletronHeterogeneousTransformerConfig, + PuzzletronMLPConfig, + get_gpt_heterogeneous_layer_spec_puzzletron, +) + + +def convert_attention_config_from_cfg_object(attention_config, num_attention_heads, head_dim): + for unsupported_key in [ + "llama4", + "num_sink_tokens", + "sparsify", + "unshifted_sink", + "use_prefill_window_in_sink_attention", + ]: + if hasattr(attention_config, unsupported_key) and getattr( + attention_config, unsupported_key + ) not in [ + None, + False, + ]: + # + # if attention_config.get(unsupported_key, None) not in [None, False]: + raise NotImplementedError(f"{unsupported_key} is not supported") + window_size = attention_config.window_size if hasattr(attention_config, "window_size") else None + if window_size is not None: + window_size = (window_size, 0) + is_mamba = attention_config.mamba if hasattr(attention_config, "mamba") else False + n_heads_in_group = ( + attention_config.n_heads_in_group if hasattr(attention_config, "n_heads_in_group") else 1 + ) + if n_heads_in_group is None: + n_heads_in_group = 1 + return asdict( + PuzzletronAttentionConfig( + no_op=attention_config.no_op if hasattr(attention_config, "no_op") else False, + replace_with_linear=( + attention_config.replace_with_linear + if hasattr(attention_config, "replace_with_linear") + else False + ), + num_attention_heads=num_attention_heads, + num_query_groups=num_attention_heads // n_heads_in_group, + kv_channels=head_dim, + window_size=window_size, + multi_latent_attention=False, + is_mamba=is_mamba, + mamba_state_dim=( + attention_config.mamba.state_dim + if is_mamba and hasattr(attention_config.mamba, "state_dim") + else 128 + ), + mamba_head_dim=( + attention_config.mamba.head_dim + if is_mamba and hasattr(attention_config.mamba, "head_dim") + else 64 + ), + mamba_num_groups=( + attention_config.mamba.num_groups + if is_mamba and hasattr(attention_config.mamba, "num_groups") + else 8 + ), + mamba_num_heads=( + attention_config.mamba.num_heads + if is_mamba and hasattr(attention_config.mamba, "num_heads") + else None + ), + ) + ) + + +def convert_mlp_config_from_cfg_object(mlp_config, parallel_blocks, default_hidden_act): + if parallel_blocks is not None: + raise NotImplementedError("parallel_blocks is not supported") + if not hasattr(mlp_config, "gated") or mlp_config.gated is False: + raise NotImplementedError("non-gated MLP is not supported") + if not hasattr(mlp_config, "hidden_act") or mlp_config.hidden_act not in [default_hidden_act]: + raise NotImplementedError(f"all mlps must have the same activation ({default_hidden_act})") + if hasattr(mlp_config, "sparsify") and mlp_config.sparsify is not None: + raise NotImplementedError("sparsify is not supported") + is_moe = hasattr(mlp_config, "moe") and mlp_config.moe is not None + return asdict( + PuzzletronMLPConfig( + no_op=mlp_config.no_op if hasattr(mlp_config, "no_op") else False, + replace_with_linear=mlp_config.replace_with_linear + if hasattr(mlp_config, "replace_with_linear") + else False, + ffn_hidden_size=mlp_config.intermediate_size + if hasattr(mlp_config, "intermediate_size") + else None, + num_moe_experts=( + mlp_config.moe.num_local_experts + if is_moe and hasattr(mlp_config.moe, "num_local_experts") + else None + ), + moe_shared_expert_intermediate_size=( + mlp_config.moe.shared_expert_intermediate_dim + if is_moe and hasattr(mlp_config.moe, "shared_expert_intermediate_dim") + else None + ), + moe_ffn_hidden_size=( + mlp_config.moe.expert_intermediate_dim + if is_moe and hasattr(mlp_config.moe, "expert_intermediate_dim") + else None + ), + moe_router_topk=( + mlp_config.moe.num_experts_per_tok + if is_moe and hasattr(mlp_config.moe, "num_experts_per_tok") + else 2 + ), + ) + ) diff --git a/modelopt/torch/puzzletron/export/MCore/puzzletron_layer_specs.py b/modelopt/torch/puzzletron/export/MCore/puzzletron_layer_specs.py new file mode 100644 index 000000000..ec011ff28 --- /dev/null +++ b/modelopt/torch/puzzletron/export/MCore/puzzletron_layer_specs.py @@ -0,0 +1,928 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from dataclasses import asdict, dataclass, field, fields +from pathlib import Path +from typing import Any, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.models.gpt.gpt_layer_specs import ( + LayerType, + LNImpl, + TransformerBlockSubmodules, + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, + get_num_layers_to_build, + get_transformer_layer_offset, +) +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.post_training.modelopt.layers import Linear +from megatron.core.process_groups_config import ModelCommProcessGroups +from megatron.core.quantization.utils import ( + kitchen_quantization_recipe_config, + load_quantization_recipe, +) +from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules +from megatron.core.tensor_parallel.layers import ( + ColumnParallelLinear, + RowParallelLinear, + _initialize_affine_weight_cpu, +) +from megatron.core.tensor_parallel.random import get_cuda_rng_tracker +from megatron.core.transformer import MLATransformerConfig, TransformerConfig +from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.utils import get_te_version, is_te_min_version, is_torch_min_version + +# from megatron.core.activations import squared_relu #for megatron 0.14 version in future NeMo containers +from megatron.training.activations import squared_relu +from nemo.collections.llm.gpt.model.llama import Llama31Config70B +from packaging.version import Version as PkgVersion +from torch import Tensor +from torch.nn.parameter import Parameter + +try: + import transformer_engine as te # pylint: disable=unused-import + from megatron.core.extensions.transformer_engine import ( + TELayerNormColumnParallelLinear, + TELinear, + TENorm, + TERowParallelLinear, + _get_extra_te_kwargs, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False + +# TODO: check sharded_state_dict_keys_map => only if TE is disabled +# TODO: parallel Blocks +# TODO: multimodal +# https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/vlm/neva/model/base.py +# https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/vlm/qwen2vl/model/base.py + + +# NOTE based on https://github.com/NVIDIA/Megatron-LM/blob/aacc3b8aa5f0d3071431a94503d6233802fbaedd/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py#L144 +# TODO: what is the difference between this and the referenced one? +def _get_sharded_state_dict_keys_map( + block_config: "PuzzletronTransformerBlockConfig", use_transformer_engine: bool +): + """Generate a mapping of sharded state dictionary keys for Puzzletron transformer blocks. + + This function is a specialized version of the original Megatron-LM + `_get_sharded_state_dict_keys_map` function, adapted for Puzzletron's + heterogeneous transformer architecture with Mamba support. + + Key differences from the original: + - **Mamba Support**: Adds mapping for Mamba layers (`mixer.norm_`) + - **Enhanced Logic**: Uses if-elif-else structure instead of multiple if statements + - **No-op Handling**: Explicit handling of no-op attention and MLP cases + - **Simplified**: Removes `num_query_groups` check (handled in main logic) + - **Config Type**: Uses `PuzzletronTransformerBlockConfig` instead of `TransformerBlockConfig` + + Args: + block_config: Puzzletron transformer block configuration + use_transformer_engine: Whether to use Transformer Engine optimizations + + Returns: + dict: A dictionary mapping sharded state dictionary keys + """ + mapping = {} + if not use_transformer_engine: + if block_config.attention.replace_with_linear: + mapping.update({"input_layernorm.": "self_attention.layer_norm_"}) + elif block_config.attention.is_mamba: # Mamba, not sure about this + mapping.update({"input_layernorm.": "mixer.norm_"}) + elif not block_config.attention.no_op: # MHA and MLA + mapping.update({"input_layernorm.": "self_attention.linear_qkv.layer_norm_"}) + else: # No-op + pass + + if block_config.mlp.ffn_hidden_size is not None: # FFN + mapping.update({"pre_mlp_layernorm.": "mlp.linear_fc1.layer_norm_"}) + elif block_config.mlp.replace_with_linear: # Linear + mapping.update({"pre_mlp_layernorm.": "mlp.layer_norm_"}) + else: # No-op, MoE + pass + return mapping + + +# NOTE: new class +@dataclass +class PuzzletronSubblockConfig: + """Base configuration class for Puzzletron transformer subblocks. + + This is the base class for attention and MLP configurations in Puzzletron's + heterogeneous transformer architecture. It provides common functionality + for subblock configurations including no-op and linear replacement options. + + Key differences from the original Megatron-LM subblock configs: + - **Enhanced Building**: Uses `build_config_from_dict()` with main config fallback + - **Validation**: Includes `__post_init__()` validation for mutual exclusivity + - **Flexibility**: Supports both no-op and linear replacement modes + + Attributes: + no_op: Whether this subblock should be a no-op operation + replace_with_linear: Whether to replace the subblock with a single linear layer + """ + + no_op: bool = False + replace_with_linear: bool = False + + @classmethod + def build_config_from_dict( + cls, + subblock_config_dict: dict[str, Any], + main_config: "PuzzletronHeterogeneousTransformerConfig", + ): + field_names = {f.name for f in fields(cls)} + subblock_config_dict = {k: v for k, v in subblock_config_dict.items() if k in field_names} + # getting default values from the main config (if not overridden in the subblock config) + for field_name in field_names: + # note that MLA fields are also not in the main_config + if field_name not in subblock_config_dict and hasattr(main_config, field_name): + subblock_config_dict[field_name] = getattr(main_config, field_name) + return cls(**subblock_config_dict) + + def __post_init__(self) -> None: + assert not (self.no_op and self.replace_with_linear), ( + "at most one of no_op, replace_with_linear can be True" + ) + + +@dataclass +class PuzzletronAttentionConfig(PuzzletronSubblockConfig): + """Configuration parameters for the self-attention part of a Puzzletron transformer block. + + This class extends the original Megatron-LM AttentionConfig with support for + Mamba layers and enhanced Multi-Latent Attention (MLA) configurations. + + Key differences from the original AttentionConfig: + - **Mamba Support**: Adds `is_mamba` flag and Mamba-specific parameters + - **Enhanced MLA**: Extended MLA parameters with LoRA ranks and head dimensions + - **Context Parallelism**: Adds `cp_comm_type` for attention context parallelism + - **Validation**: Enhanced `__post_init__()` with Mamba-MLA mutual exclusivity check + - **Flexibility**: Supports MHA, MLA, and Mamba attention types in one config + + Attributes: + # MHA (Multi-Head Attention) parameters + num_attention_heads: Number of attention heads + num_query_groups: Number of query groups for grouped query attention + kv_channels: Key-value projection dimension + window_size: Sliding window size for local attention + + # MLA (Multi-Latent Attention) parameters + multi_latent_attention: Whether to use MLA instead of MHA + q_lora_rank: LoRA rank for query projections + kv_lora_rank: LoRA rank for key-value projections + qk_head_dim: Query-key head dimension + qk_pos_emb_head_dim: Query-key positional embedding head dimension + v_head_dim: Value head dimension + + # Context parallelism + cp_comm_type: Communication type for context parallelism + + # Mamba parameters + is_mamba: Whether to use Mamba instead of attention + mamba_state_dim: Mamba state dimension + mamba_head_dim: Mamba head dimension + mamba_num_groups: Number of groups in Mamba + mamba_num_heads: Number of heads in Mamba (auto-calculated if None) + """ + + # all attributes, except for is_mamba are part of TransformerConfig/MLATransformerConfig + # MHA + num_attention_heads: Optional[int] = None + num_query_groups: Optional[int] = None + kv_channels: Optional[int] = None + window_size: Optional[Tuple[int, int]] = None + # MLA (Note that for MLA we have to instantiate a MLATransformerConfig, since there is a isinstance in attention.py) + multi_latent_attention: bool = False + q_lora_rank: int = 512 + kv_lora_rank: int = 512 + qk_head_dim: int = 128 + qk_pos_emb_head_dim: int = 64 + v_head_dim: int = 128 + # for attention context parallelism (ignored for mamba) + cp_comm_type: str = "p2p" + # Mamba + is_mamba: bool = False # new + mamba_state_dim: int = 128 + mamba_head_dim: int = 64 + mamba_num_groups: int = 8 + mamba_num_heads: Optional[int] = None + + def __post_init__(self) -> None: + super().__post_init__() + if self.no_op or self.replace_with_linear: + self.is_mamba = False + self.num_attention_heads = 8 + self.multi_latent_attention = False + if self.is_mamba: + if self.num_attention_heads is None or self.num_attention_heads == 0: + self.num_attention_heads = 8 # to avoid division by zero + assert not (self.is_mamba and self.multi_latent_attention), ( + "Mamba and MLA cannot be used together" + ) + + +@dataclass +class PuzzletronMLPConfig(PuzzletronSubblockConfig): + """Configuration parameters for the MLP part of a Puzzletron transformer block. + + This class extends the original Megatron-LM MLPConfig with enhanced + Mixture of Experts (MoE) support and improved configuration building. + + Key differences from the original MLPConfig: + - **Enhanced MoE**: Extended MoE parameters with shared expert support + - **Validation**: Includes `__post_init__()` validation for no-op/linear modes + - **Building**: Uses `build_config_from_dict()` with main config fallback + - **Flexibility**: Supports standard MLP, MoE, no-op, and linear replacement modes + + Attributes: + # Standard MLP parameters + ffn_hidden_size: MLP intermediate size (hidden dimension) + + # MoE (Mixture of Experts) parameters + num_moe_experts: Number of expert networks in MoE + moe_shared_expert_intermediate_size: Size of shared expert intermediate layer + moe_ffn_hidden_size: Hidden size for MoE expert networks + moe_router_topk: Number of top-k experts to route tokens to + """ + + # all attributes are part of TransformerConfig + ffn_hidden_size: Optional[int] = None + # MoE + num_moe_experts: Optional[int] = None + moe_shared_expert_intermediate_size: Optional[int] = None + moe_ffn_hidden_size: Optional[int] = None + moe_router_topk: int = 2 + + def __post_init__(self) -> None: + super().__post_init__() + if self.no_op or self.replace_with_linear: + self.ffn_hidden_size = None + self.num_moe_experts = None + self.moe_ffn_hidden_size = None + + +# NOTE: based on https://github.com/NVIDIA/Megatron-LM/blob/aacc3b8aa5f0d3071431a94503d6233802fbaedd/megatron/core/transformer/heterogeneous/heterogeneous_config.py#L134 +@dataclass +class PuzzletronTransformerBlockConfig: + """Configuration for a single Puzzletron transformer block in a heterogeneous model. + + This class represents the configuration for one transformer block, containing + both attention and MLP subblock configurations. It's based on the original + Megatron-LM TransformerBlockConfig but uses Puzzletron-specific subblock configs. + + Key differences from the original TransformerBlockConfig: + - **Puzzletron Subblocks**: Uses `PuzzletronAttentionConfig` and `PuzzletronMLPConfig` + - **Enhanced Building**: Uses `build_from_dict()` with main config fallback + - **Mamba Support**: Supports Mamba layers through attention config + - **MoE Support**: Enhanced MoE support through MLP config + - **Flexibility**: Supports all Puzzletron attention and MLP variants + + Attributes: + attention: Configuration for the attention subblock (MHA, MLA, or Mamba) + mlp: Configuration for the MLP subblock (standard MLP or MoE) + """ + + attention: PuzzletronAttentionConfig + mlp: PuzzletronMLPConfig + + @classmethod + def build_from_dict( + cls, block: dict[str, Any], main_config: "PuzzletronHeterogeneousTransformerConfig" + ): + if "mlp" in block: + mlp = block["mlp"] + elif "ffn" in block: + mlp = block["ffn"] + else: + raise ValueError(f"mlp/ffn not found in block: {block}") + + return cls( + attention=PuzzletronAttentionConfig.build_config_from_dict( + subblock_config_dict=block["attention"], main_config=main_config + ), + mlp=PuzzletronMLPConfig.build_config_from_dict( + subblock_config_dict=mlp, main_config=main_config + ), + ) + + +@dataclass +class PuzzletronMambaTransformerConfig(TransformerConfig): + """Configuration for Puzzletron Mamba-only transformer models. + + This class extends the base TransformerConfig for models that use + Mamba layers exclusively instead of attention mechanisms. It inherits + all standard transformer configuration parameters from TransformerConfig. + + Key differences from standard TransformerConfig: + - **Mamba Focus**: Designed specifically for Mamba-based architectures + - **Inheritance**: Inherits all standard transformer parameters + - **Simplicity**: Currently a pass-through class for future Mamba-specific extensions + + Note: This class is currently minimal and inherits all functionality + from the base TransformerConfig. Future versions may add Mamba-specific + configuration parameters as needed. + """ + + +# NOTE: based on https://github.com/NVIDIA/Megatron-LM/blob/aacc3b8aa5f0d3071431a94503d6233802fbaedd/megatron/core/transformer/heterogeneous/heterogeneous_config.py#L147 +@dataclass +class PuzzletronHeterogeneousTransformerConfig(TransformerConfig): + """Configuration object for Puzzletron heterogeneous transformers. + + This class extends the original Megatron-LM HeterogeneousTransformerConfig with + enhanced support for Mamba layers and improved configuration management. + + Key differences from the original HeterogeneousTransformerConfig: + - **Mamba Support**: Adds Mamba-specific parameters for state-space models + - **Enhanced Block Configs**: Uses `PuzzletronTransformerBlockConfig` with Mamba support + - **Improved Building**: Enhanced `__post_init__()` with better config validation + - **Flexibility**: Supports all Puzzletron attention and MLP variants + + Heterogeneous models refer to transformer architectures where individual layers can differ + in configuration. Specifically: + - Attention layers can be MHA, MLA, Mamba, Linear, or No-op (all with their own config) + - MLP layers can be MoE, MLP, Linear, or No-op (all with their own config) + - Layers can have parallel blocks that run simultaneously and sum their outputs + + Mamba Parameters (shared across all Mamba layers): + d_conv: Convolution dimension for Mamba + expand: Expansion factor for Mamba hidden dimension + D_has_hdim: Whether D matrix has hidden dimension + rmsnorm: Whether to use RMS normalization + norm_before_gate: Whether to normalize before gating + dt_min/max/scale: Delta time parameters for Mamba + bias/conv_bias: Bias settings for Mamba layers + chunk_size: Chunk size for Mamba processing + """ + + heterogeneous_layers_config_path: str = "" + """Path to the json file containing the heterogeneous block specs.""" + + heterogeneous_layers_config_encoded_json: str = "" + """The contents of the json file containing the heterogeneous block specs. It will be read from + heterogeneous_layers_config_path at first, then saved forever inside the model checkpoint.""" + + per_block_parameters: list[PuzzletronTransformerBlockConfig] = field(init=False) + """Configuration parameters for each of the transformer blocks in a + heterogeneous transformer.""" + + # all of these can be used to instantiate a MambaMixer, they are shared for all Mamba layers + d_conv: int = 4 + expand: int = 2 + D_has_hdim: bool = False + rmsnorm: bool = True + norm_before_gate: bool = False + dt_min: float = 0.001 + dt_max: float = 0.1 + dt_scale: float = 1.0 + bias: bool = False + conv_bias: bool = True + chunk_size: int = 128 + + def __post_init__(self) -> None: + if self.kv_channels is None and self.num_attention_heads == 0: + self.num_attention_heads = 8 # to avoid division by zero + # Type assertion to help mypy understand the type after the check + assert isinstance(self.num_attention_heads, int), "num_attention_heads must be an integer" + if self.heterogeneous_layers_config_encoded_json in ("", None): + assert self.heterogeneous_layers_config_path not in ( + None, + "", + ), ( + "heterogeneous_layers_config_path is required, if heterogeneous_layers_config_encoded_json is not provided" + ) + self.heterogeneous_layers_config_encoded_json = Path( + self.heterogeneous_layers_config_path + ).read_text() + hf_config_dict: dict[str, Any] = json.loads(self.heterogeneous_layers_config_encoded_json) + block_list = hf_config_dict["block_configs"] + # TODO: should we change the definition of num_layers? it can be sum(mlp/attention) rather than uneven blocks + if self.num_layers is None or self.num_layers == 0: + self.num_layers = len(block_list) + # Type assertion to help mypy understand the type after the check + assert isinstance(self.num_layers, int), "num_layers must be an integer" + assert self.num_layers == len(block_list), ( + "num_layers must match the number of blocks in the json file" + ) + super().__post_init__() + self.heterogeneous_block_specs = True + self.heterogeneous_dist_checkpoint = True # TODO: check if this is correct/needed + self.per_block_parameters = [ + PuzzletronTransformerBlockConfig.build_from_dict(block=block, main_config=self) + for block in block_list + ] + + # TODO add parallel blocks support + def get_config_for_layer( + self, layer_number: int + ) -> TransformerConfig | MLATransformerConfig | PuzzletronMambaTransformerConfig: + """ + Get the config for the given layer number. + Based on the layer number, the corresponding block config is returned, + overriding the main config's value. + + Returns: + TransformerConfig: For standard transformer layers + MLATransformerConfig: For MLA layers + PuzzletronMambaTransformerConfig: For Mamba layers + """ + layer_idx = layer_number - 1 # layer number starts from 1 + if layer_idx < 0 or layer_idx >= len(self.per_block_parameters): + raise ValueError( + f"Invalid layer number: {layer_number}. Should be in " + f"range [1, {len(self.per_block_parameters)}]." + ) + block_config = self.per_block_parameters[layer_idx] + + # Determine which config class to use based on the block configuration + if block_config.attention.is_mamba: + config_class = PuzzletronMambaTransformerConfig + elif block_config.attention.multi_latent_attention: + config_class = MLATransformerConfig + else: + config_class = TransformerConfig + + # Get all available fields from the attention and MLP configs + attention_fields = {f.name for f in fields(block_config.attention)} + mlp_fields = {f.name for f in fields(block_config.mlp)} + + # Get all available fields from the target config class + target_config_fields = {f.name for f in fields(config_class)} + + # Start with the base config + transformer_config_dict = asdict(self) + + # Remove keys that are not in the target config class + transformer_config_dict = { + k: v for k, v in transformer_config_dict.items() if k in target_config_fields + } + + # Update with all available attention config values (if they exist in target config) + for field_name in attention_fields: + if field_name in target_config_fields: + transformer_config_dict[field_name] = getattr(block_config.attention, field_name) + + # Update with all available MLP config values (if they exist in target config) + for field_name in mlp_fields: + if field_name in target_config_fields: + transformer_config_dict[field_name] = getattr(block_config.mlp, field_name) + + if transformer_config_dict["num_moe_experts"] is None: + # to pass __post_init__ of config_class + transformer_config_dict["expert_model_parallel_size"] = 1 + config = config_class(**transformer_config_dict) + + return config + + +# NOTE: based on https://github.com/NVIDIA/Megatron-LM/blob/ba97a7e282a8478a02d012bc9b9e45f3a6be216e/megatron/core/extensions/transformer_engine.py#L449 +class WrappedTENormLinear(TELayerNormColumnParallelLinear): + """A wrapper around TELayerNormColumnParallelLinear with simplified interface and forced configurations. + + This wrapper simplifies the interface of TELayerNormColumnParallelLinear by: + 1. Taking only a config object instead of individual parameters + 2. Forcing specific configurations (tp_group=None, tp_size=1, etc.) for compatibility + 3. Adding version checks for Transformer Engine features + 4. Providing a cleaner interface for heterogeneous transformer models + + Key differences from TELayerNormColumnParallelLinear: + - Simplified constructor: only requires config and optional unused parameters + - Forces tensor parallel settings: tp_group=None, tp_size=1, tp_rank=0 + - Automatically sets input_size=output_size=config.hidden_size + - Adds version checks for TE features (delay_wgrad_compute, normalization, symmetric_ar_type) + - Forces bias=False, skip_bias_add=False for consistency + - Disables gather_output (raises error if True) + - Uses simplified init_method=lambda w: None + + This wrapper is designed for use in heterogeneous transformer architectures where + individual layers may have different configurations but need a consistent interface. + """ + + def __init__( + self, + config, + layer_number=None, # unused + model_comm_pgs=None, # unused + cp_comm_type=None, # unused + tp_group=None, # unused + tp_comm_buffer_name=None, + gather_output=False, # unused + ): + # unfortunately, TELayerNormColumnParallelLinear sets tp_group and forcing it to be None requires to copy/paste __init__ + if not HAVE_TE: + raise ImportError( + "Transformer Engine is not installed. " + "Please install it with `pip install transformer-engine`." + ) + + self.config = config + + if gather_output: + raise ValueError("Transformer Engine linear layers do not support gather_output = True") + + skip_bias_add = False + bias = False + + # TE returns a zero length Tensor when bias=False and + # return_bias=True, but we prefer None. So in that case we + # tell TE to not return the bias, and return None + # ourselves. This way our forward always returns two values + # and we don't have to deal with the zero length Tensor. + self.te_return_bias = skip_bias_add and bias + self.is_first_microbatch = True + self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache + extra_kwargs = _get_extra_te_kwargs(config) + self.tp_size = 1 + self.tp_rank = 0 + + if self.config.delay_wgrad_compute: + if is_te_min_version("2.3.0"): + extra_kwargs["delay_wgrad_compute"] = self.config.delay_wgrad_compute + else: + raise RuntimeError("Only TE with version >=2.3.0 supports delay_wgrad_compute now.") + + # Only Transformer-Engine version >= 0.11.0 supports `RMSNorm` + if is_te_min_version("0.11.0"): + extra_kwargs["normalization"] = self.config.normalization + elif self.config.normalization != "LayerNorm": + te_version = get_te_version() + raise ValueError( + f"Transformer Engine v{te_version} does not support {self.config.normalization}." + ) + + if self.config.symmetric_ar_type is not None: + assert is_torch_min_version("2.7.0a0"), "Must have at least torch version 2.7 or higher" + assert is_te_min_version("2.3.0") or get_te_version() == PkgVersion( + "2.3.0.dev0+39c0e70" + ), "Must have at least TE version 2.3 or higher to use symmetric memory all reduce" + extra_kwargs["symmetric_ar_type"] = self.config.symmetric_ar_type + + output_size = config.hidden_size + input_size = config.hidden_size + # calling te.pytorch.LayerNormLinear's __init__ + super(TELayerNormColumnParallelLinear, self).__init__( + in_features=input_size, + out_features=output_size, + eps=self.config.layernorm_epsilon, + sequence_parallel=self.config.sequence_parallel, + fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, + tp_group=None, + tp_size=1, + get_rng_state_tracker=( + get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None + ), + init_method=lambda w: None, + bias=bias, + return_bias=self.te_return_bias, + parallel_mode=None, + return_layernorm_output=False, + zero_centered_gamma=self.config.layernorm_zero_centered_gamma, + **extra_kwargs, + ) + + if config.use_cpu_initialization: + output_size_per_partition = output_size + _ = _initialize_affine_weight_cpu( + self.weight, + output_size, + input_size, + output_size_per_partition, + 0, + init_method=lambda w: None, + stride=1, + return_master_weight=False, + rank=self.tp_rank, + world_size=self.tp_size, + skip_set_tensor_parallel_attributes=True, + ) + if bias: + self.bias = Parameter( + torch.empty(output_size_per_partition, dtype=config.params_dtype) + ) + with torch.no_grad(): + self.bias.zero_() + + def forward(self, x, *args, **kwargs): + return super().forward(x) + + +class WrappedLinear(Linear): + def __init__( + self, + config, + layer_number=None, + model_comm_pgs=None, + cp_comm_type=None, + tp_group=None, + tp_comm_buffer_name=None, + gather_output=False, + ): + super().__init__( + input_size=config.hidden_size, + output_size=config.hidden_size, + config=config, + init_method=config.init_method, + bias=False, + gather_output=gather_output, + skip_bias_add=False, + tp_comm_buffer_name=tp_comm_buffer_name, + tp_group=tp_group, + ) + + def forward(self, x, *args, **kwargs): + return super().forward(x) + + +class WrappedTELinear(TELinear): + # TODO: docstring + def __init__( + self, + config, + layer_number=None, # unused + model_comm_pgs=None, # unused + cp_comm_type=None, # unused + tp_group=None, # unused + tp_comm_buffer_name=None, + gather_output=False, # unused + ): + super().__init__( + input_size=config.hidden_size, + output_size=config.hidden_size, + parallel_mode="duplicated", + # parallel_mode=None, + config=config, + init_method=config.init_method, + bias=False, + skip_bias_add=False, + skip_weight_param_allocation=False, + tp_comm_buffer_name=tp_comm_buffer_name, + is_expert=False, + ) + + def forward(self, x, *args, **kwargs): + return super().forward(x) + + +class WrappedMambaMixer(MambaMixer): + def __init__(self, *args, cp_comm_type: Optional[str] = None, **kwargs): + # ignoring cp_comm_type + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + key_value_states: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[int] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + ) -> Tuple[Tensor, Tensor]: + result = super().forward(hidden_states, inference_context=inference_context) + # Ensure we return a tuple of two tensors + assert isinstance(result, tuple) and len(result) == 2 + return result + + +# NOTE: new method +def get_layer_spec_for_layer( + block_params: PuzzletronTransformerBlockConfig, + config: PuzzletronHeterogeneousTransformerConfig, + use_transformer_engine: bool, + normalization: Optional[str] = None, + qk_l2_norm: Optional[bool] = False, +) -> ModuleSpec: + # this part is copied from megatron.core.models.gpt.gpt_layer_specs.get_gpt_decoder_block_spec() + if use_transformer_engine: + layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=block_params.mlp.num_moe_experts, + moe_grouped_gemm=False, + qk_layernorm=config.qk_layernorm, + multi_latent_attention=block_params.attention.multi_latent_attention, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, + qk_l2_norm=qk_l2_norm, + use_kitchen=config.use_kitchen, + # use_te_activation_func=config.use_te_activation_func, #TODO: part of megatron 0.14 version. check if this is needed now. + ) + else: + layer_spec = get_gpt_layer_local_spec( + num_experts=block_params.mlp.num_moe_experts, + moe_grouped_gemm=False, + qk_layernorm=config.qk_layernorm, + multi_latent_attention=block_params.attention.multi_latent_attention, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, + normalization=normalization, + qk_l2_norm=qk_l2_norm, + use_kitchen=config.use_kitchen, + ) + if block_params.attention.no_op: + layer_spec.submodules.input_layernorm = IdentityOp + layer_spec.submodules.self_attn_bda = IdentityFuncOp + layer_spec.submodules.self_attention = ModuleSpec(module=IdentityOp) + elif block_params.attention.replace_with_linear: + layer_spec.submodules.self_attention = ModuleSpec( + module=WrappedTENormLinear if use_transformer_engine else WrappedLinear, + params={"tp_comm_buffer_name": "linear_attn"}, + ) + elif block_params.attention.is_mamba: + mamba_mixer_params = dict( + d_model=config.hidden_size, + d_conv=config.d_conv, + expand=config.expand, + D_has_hdim=config.D_has_hdim, + rmsnorm=config.rmsnorm, + norm_before_gate=config.norm_before_gate, + dt_min=config.dt_min, + dt_max=config.dt_max, + dt_scale=config.dt_scale, + bias=config.bias, + conv_bias=config.conv_bias, + chunk_size=config.chunk_size, + ) + layer_spec.submodules.self_attention = ModuleSpec( + module=WrappedMambaMixer, + params=mamba_mixer_params, + submodules=MambaMixerSubmodules( + in_proj=( + TELayerNormColumnParallelLinear + if use_transformer_engine + else ColumnParallelLinear + ), + out_proj=TERowParallelLinear if use_transformer_engine else RowParallelLinear, + ), + ) + + if block_params.mlp.no_op: + layer_spec.submodules.pre_mlp_layernorm = IdentityOp + layer_spec.submodules.mlp_bda = IdentityFuncOp + layer_spec.submodules.mlp = ModuleSpec(module=IdentityOp) + elif block_params.mlp.replace_with_linear: + layer_spec.submodules.mlp = ModuleSpec( + module=WrappedTENormLinear if use_transformer_engine else WrappedLinear, + params={"tp_comm_buffer_name": "linear_mlp"}, + ) + + layer_spec.submodules.sharded_state_dict_keys_map = _get_sharded_state_dict_keys_map( + block_params, use_transformer_engine + ) + return layer_spec + + +# NOTE: based on https://github.com/NVIDIA/Megatron-LM/blob/aacc3b8aa5f0d3071431a94503d6233802fbaedd/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py#L168 +def get_gpt_heterogeneous_layer_spec_puzzletron( + config: PuzzletronHeterogeneousTransformerConfig, + use_transformer_engine: bool, + normalization: Optional[str] = None, + qk_l2_norm: Optional[bool] = False, + vp_stage: Optional[int] = None, +) -> TransformerBlockSubmodules: + """Generate heterogeneous layer specifications for Puzzletron transformer models. + + This function is a specialized version of the original Megatron Core + `get_gpt_heterogeneous_layer_spec` function, adapted for Puzzletron's specific + heterogeneous transformer architecture requirements. + + Key differences from the original: + - **Signature**: Adds `normalization` and `qk_l2_norm` parameters, removes `pp_rank` + - **Architecture**: Uses `get_layer_spec_for_layer()` helper for modular layer creation + - **Pipeline Parallel**: Enhanced with `pipeline_model_parallel_layout` support + - **Configuration**: Uses `PuzzletronHeterogeneousTransformerConfig` with Mamba parameters + - **Layer Norm**: Simplified to `TENorm` vs `LNImpl` (removes `WrappedTorchNorm` complexity) + - **Features**: Supports Mamba layers, custom attention types, and advanced parallelization + + Args: + config: Puzzletron heterogeneous transformer configuration + use_transformer_engine: Whether to use Transformer Engine optimizations + normalization: Optional normalization type override + qk_l2_norm: Whether to apply L2 normalization to QK matrices + vp_stage: Virtual pipeline stage for advanced parallelization + + Returns: + TransformerBlockSubmodules: Complete layer specification for the heterogeneous model + """ + # Create the layer specs for the model. + layer_specs = [ + get_layer_spec_for_layer( + block_params, config, use_transformer_engine, normalization, qk_l2_norm + ) + for block_params in config.per_block_parameters + ] + + # Slice the layer specs to only include the layers that are built in this pipeline stage. + # Note: MCore layer_number starts at 1 + num_layers_to_build = get_num_layers_to_build(config, vp_stage=vp_stage) + + if config.pipeline_model_parallel_layout is not None: + local_layer_specs = [ + layer_specs[layer_id] + for layer_id in config.pipeline_model_parallel_layout.get_layer_id_list( + layer_type=LayerType.decoder, vp_stage=vp_stage + ) + ] + else: + offset = get_transformer_layer_offset(config, vp_stage=vp_stage) + local_layer_specs = layer_specs[offset : offset + num_layers_to_build] + + if use_transformer_engine: + layer_norm_impl = TENorm + else: + layer_norm_impl = LNImpl + + # Block spec. + block_spec = TransformerBlockSubmodules( + layer_specs=local_layer_specs, layer_norm=layer_norm_impl + ) + + return block_spec + + +# NOTE: based on https://github.com/NVIDIA/Megatron-LM/blob/aacc3b8aa5f0d3071431a94503d6233802fbaedd/gpt_builders.py#L23 +def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None): + """Build a GPT model with Puzzletron's heterogeneous transformer architecture. + + This function is a specialized version of the original Megatron-LM `gpt_builder` function, + adapted for Puzzletron's heterogeneous transformer architecture requirements. + + Key differences from the original: + - **Simplified**: Focuses exclusively on heterogeneous models (rejects legacy, spec-based, MoE, MTP) + - **Configuration**: Only supports args-based config (removes YAML complexity) + - **Layer Spec**: Uses single `get_gpt_heterogeneous_layer_spec_puzzletron` function + - **Error Handling**: Explicit error messages for unsupported features + - **Logging**: Removes debug logging for cleaner implementation + + Args: + args: Command-line arguments namespace containing model configuration parameters + pre_process: Whether to include pre-processing layers + post_process: Whether to include post-processing layers + vp_stage: Virtual pipeline stage for advanced parallelization + config: Optional pre-configured transformer config (if None, created from args) + + Returns: + GPTModel: Configured GPT model with heterogeneous transformer architecture + + Raises: + ValueError: If legacy models, spec-based models, or MTP are requested + """ + assert config is not None, "config is required" + if args.use_legacy_models: + raise ValueError("Legacy models are not supported") + if args.spec is not None: + raise ValueError("Spec is not supported") + use_te = args.transformer_impl == "transformer_engine" + transformer_layer_spec = get_gpt_heterogeneous_layer_spec_puzzletron( + config, + use_te, + normalization=args.normalization, + qk_l2_norm=args.qk_l2_norm, + vp_stage=vp_stage, + ) + mtp_block_spec = None + if args.mtp_num_layers is not None: + raise ValueError("MTP is not supported") + model = GPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, + rope_scaling=args.use_rope_scaling, + mtp_block_spec=mtp_block_spec, + vp_stage=vp_stage, + ) + + return model