Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions bionemo-recipes/models/llama3/modeling_llama_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,8 @@ def forward(
hidden_states = hidden_states.squeeze(0)

if self.config.attn_input_format == "bshd" and attention_mask is not None and attention_mask.dim() == 2:
# If we're using padded BSHD inputs, we need to convert the 2-dimensional mask to a 4-dimensional mask in
# the expected boolean format for TE.
attention_mask = attention_mask[:, None, None, :] < -1
# Convert HF mask (1=attend, 0=pad) to TE boolean mask (True=masked, False=attend)
attention_mask = ~attention_mask[:, None, None, :].bool()

if isinstance(past_key_values, InferenceParams): # InferenceParams is TE's way of managing kv-caching.
# In generation mode, we set the length to 1 for each batch index. Otherwise, we use the attention mask to
Expand Down
4 changes: 2 additions & 2 deletions bionemo-recipes/models/llama3/tests/test_cp_bshd.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_te_model_checkpoint(tmp_path):
"""
# Use the 1B model for practical testing (8B model requires too much memory)
model_hf = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B-Instruct", revision="c731040f", dtype=torch.bfloat16
"meta-llama/Llama-3.2-1B-Instruct", revision="9213176", dtype=torch.bfloat16
Comment thread
pstjohn marked this conversation as resolved.
)
model_te = convert_llama_hf_to_te(model_hf, attn_input_format="bshd", self_attn_mask_type="causal")
model_te.save_pretrained(tmp_path / "te_model_checkpoint")
Expand Down Expand Up @@ -210,7 +210,7 @@ def test_context_parallel_equivalence_2process(recipe_path: Path, unused_tcp_por
model_ckpt = get_te_model_checkpoint(tmp_path)

# Create tokenizer for English text (use the 1B model tokenizer)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", revision="c731040f")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", revision="9213176")
tokenizer.pad_token = tokenizer.eos_token
input_data_bshd_dp0 = get_dummy_data_bshd_no_padding(tokenizer=tokenizer, seq_length=64)

Expand Down
1 change: 1 addition & 0 deletions bionemo-recipes/models/mixtral/.ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
extend = "../.ruff.toml"
Comment thread
pstjohn marked this conversation as resolved.
988 changes: 988 additions & 0 deletions bionemo-recipes/models/mixtral/collator.py

Large diffs are not rendered by default.

182 changes: 182 additions & 0 deletions bionemo-recipes/models/mixtral/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect

import torch
from transformers import MixtralConfig, MixtralForCausalLM

import state
from modeling_mixtral_te import NVMixtralConfig, NVMixtralForCausalLM


mapping = {
"model.embed_tokens.weight": "model.embed_tokens.weight",
"model.layers.*.input_layernorm.weight": "model.layers.*.self_attention.layernorm_qkv.layer_norm_weight",
"model.layers.*.self_attn.o_proj.weight": "model.layers.*.self_attention.proj.weight",
"model.layers.*.post_attention_layernorm.weight": "model.layers.*.post_attention_layernorm.weight",
"model.layers.*.mlp.gate.weight": "model.layers.*.mlp.gate.weight",
"model.norm.weight": "model.norm.weight",
"lm_head.weight": "lm_head.weight",
}

reverse_mapping = {v: k for k, v in mapping.items()}


def _split_experts_gate_up(gate_up_proj: torch.Tensor):
"""Split a stacked expert gate_up tensor into per-expert tensors.

Args:
gate_up_proj: Tensor of shape [num_experts, 2*ffn, hidden].

Returns:
Tuple of per-expert tensors, each of shape [2*ffn, hidden].
"""
return tuple(gate_up_proj[i] for i in range(gate_up_proj.shape[0]))


def _split_experts_down(down_proj: torch.Tensor):
"""Split a stacked expert down_proj tensor into per-expert tensors.

Args:
down_proj: Tensor of shape [num_experts, hidden, ffn].

Returns:
Tuple of per-expert tensors, each of shape [hidden, ffn].
"""
return tuple(down_proj[i] for i in range(down_proj.shape[0]))


def _make_merge_experts_fn(num_experts: int):
"""Create a merge function with the correct number of named parameters.

The state.py transform system maps function parameter names to source keys, so we need a function
with exactly `num_experts` named parameters (weight0, weight1, ...).
"""
param_names = [f"weight{i}" for i in range(num_experts)]
code = f"def merge_experts({', '.join(param_names)}):\n return torch.stack([{', '.join(param_names)}])"
local_ns = {"torch": torch}
exec(code, local_ns)
return local_ns["merge_experts"]


def convert_mixtral_hf_to_te(model_hf: MixtralForCausalLM, **config_kwargs) -> NVMixtralForCausalLM:
"""Convert a Hugging Face Mixtral model to a Transformer Engine model.

Args:
model_hf: The Hugging Face Mixtral model.
**config_kwargs: Additional configuration kwargs to be passed to NVMixtralConfig.

Returns:
The Transformer Engine Mixtral model.
"""
te_config = NVMixtralConfig(**model_hf.config.to_dict(), **config_kwargs)
with torch.device("meta"):
model_te = NVMixtralForCausalLM(te_config)

num_experts = model_hf.config.num_local_experts

# Build expert weight target keys for gate_up and down projections
gate_up_target_keys = tuple(f"model.layers.*.mlp.experts_gate_up.weight{i}" for i in range(num_experts))
down_target_keys = tuple(f"model.layers.*.mlp.experts_down.weight{i}" for i in range(num_experts))

output_model = state.apply_transforms(
model_hf,
model_te,
mapping,
[
state.state_transform(
source_key=(
"model.layers.*.self_attn.q_proj.weight",
"model.layers.*.self_attn.k_proj.weight",
"model.layers.*.self_attn.v_proj.weight",
),
target_key="model.layers.*.self_attention.layernorm_qkv.weight",
fn=state.TransformFns.merge_qkv,
),
state.state_transform(
source_key="model.layers.*.mlp.experts.gate_up_proj",
target_key=gate_up_target_keys,
fn=_split_experts_gate_up,
),
state.state_transform(
source_key="model.layers.*.mlp.experts.down_proj",
target_key=down_target_keys,
fn=_split_experts_down,
),
],
)

output_model.model.rotary_emb.inv_freq = model_hf.model.rotary_emb.inv_freq.clone()

return output_model


def convert_mixtral_te_to_hf(model_te: NVMixtralForCausalLM, **config_kwargs) -> MixtralForCausalLM:
"""Convert a Transformer Engine Mixtral model to a Hugging Face model.

Args:
model_te: The Transformer Engine Mixtral model.
**config_kwargs: Additional configuration kwargs to be passed to MixtralConfig.

Returns:
The Hugging Face Mixtral model.
"""
te_config_dict = model_te.config.to_dict()
valid_keys = set(inspect.signature(MixtralConfig.__init__).parameters)
filtered_config = {k: v for k, v in te_config_dict.items() if k in valid_keys}
hf_config = MixtralConfig(**filtered_config, **config_kwargs)

with torch.device("meta"):
model_hf = MixtralForCausalLM(hf_config)

num_experts = hf_config.num_local_experts

gate_up_source_keys = tuple(f"model.layers.*.mlp.experts_gate_up.weight{i}" for i in range(num_experts))
down_source_keys = tuple(f"model.layers.*.mlp.experts_down.weight{i}" for i in range(num_experts))

merge_fn = _make_merge_experts_fn(num_experts)

output_model = state.apply_transforms(
model_te,
model_hf,
reverse_mapping,
[
state.state_transform(
source_key="model.layers.*.self_attention.layernorm_qkv.weight",
target_key=(
"model.layers.*.self_attn.q_proj.weight",
"model.layers.*.self_attn.k_proj.weight",
"model.layers.*.self_attn.v_proj.weight",
),
fn=state.TransformFns.split_qkv,
),
state.state_transform(
source_key=gate_up_source_keys,
target_key="model.layers.*.mlp.experts.gate_up_proj",
fn=merge_fn,
),
state.state_transform(
source_key=down_source_keys,
target_key="model.layers.*.mlp.experts.down_proj",
fn=merge_fn,
),
],
)

output_model.model.rotary_emb.inv_freq = model_te.model.rotary_emb.inv_freq.clone()
output_model.tie_weights()

return output_model
56 changes: 56 additions & 0 deletions bionemo-recipes/models/mixtral/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Create a Mixtral TE checkpoint from a HuggingFace Mixtral model."""

import json
import shutil
from pathlib import Path

from transformers import AutoModelForCausalLM, AutoTokenizer

import convert
from modeling_mixtral_te import AUTO_MAP


def export_hf_checkpoint(tag: str, export_path: Path):
"""Export a Hugging Face checkpoint to a Transformer Engine checkpoint.

Args:
tag: The tag of the checkpoint to export.
export_path: The parent path to export the checkpoint to.
"""
model_hf = AutoModelForCausalLM.from_pretrained(tag)

model_te = convert.convert_mixtral_hf_to_te(model_hf)
model_te.save_pretrained(export_path)

tokenizer = AutoTokenizer.from_pretrained(tag)
tokenizer.save_pretrained(export_path)

# Patch the config
with open(export_path / "config.json", "r") as f:
config = json.load(f)

config["auto_map"] = AUTO_MAP

with open(export_path / "config.json", "w") as f:
json.dump(config, f, indent=2, sort_keys=True)

shutil.copy(Path(__file__).parent / "modeling_mixtral_te.py", export_path / "modeling_mixtral_te.py")


if __name__ == "__main__":
export_hf_checkpoint("NeuralNovel/Mini-Mixtral-v0.2", Path("checkpoint_export"))
Loading