Skip to content

Commit 607d72e

Browse files
committed
update mixtral readme and tests
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 6a60786 commit 607d72e

5 files changed

Lines changed: 320 additions & 3 deletions

File tree

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Mixtral Optimized with NVIDIA TransformerEngine
2+
3+
This folder contains source code and tests for Mixtral-style Mixture of Experts (MoE) models that inherit from the
4+
transformers `PreTrainedModel` class and use TransformerEngine layers. The implementation replaces the standard
5+
attention layers with TE `MultiheadAttention` and uses TE `GroupedLinear` for efficient parallel expert computation.
6+
7+
## Feature support
8+
9+
The Mixtral implementation natively supports the following TransformerEngine-provided optimizations:
10+
11+
| Feature | Support |
12+
| --------------------------------------- | -------------------------------------------------------------------------------- |
13+
| **FP8** | ✅ Supported on compute capacity 9.0 and above (Hopper+) |
14+
| **MXFP8** | ✅ Supported on compute capacity 10.0 and 10.3 (Blackwell), 12.0 support pending |
15+
| **Sequence Packing / THD input format** | ✅ Supported |
16+
| **FP8 with THD input format** | ✅ Supported where FP8 is supported |
17+
| **Import from HuggingFace checkpoints** | ✅ Supported |
18+
| **Export to HuggingFace checkpoints** | ✅ Supported |
19+
| **KV-cache inference** | ✅ Supported |
20+
21+
## Inference Examples
22+
23+
### Quick start: convert and run
24+
25+
> **Note:** The snippets below use bare imports (e.g., `from convert import ...`). Run them from the
26+
> `bionemo-recipes/models/mixtral` directory, or install dependencies first with `pip install -r requirements.txt`.
27+
28+
```python
29+
import torch
30+
from transformers import AutoModelForCausalLM, AutoTokenizer
31+
32+
from convert import convert_mixtral_hf_to_te
33+
34+
# Load the original HuggingFace Mixtral model
35+
model_hf = AutoModelForCausalLM.from_pretrained(
36+
"mistralai/Mixtral-8x7B-v0.1", torch_dtype=torch.bfloat16
37+
)
38+
39+
# Convert to TransformerEngine
40+
model_te = convert_mixtral_hf_to_te(model_hf)
41+
model_te.to("cuda")
42+
43+
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
44+
tokenizer.pad_token = tokenizer.eos_token
45+
46+
inputs = tokenizer("The quick brown fox", return_tensors="pt")
47+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
48+
49+
with torch.no_grad():
50+
output_ids = model_te.generate(**inputs, max_new_tokens=16)
51+
52+
print(tokenizer.decode(output_ids[0], skip_special_tokens=True))
53+
```
54+
55+
## Converting Between Model Formats
56+
57+
This section explains how to convert between Hugging Face Transformers and Transformer Engine (TE) Mixtral model
58+
formats. The process demonstrates bidirectional conversion: from Transformers to TE format for optimized training and
59+
inference, and back to Hugging Face Transformers format for sharing and deployment.
60+
61+
### Converting from HF Transformers to TE
62+
63+
> **Note:** Run from the `bionemo-recipes/models/mixtral` directory, or install dependencies first with
64+
> `pip install -r requirements.txt`.
65+
66+
```python
67+
from transformers import AutoModelForCausalLM
68+
69+
from convert import convert_mixtral_hf_to_te
70+
71+
model_hf = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
72+
model_te = convert_mixtral_hf_to_te(model_hf)
73+
model_te.save_pretrained("/path/to/te_checkpoint")
74+
```
75+
76+
### Converting from TE back to HF Transformers
77+
78+
> **Note:** Run from the `bionemo-recipes/models/mixtral` directory, or install dependencies first with
79+
> `pip install -r requirements.txt`.
80+
81+
```python
82+
from convert import convert_mixtral_te_to_hf
83+
from modeling_mixtral_te import NVMixtralForCausalLM
84+
85+
model_te = NVMixtralForCausalLM.from_pretrained("/path/to/te_checkpoint")
86+
model_hf = convert_mixtral_te_to_hf(model_te)
87+
model_hf.save_pretrained("/path/to/hf_checkpoint")
88+
```
89+
90+
### Validating Converted Models
91+
92+
The golden value tests in [test_modeling_mixtral.py](tests/test_modeling_mixtral.py) verify that the converted TE model
93+
produces numerically equivalent outputs to the original HuggingFace model. Specifically:
94+
95+
- `test_golden_values_bshd` — loads both models, runs a forward pass on the same input, and asserts that logits and
96+
loss match within tolerance.
97+
- `test_round_trip_conversion` — converts HF → TE → HF and verifies the round-tripped model produces identical outputs.
98+
99+
To run these tests locally:
100+
101+
```bash
102+
./ci/scripts/recipes_local_test.py bionemo-recipes/models/mixtral/
103+
```
104+
105+
## Developer Guide
106+
107+
### Running tests
108+
109+
To run tests locally, run `recipes_local_test.py` from the repository root with the model directory as an argument.
110+
111+
```bash
112+
./ci/scripts/recipes_local_test.py bionemo-recipes/models/mixtral/
113+
```
114+
115+
### Exporting to Hugging Face Hub
116+
117+
The model directory includes an `export.py` script that bundles all files needed for Hugging Face Hub distribution. To
118+
create the export bundle, run from the model directory:
119+
120+
```bash
121+
python export.py
122+
```
123+
124+
Before publishing, validate the export by running the local test suite via
125+
[recipes_local_test.py](../../ci/scripts/recipes_local_test.py).
126+
127+
### Development container
128+
129+
To use the provided devcontainer, use "Dev Containers: Reopen in Container" from the VSCode menu, and choose the
130+
"BioNeMo Recipes Dev Container" option. To run the tests inside the container, first install the dependencies with
131+
`pip install -r requirements.txt`, then run `pytest -v .` in the model directory.

bionemo-recipes/models/mixtral/convert.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
"""Conversion utilities between HuggingFace Mixtral and TransformerEngine formats."""
17+
1618
import inspect
1719

1820
import torch
@@ -62,8 +64,20 @@ def _split_experts_down(down_proj: torch.Tensor):
6264
def _make_merge_experts_fn(num_experts: int):
6365
"""Create a merge function with the correct number of named parameters.
6466
65-
The state.py transform system maps function parameter names to source keys, so we need a function
66-
with exactly `num_experts` named parameters (weight0, weight1, ...).
67+
The state.py transform system maps function parameter names to source dict keys by inspecting
68+
the function signature. When ``source_key`` is a tuple, it pairs each tuple element with the
69+
corresponding named parameter via ``{param: source_key[i]}``. This means ``*args`` style
70+
parameters do not work -- the system cannot map positional varargs to specific source keys.
71+
72+
Since the number of experts is dynamic (varies per model config), we use ``exec()`` to generate
73+
a function with exactly ``num_experts`` named parameters (weight0, weight1, ..., weightN-1).
74+
75+
Args:
76+
num_experts: The number of expert weight parameters the generated function will accept.
77+
78+
Returns:
79+
A callable ``(weight0, weight1, ..., weight{N-1}) -> torch.Tensor`` that stacks the
80+
per-expert weight tensors into a single tensor of shape ``[num_experts, ...]``.
6781
"""
6882
param_names = [f"weight{i}" for i in range(num_experts)]
6983
code = f"def merge_experts({', '.join(param_names)}):\n return torch.stack([{', '.join(param_names)}])"

bionemo-recipes/models/mixtral/modeling_mixtral_te.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
"""TransformerEngine-optimized Mixtral model with Mixture of Experts."""
17+
1618
from collections import OrderedDict
1719
from typing import ClassVar, Unpack
1820

@@ -38,6 +40,9 @@
3840
class NVMixtralConfig(MixtralConfig):
3941
"""NVMixtral configuration."""
4042

43+
# Attention input format:
44+
# "bshd" = Batch, Sequence, Head, Dimension (standard padded format)
45+
# "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
4146
attn_input_format: str = "thd"
4247
self_attn_mask_type: str = "padding_causal"
4348

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
18+
import pytest
19+
from transformer_engine.pytorch import MultiheadAttention
20+
from transformers import AutoModelForCausalLM, AutoTokenizer
21+
22+
from export import export_hf_checkpoint
23+
24+
25+
@pytest.mark.skipif(os.getenv("CI", "false") == "true", reason="Skipping test in CI, requires Mini-Mixtral download.")
26+
def test_export_mixtral_checkpoint(tmp_path):
27+
export_hf_checkpoint("NeuralNovel/Mini-Mixtral-v0.2", tmp_path / "checkpoint_export")
28+
29+
_ = AutoTokenizer.from_pretrained(tmp_path / "checkpoint_export")
30+
model = AutoModelForCausalLM.from_pretrained(tmp_path / "checkpoint_export", trust_remote_code=True)
31+
assert "NVMixtralForCausalLM" in model.__class__.__name__
32+
assert "NVMixtralConfig" in model.config.__class__.__name__
33+
# Mixtral uses custom NVMixtralDecoderLayer with TE MultiheadAttention sub-modules
34+
assert isinstance(model.model.layers[0].self_attention, MultiheadAttention)

bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
from collator import DataCollatorWithFlattening
3939
from convert import convert_mixtral_hf_to_te, convert_mixtral_te_to_hf
40-
from modeling_mixtral_te import NVMixtralConfig, NVMixtralForCausalLM
40+
from modeling_mixtral_te import HFInferenceParams, NVMixtralConfig, NVMixtralForCausalLM
4141
from tests.common import BaseModelTest, TestTolerances
4242

4343

@@ -145,3 +145,136 @@ def get_tolerances(self) -> TestTolerances:
145145
cp_loss_atol=0.5,
146146
cp_loss_rtol=0.25,
147147
)
148+
149+
# ==================== Mixtral-Specific KV-Cache Tests ====================
150+
151+
def _create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1):
152+
"""Create HFInferenceParams for the given config."""
153+
past_key_values = HFInferenceParams(
154+
max_batch_size=batch_size * num_beams,
155+
max_sequence_length=max_seq_len,
156+
num_heads_kv=config.num_key_value_heads,
157+
head_dim_k=config.hidden_size // config.num_attention_heads,
158+
dtype=torch.bfloat16,
159+
qkv_format="thd",
160+
max_ctx_len=max_seq_len,
161+
)
162+
for layer_number in range(1, config.num_hidden_layers + 1):
163+
past_key_values.allocate_memory(layer_number)
164+
return past_key_values
165+
166+
def test_generate_with_cache(self):
167+
"""Test single-prompt generation with KV-cache (THD format)."""
168+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
169+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
170+
model.eval()
171+
172+
tokenizer = self.get_tokenizer()
173+
prompt = "The quick brown fox jumps over"
174+
inputs = tokenizer(prompt, return_tensors="pt")
175+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
176+
177+
past_key_values = self._create_inference_params(config, batch_size=1)
178+
179+
with torch.no_grad():
180+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values)
181+
182+
# Verify generation produced new tokens
183+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
184+
185+
def test_generate_with_cache_batched(self):
186+
"""Test batched generation with KV-cache (left-padded BSHD converted to THD)."""
187+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
188+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
189+
model.eval()
190+
191+
tokenizer = self.get_tokenizer()
192+
prompts = (
193+
"The quick brown fox jumps over the lazy dog.",
194+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
195+
)
196+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
197+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
198+
199+
past_key_values = self._create_inference_params(config, batch_size=2)
200+
201+
with torch.no_grad():
202+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values)
203+
204+
# Verify generation produced new tokens for both sequences
205+
assert output_ids.shape[0] == 2
206+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
207+
208+
def test_generate_with_cache_beam_search(self):
209+
"""Test batched generation with KV-cache and beam search."""
210+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
211+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
212+
model.eval()
213+
214+
tokenizer = self.get_tokenizer()
215+
prompts = (
216+
"The quick brown fox jumps over the lazy dog.",
217+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
218+
)
219+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
220+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
221+
222+
num_beams = 2
223+
past_key_values = self._create_inference_params(config, batch_size=2, num_beams=num_beams)
224+
225+
with torch.no_grad():
226+
output_ids = model.generate(
227+
**inputs,
228+
max_new_tokens=16,
229+
use_cache=True,
230+
past_key_values=past_key_values,
231+
num_beams=num_beams,
232+
do_sample=True,
233+
)
234+
235+
# Verify generation produced new tokens for both sequences
236+
assert output_ids.shape[0] == 2
237+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
238+
239+
# ==================== Standalone Mixtral Generation Tests ====================
240+
241+
def test_te_mixtral_model_generate_with_cache_beam_search(self):
242+
"""Test Mixtral generation with KV-cache and beam search using real model weights."""
243+
import gc
244+
245+
model_hf = self.get_reference_model()
246+
model_te = convert_mixtral_hf_to_te(model_hf, attn_input_format="thd", self_attn_mask_type="padding_causal")
247+
del model_hf
248+
gc.collect()
249+
250+
model_te.to("cuda")
251+
model_te.eval()
252+
253+
tokenizer = self.get_tokenizer()
254+
255+
prompts = (
256+
'Licensed under the Apache License, Version 2.0 (the "License");'
257+
" you may not use this file except in compliance with the License."
258+
" You may obtain a copy of the License at",
259+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore",
260+
)
261+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
262+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
263+
264+
num_beams = 2
265+
config = model_te.config
266+
past_key_values = self._create_inference_params(config, batch_size=2, num_beams=num_beams)
267+
268+
with torch.no_grad():
269+
output_ids = model_te.generate(
270+
**inputs,
271+
max_new_tokens=16,
272+
use_cache=True,
273+
past_key_values=past_key_values,
274+
num_beams=num_beams,
275+
do_sample=False,
276+
)
277+
278+
generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
279+
assert "http://www.apache.org/licenses/LICENSE-2.0" in generated_text[0]
280+
assert "et dolore magna aliqua" in generated_text[1]

0 commit comments

Comments
 (0)