Skip to content

Add Mixtral model for MoE demo#1458

Merged
pstjohn merged 4 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/bio-216-moe-in-bionemo-recipes
Feb 23, 2026
Merged

Add Mixtral model for MoE demo#1458
pstjohn merged 4 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/bio-216-moe-in-bionemo-recipes

Conversation

@pstjohn
Copy link
Collaborator

@pstjohn pstjohn commented Feb 9, 2026

Summary

Adds a Mixtral Mixture-of-Experts model to bionemo-recipes/models/mixtral/ using TransformerEngine, following the same pattern as the existing llama3
model.

MoE Implementation

The core MoE block (NVMixtralSparseMoeBlock) uses:

  • te.GroupedLinear for efficient parallel expert FFN computation (gate_up + down projections)
  • te.moe_permute / te.moe_unpermute with map_type="index" for token-to-expert routing
  • Standard nn.Linear router (kept in bf16 during FP8 training via te.autocast(enabled=False))
  • SwiGLU activation with fused gate/up projection split

Weight conversion handles the structural difference between HF's stacked 3D expert tensors ([num_experts, out, in]) and TE's per-expert GroupedLinear
weights (weight0, weight1, ...).

Base test class improvements

  • Added clear_gpu_memory fixture (gc + cuda cache clear before/after each test) to prevent OOM cascading
  • get_converted_te_model_checkpoint now frees the HF model and moves TE to CPU before saving (save_pretrained clones state dict internally)
  • test_golden_values and test_golden_values_thd now run models sequentially to support large models

Summary by CodeRabbit

  • New Features

    • Added support for Mixtral mixture-of-experts model with Transformer Engine backend.
    • Improved token packing and sequence batching with overflow handling and sample splitting.
  • Bug Fixes

    • Fixed typo in error logging messages.
    • Enhanced validation for tensor format handling in data collators.
  • Tests

    • Added comprehensive test suite for Mixtral model conversions and FP8 quantization.
    • Improved GPU capability detection and pattern-based model configuration handling.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 9, 2026

Important

Review skipped

Auto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This pull request introduces a Mixtral Mixture-of-Experts model implementation with Transformer Engine support, refactors mutable default arguments across existing state modules, enhances collator batch packing logic with validation and splitting behavior, and establishes a comprehensive test framework for model validation and conversion testing.

Changes

Cohort / File(s) Summary
State Module Refactoring (Mutable Defaults)
bionemo-recipes/models/amplify/src/amplify/state.py, bionemo-recipes/models/esm2/src/esm/state.py, bionemo-recipes/models/llama3/state.py
Changed default parameters for apply_transforms from mutable lists [] to None, with initialization guards inside the function. Minor logging fix: corrected "Enountered" typo to "Encountered" in error handling.
Collator Batch Packing Enhancement
bionemo-recipes/models/esm2/src/esm/collator.py, bionemo-recipes/models/llama3/collator.py, bionemo-recipes/recipes/esm2_native_te/collator.py, bionemo-recipes/recipes/esm2_peft_te/collator.py, bionemo-recipes/recipes/llama3_native_te/collator.py
Enforced return_tensors policy with NotImplementedError if not "pt". Enhanced TokenPackingDataset.__iter__ with per-sample length validation, refined batch overflow handling for split and non-split modes, and improved length accounting using actual sample lengths.
New Mixtral Model Implementation
bionemo-recipes/models/mixtral/collator.py
Added comprehensive data collation toolkit including DataCollatorWithFlattening for sequence packing, TokenPackingDataset for efficient batching, DataCollatorForContextParallel for CP/TP support, and ContextParallelDataLoaderWrapper for distributed data loading with sequence splitting and padding utilities.
Mixtral Transformer Engine Integration
bionemo-recipes/models/mixtral/convert.py, bionemo-recipes/models/mixtral/modeling_mixtral_te.py, bionemo-recipes/models/mixtral/state.py
Introduced bidirectional HF-to-TE conversion with expert weight splitting/merging, complete TE-based Mixtral model with sparse MoE blocks and decoder layers, and comprehensive state dict transformation framework with wildcard key matching and pluggable transform functions.
Mixtral Model Export and Support
bionemo-recipes/models/mixtral/export.py, bionemo-recipes/models/mixtral/.ruff.toml, bionemo-recipes/models/mixtral/requirements.txt
Added checkpoint export functionality to convert HF Mixtral to TE format, Ruff configuration inheritance from parent config, and required dependencies for Mixtral (transformer_engine, torch, torchao).
Mixtral Test Infrastructure
bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py, bionemo-recipes/models/mixtral/tests/common/fixtures.py, bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py, bionemo-recipes/models/mixtral/tests/conftest.py, bionemo-recipes/models/mixtral/tests/common/__init__.py, bionemo-recipes/models/mixtral/tests/common/README.md
Introduced comprehensive test framework with BaseModelTest abstract class, TestTolerances dataclass, extensive test fixtures (FP8 recipes, attention backends, input formats), and concrete Mixtral tests covering smoke tests, conversions, golden values, and FP8 scenarios.
Common Test Improvements
bionemo-recipes/models/esm2/tests/common/test_modeling_common.py, bionemo-recipes/models/llama3/tests/common/test_modeling_common.py
Enhanced GPU detection with try/except fallback, added fnmatch pattern matching for FP8 quantization skip logic, and updated get_reference_model_no_weights to accept arbitrary kwargs for flexible config creation.
Test Fixtures and Documentation
bionemo-recipes/models/esm2/tests/common/fixtures.py, bionemo-recipes/models/esm2/tests/common/__init__.py, bionemo-recipes/models/llama3/tests/common/fixtures.py, bionemo-recipes/models/llama3/tests/common/__init__.py, bionemo-recipes/models/esm2/tests/common/README.md, bionemo-recipes/models/llama3/tests/common/README.md
Removed duplicate license headers, replaced del os.environ with safer os.environ.pop() calls, updated code fence formatting in documentation, and standardized test infrastructure documentation.
Collator Test Coverage
bionemo-recipes/models/esm2/tests/test_collator.py
Added three new test cases for TokenPackingDataset covering padding with split mode, non-split mode overflow handling, and oversized sample validation.
CI Infrastructure
ci/scripts/check_copied_files.py
Extended file copy mappings to include Mixtral model destinations for collator, state, and common tests synchronization.

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant DataCollator as DataCollator<br/>WithFlattening
    participant TP as TokenPacking<br/>Dataset
    participant CP as Context<br/>Parallel
    
    Client->>TP: __iter__()
    loop For each sample
        TP->>TP: Validate length vs<br/>max_tokens_per_batch
        alt Split disabled
            TP->>TP: Accumulate or yield<br/>existing batch
            TP->>DataCollator: samples
            DataCollator->>DataCollator: Flatten + pack
            DataCollator-->>TP: packed batch
        else Split enabled
            TP->>TP: Compute tokens_available
            alt Room available
                TP->>TP: Split sample by<br/>tokens_available
                TP->>DataCollator: first_part + batch
                DataCollator-->>TP: packed batch
            else No room
                TP->>DataCollator: existing batch
                DataCollator-->>TP: packed batch
                TP->>TP: Start new batch<br/>with current sample
            end
        end
        TP->>CP: packed batch
        CP->>CP: Shard by CP rank
        CP->>CP: Compute cu_seq_lens
        CP-->>Client: CP-formatted batch
    end
Loading
sequenceDiagram
    participant HFModel as HF Model<br/>(Mixtral)
    participant Convert as convert_mixtral<br/>_hf_to_te
    participant StateDictTransform as apply_transforms<br/>+ TransformFns
    participant TEModel as TE Model<br/>(NVMixtral)
    
    HFModel->>Convert: load state_dict
    Convert->>Convert: Build NVMixtralConfig
    Convert->>TEModel: Allocate on meta device
    Convert->>StateDictTransform: Apply transform chain
    loop For each transform
        StateDictTransform->>StateDictTransform: Match source→target<br/>keys via wildcards
        StateDictTransform->>StateDictTransform: Split/merge expert<br/>weights as needed
        StateDictTransform->>StateDictTransform: QKV projection<br/>handling
    end
    StateDictTransform->>TEModel: Load transformed state
    TEModel-->>Convert: Converted model
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐰 Hops excitedly
A Mixtral feast, with experts bright,
State dicts mended, defaults made right,
Batches packed with Tesla's care,
From HF gardens to TE's lair!
Test frameworks blooming—oh what delight!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description is missing critical template sections: 'Usage' code snippet, 'Type of changes' checklist, 'CI Pipeline Configuration' labels, and 'Pre-submit Checklist' completion. Complete the PR description by filling in: a usage code snippet showing how users interact with the new Mixtral model, marking the type of change (New feature), selecting appropriate CI labels (likely ciflow:all-recipes), and confirming pre-submit checklist items.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Add Mixtral model for MoE demo' clearly and concisely summarizes the main change: adding a Mixtral model implementation to demonstrate Mixture-of-Experts functionality.
Docstring Coverage ✅ Passed Docstring coverage is 88.44% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 17

🤖 Fix all issues with AI agents
In `@bionemo-recipes/models/mixtral/collator.py`:
- Around line 98-180: The __call__ method ignores the return_tensors parameter;
validate that return_tensors is either None or "pt" at the start of __call__ and
raise NotImplementedError for other values, and forward return_tensors into
downstream calls so tensor backend is consistent: call self.collator(features,
return_tensors=return_tensors) instead of self.collator(features) and call
_pt_flatten_collate(features, return_position_ids=self.return_position_ids,
return_tensors=return_tensors) (or the appropriate flatten helper that accepts
return_tensors); keep the rest of the logic (masked_input_ids/labels,
separator_id handling, and padding via
_pad_batch_to_multiple_of/_pad_sequences_to_be_divisible_by) unchanged.
- Around line 305-351: In the __call__ method of the collator, the computation
batch_shard["max_length_k"] = batch_shard["max_length_q"] = max_length *
round(max_length / 64) is incorrect; replace it with a ceil-to-multiple-of-64
calculation (e.g. compute padded_max = ((max_length + 63) // 64) * 64) and
assign padded_max to both batch_shard["max_length_k"] and
batch_shard["max_length_q"] so the result is the next multiple of 64 without
floating point rounding or inflation.
- Around line 238-277: The iterator currently calls
_split_sample_by_num_tokens(sample, tokens_available) when tokens_available can
be 0 (causing an error) and can yield an empty batch when a single sample
exceeds max_tokens_per_batch with split_samples=False; fix __iter__ to first
check tokens_available <= 0 and in that case yield the current samples and reset
samples before handling the incoming sample, and then reprocess the incoming
sample: if split_samples is True, split the sample into chunks of size up to
max_tokens_per_batch (call _split_sample_by_num_tokens in a loop using chunk
size = max_tokens_per_batch) and yield/fill batches accordingly; if
split_samples is False, ensure you never yield an empty batch by appending the
oversized sample as its own batch (or yield it immediately) instead of yielding
samples=[]; update references in __iter__ to use tokens_available guard,
_split_sample_by_num_tokens, split_samples, max_tokens_per_batch, samples and
current_length.

In `@bionemo-recipes/models/mixtral/convert.py`:
- Around line 136-137: The current filtering uses
MixtralConfig.__init__.__code__.co_varnames which is fragile because co_varnames
contains locals too; update the logic that builds valid_keys (used to create
filtered_config from te_config_dict) to derive parameter names from
MixtralConfig.__init__ via inspect.signature (e.g.,
inspect.signature(MixtralConfig.__init__).parameters) and then filter
te_config_dict by those parameter names so only actual constructor args are
preserved.

In `@bionemo-recipes/models/mixtral/export.py`:
- Line 53: The copy call uses a relative path so it breaks when the working
directory isn't the file's folder; change the source to be anchored to this
module by resolving Path(__file__).parent / "modeling_mixtral_te.py" and pass
that resolved path as the first argument to shutil.copy when copying to
export_path (the existing export_path / "modeling_mixtral_te.py" destination can
remain); update the code that calls shutil.copy accordingly (referencing
shutil.copy and the filename "modeling_mixtral_te.py").
- Around line 22-37: The export_hf_checkpoint function is creating a randomly
initialized model by calling AutoModelForCausalLM.from_config; replace that with
loading the actual pretrained weights by calling
AutoModelForCausalLM.from_pretrained using the same tag (keep or remove the
separate AutoConfig.from_pretrained as needed), so model_hf holds the real
checkpoint weights before export; update any related uses of model_hf and ensure
tokenizer/config are loaded from_pretrained as well if required.

In `@bionemo-recipes/models/mixtral/modeling_mixtral_te.py`:
- Line 453: The module sets torch._dynamo.config.capture_scalar_outputs = True
at import time, which mutates global TorchDynamo state and can affect other
code; change this to a local, temporary setting or document it: in the functions
that require this behavior (e.g., the compile/optimization entry points in
modeling_mixtral_te), save the current
torch._dynamo.config.capture_scalar_outputs value, set it to True only for the
scope where you call torch.compile/torch._dynamo operations, then restore the
original value in a finally block (or use a small context manager) so the global
config isn't mutated at module load; alternatively, add a clear comment in the
module/top-level docstring explaining the global requirement if a global change
is unavoidable.

In `@bionemo-recipes/models/mixtral/requirements.txt`:
- Around line 1-5: The requirements file currently leaves torch, transformers,
transformer_engine[pytorch], torchao, and lm-eval unpinned which causes
reproducibility and compatibility issues; update the requirements (the entries
referencing torch, transformers, transformer_engine[pytorch], torchao!=0.14.0,
and lm-eval) to pin explicit compatible versions (or reference a constraints
file) that match the project's tested torch version matrix and known working
transformer_engine and torchao releases, ensuring transformer_engine is the
pytorch build and avoiding the excluded torchao 0.14.0 conflict; provide exact
version specs consistent with the root project strategy so installs are
deterministic.

In `@bionemo-recipes/models/mixtral/state.py`:
- Line 321: Fix the typo in the logger.error message: change "Enountered" to
"Encountered" in the error string used in the code that logs IndexError during
transform (the logger.error call that includes source_matches and
target_matches); update the f-string to read logger.error(f"Encountered
IndexError during transform.\n{source_matches=}\n{target_matches=}") so the
message is spelled correctly while preserving the variables and formatting.
- Around line 70-72: The dataclass/module-level defaults use mutable lists
(transforms, state_dict_ignored_entries); change their default values from [] to
None and initialize them to empty lists at construction (e.g., in __post_init__
of the dataclass or where the object is created) to avoid shared mutable state;
keep cast_dtype as Optional[torch.dtype] = None unchanged and ensure any code
referencing transforms or state_dict_ignored_entries handles the None-to-list
initialization (refer to the symbols transforms, state_dict_ignored_entries,
TransformCTX, cast_dtype).

In `@bionemo-recipes/models/mixtral/tests/common/__init__.py`:
- Around line 1-29: Remove the duplicated license header blocks and leave a
single, consistent header: keep one SPDX copyright line with the correct year
(use 2026), set SPDX-License-Identifier to "Apache-2.0", and retain the standard
Apache-2.0 license text that follows; delete the other entire license block so
only one header remains at the top of the file.

In `@bionemo-recipes/models/mixtral/tests/common/fixtures.py`:
- Around line 141-143: Replace the direct environment deletions with safe
removals using os.environ.pop for keys "NVTE_FUSED_ATTN" and "NVTE_FLASH_ATTN"
to avoid KeyError if they were never set; keep the existing update to
_attention_backends["backend_selection_requires_update"] = True unchanged so the
backend refresh still occurs. Use os.environ.pop("NVTE_FUSED_ATTN", None) and
os.environ.pop("NVTE_FLASH_ATTN", None) in the teardown (where the current del
os.environ[...] calls are) to safely remove the variables.
- Line 66: Replace the unsafe deletion del os.environ["NVTE_DEBUG"] in the
fixtures cleanup with a safe pop call so removing the NVTE_DEBUG env var cannot
raise KeyError; locate the occurrence of del os.environ["NVTE_DEBUG"] in
bionemo-recipes/models/mixtral/tests/common/fixtures.py (the teardown/cleanup
for the test fixture) and use os.environ.pop("NVTE_DEBUG", None) to silently
handle the case where the variable is already absent.
- Around line 1-30: The file contains two SPDX/Apache-2.0 license header blocks
with conflicting years; remove the second duplicate header (the block that
begins with "SPDX-FileCopyrightText: Copyright (c) 2025" and its following
Apache-2.0 license text) so only the first header (the 2026 SPDX header)
remains; ensure no extra blank lines are left at the top after removal.

In `@bionemo-recipes/models/mixtral/tests/common/README.md`:
- Around line 7-13: The fenced code block in README.md under tests/common is
missing a language marker (MD040); update the snippet in that file (the
three-line tree block within the fenced code) to include a language tag such as
"text" or "plaintext" after the opening backticks so linting passes (e.g.,
change ``` to ```text).

In `@bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py`:
- Around line 316-354: The HF API calls use the wrong kwarg name: replace the
"dtype" param with "torch_dtype" in both upstream_class.from_pretrained(...)
inside get_reference_model and AutoConfig.from_pretrained(...) inside
get_reference_model_no_weights so the requested precision (e.g., torch.bfloat16
or torch.float32) is honored; keep the existing attn_implementation and revision
logic and leave model.to("cuda") as-is.
- Around line 34-36: The HAS_DATA_CENTER_GPU probe calls
torch.cuda.get_device_name(0) unguarded which raises in CPU-only environments;
update the definition of HAS_DATA_CENTER_GPU to first check
torch.cuda.is_available() (and optionally wrap the probe in a try/except for
RuntimeError/AssertionError) and only call torch.cuda.get_device_name(0) when
CUDA is available, otherwise set HAS_DATA_CENTER_GPU to False; modify the
variable in the test module (the HAS_DATA_CENTER_GPU assignment) to implement
this guard so imports do not fail on CPU-only systems.
🧹 Nitpick comments (6)
bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py (1)

104-105: Remove redundant pad_token check.

This check duplicates lines 69-70 in get_tokenizer(). Since get_test_input_data calls self.get_tokenizer() at line 97, the pad_token is already set.

♻️ Proposed fix
         data_collator = DataCollatorForLanguageModeling(
             tokenizer=tokenizer,
             pad_to_multiple_of=pad_to_multiple_of,
             mlm=False,
         )

-        if tokenizer.pad_token is None:
-            tokenizer.pad_token = tokenizer.eos_token
-
         if format == "thd":
bionemo-recipes/models/mixtral/modeling_mixtral_te.py (1)

380-401: Consider standard import for GenerationMixin.

The dynamic import __import__("transformers").GenerationMixin works but is unconventional. A standard import would be clearer.

♻️ Proposed fix
+from transformers import GenerationMixin
+
 ...
 
-class NVMixtralForCausalLM(NVMixtralPreTrainedModel, __import__("transformers").GenerationMixin):
+class NVMixtralForCausalLM(NVMixtralPreTrainedModel, GenerationMixin):
bionemo-recipes/models/mixtral/convert.py (1)

60-71: Consider alternative to exec() for dynamic function creation.

Using exec() works here but is harder to debug and understand. A simpler approach using *args:

♻️ Proposed alternative
 def _make_merge_experts_fn(num_experts: int):
-    """Create a merge function with the correct number of named parameters.
-
-    The state.py transform system maps function parameter names to source keys, so we need a function
-    with exactly `num_experts` named parameters (weight0, weight1, ...).
-    """
-    param_names = [f"weight{i}" for i in range(num_experts)]
-    code = f"def merge_experts({', '.join(param_names)}):\n    return torch.stack([{', '.join(param_names)}])"
-    local_ns = {"torch": torch}
-    exec(code, local_ns)
-    return local_ns["merge_experts"]
+    """Create a merge function that stacks expert weights."""
+    def merge_experts(*weights):
+        if len(weights) != num_experts:
+            raise ValueError(f"Expected {num_experts} weights, got {len(weights)}")
+        return torch.stack(weights)
+    return merge_experts

Note: This assumes the state transform system supports *args. If named parameters are strictly required by the transform system, the current exec() approach is acceptable but should be documented.

bionemo-recipes/models/mixtral/tests/common/fixtures.py (1)

62-63: Remove redundant import os.

os is already imported at line 33.

♻️ Proposed fix
 def use_te_debug():
     """Auto-use fixture to enable TransformerEngine debugging.
 
     This fixture automatically enables debug mode for TransformerEngine
     in all tests for better error messages.
     """
-    import os
-
     os.environ["NVTE_DEBUG"] = "1"
bionemo-recipes/models/mixtral/state.py (2)

161-161: Use logger instead of print for unexpected keys.

Consistency with the rest of the module which uses logger.debug, logger.warning, etc.

♻️ Proposed fix
-            print(f"Unexpected key: {name} not in target model but is in source model.")
+            logger.warning(f"Unexpected key: {name} not in target model but is in source model.")

265-265: Avoid global side effect from np.set_printoptions.

This modifies global numpy print settings on every transform call. If needed for debugging, consider using a context manager or removing it.

♻️ Proposed fix
-        np.set_printoptions(threshold=10)

If this is needed for debugging specific issues, use np.printoptions(threshold=10) as a context manager around the specific debug output instead.

@pstjohn
Copy link
Collaborator Author

pstjohn commented Feb 11, 2026

@coderabbitai resolve

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 11, 2026

✅ Actions performed

Comments resolved.

@pstjohn pstjohn force-pushed the pstjohn/bio-216-moe-in-bionemo-recipes branch from 6574f52 to 58f9af5 Compare February 18, 2026 17:18
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 18, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@pstjohn pstjohn force-pushed the pstjohn/bio-216-moe-in-bionemo-recipes branch 4 times, most recently from 1dba9f4 to 24d1e2c Compare February 18, 2026 20:44
@jomitchellnv
Copy link
Collaborator

@coderabbitai run

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 19, 2026

@jomitchellnv: I'll review the changes in this PR.


🧠 Learnings used
Learnt from: CR
Repo: NVIDIA/bionemo-framework PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-01-22T19:29:40.552Z
Learning: Applies to bionemo-recipes/models/**/*.py : Models in bionemo-recipes/models/ should be HuggingFace-compatible PreTrainedModel classes with TransformerEngine layers, distributable via Hugging Face Hub, and drop-in replacements for standard transformers compatible with AutoModel.from_pretrained()

Learnt from: CR
Repo: NVIDIA/bionemo-framework PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-01-22T19:29:40.552Z
Learning: Applies to bionemo-recipes/models/**/*.py : Models in bionemo-recipes/models/ must include: golden value tests proving TE model matches reference model, bidirectional conversion functions (convert_hf_to_te() and convert_te_to_hf()), export script (export.py) for Hugging Face Hub bundling, and open-source license
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🧹 Nitpick comments (6)
bionemo-recipes/models/mixtral/requirements.txt (1)

1-1: Move lm-eval to a separate test requirements file.

lm-eval is annotated # For testing but lives in the primary requirements.txt, inflating the runtime install surface for all downstream consumers of this package. The repository already follows this pattern at the root level with separate requirements-test.txt and requirements-dev.txt files. Consider creating requirements-test.txt in this directory to align with established project conventions and isolate test-only dependencies.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/mixtral/requirements.txt` at line 1, Remove the
test-only dependency "lm-eval" from the mixtral requirements.txt and create a
new test requirements file named requirements-test.txt that contains "lm-eval  #
For testing"; update any local references in the mixtral package (e.g., any
setup or CI job that installed mixtral/requirements.txt) to install
mixtral/requirements-test.txt for test runs instead of the main requirements.txt
so runtime installs remain minimal.
bionemo-recipes/models/mixtral/tests/common/fixtures.py (1)

40-51: Redundant import os inside fixture.

os is already imported at module level (line 18). The local re-import on line 47 is unnecessary.

♻️ Proposed fix
 `@pytest.fixture`(autouse=True)
 def use_te_debug():
     """Auto-use fixture to enable TransformerEngine debugging.
 
     This fixture automatically enables debug mode for TransformerEngine
     in all tests for better error messages.
     """
-    import os
-
     os.environ["NVTE_DEBUG"] = "1"
     yield
     os.environ.pop("NVTE_DEBUG", None)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/mixtral/tests/common/fixtures.py` around lines 40 -
51, The fixture use_te_debug unnecessarily re-imports os locally; remove the
local "import os" statement inside the use_te_debug function so it uses the
module-level os import, leaving the rest of the fixture
(os.environ["NVTE_DEBUG"] = "1", yield, os.environ.pop("NVTE_DEBUG", None))
unchanged.
bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py (4)

600-622: Redundant local import of AutoConfig.

Line 606 imports AutoConfig from transformers, but it's already imported at module level on line 31.

♻️ Proposed fix
     def test_convert_config(self):
         """Test that config can be converted between HF and TE formats."""
         upstream_id = self.get_upstream_model_id()
         revision = self.get_upstream_model_revision()
 
-        # Load HF config
-        from transformers import AutoConfig
-
         kwargs = {}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py` around
lines 600 - 622, The test_convert_config function has a redundant local import
of AutoConfig; remove the line "from transformers import AutoConfig" inside
test_convert_config and rely on the module-level AutoConfig import instead so
the test uses the already-imported AutoConfig symbol; update test_convert_config
to instantiate hf_config via AutoConfig.from_pretrained(upstream_id, **kwargs)
without a local import.

85-99: Docstring example references non-existent BioNeMoModelTester class.

The example on line 87 uses ESM2ModelTester(BioNeMoModelTester) but the actual class name is BaseModelTest. This mirrors the same issue in __init__.py's docstring.

♻️ Proposed fix
     Example:
         ```python
-        class ESM2ModelTester(BioNeMoModelTester):
+        class ESM2ModelTester(BaseModelTest):
             def get_model_class(self):
                 return NVEsmForMaskedLM
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py` around
lines 85 - 99, Update the docstring example that currently shows "class
ESM2ModelTester(BioNeMoModelTester):" to use the correct base class name
"BaseModelTest" (replace BioNeMoModelTester with BaseModelTest) so the sample
subclass matches the actual abstract base class; adjust any mirrored occurrences
in module docstrings (e.g., __init__.py) to use BaseModelTest as well to keep
examples consistent with the implemented class.

630-667: get_converted_te_model_checkpoint re-runs full conversion on every call.

get_converted_te_model calls get_converted_te_model_checkpoint() each time, which downloads the HF model, converts it, and saves to disk — even though the checkpoint path is deterministic within a class's tmp dir. Tests like test_golden_values and test_golden_values_thd both call get_converted_te_model, causing redundant work.

Consider caching the checkpoint path (e.g., checking if the directory already exists before re-converting), or using functools.lru_cache / a class-level sentinel.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py` around
lines 630 - 667, The test helper currently reconverts and re-saves the TE model
every time get_converted_te_model_checkpoint() is called; update
get_converted_te_model_checkpoint to check for an existing checkpoint under the
deterministic path (self._tmp_dir / "converted_te_model") and return it
immediately if present, otherwise perform the convert/save flow; alternatively,
add a simple cache (e.g., a class or instance attribute sentinel like
self._converted_te_checkpoint) or wrap the checkpoint getter with
functools.lru_cache to avoid re-running get_reference_model(),
get_hf_to_te_converter(), convert_fn, and model_te.save_pretrained on subsequent
calls invoked by get_converted_te_model() and tests.

259-294: Add GroupedLinear initialization and FP8 quantization verification to the verify_model_parameters_initialized_correctly function.

The Mixtral MoE block uses transformer_engine.pytorch.GroupedLinear for experts_gate_up and experts_down modules, but these are not validated in the initialization checks. Since GroupedLinear modules are initialized with the same init_method as regular Linear modules and support FP8 quantization, they should be included in the verification logic to ensure MoE expert weights are correctly initialized and quantized.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py` around
lines 259 - 294, The tests currently only check torch.nn.Embedding and
transformer_engine.pytorch.Linear in
verify_model_parameters_initialized_correctly; add support for
transformer_engine.pytorch.GroupedLinear (used by experts_gate_up and
experts_down) by treating it like transformer_engine.pytorch.Linear: verify
weight mean is ~0, std equals config.initializer_range with the same tolerances,
check bias zeros if present, and perform the FP8 quantization assertion (skip if
name is in model._tied_weights_keys or matches model._do_not_quantize patterns).
Update the logic that asserts isinstance(module.weight, QuantizedTensor) so it
also accepts GroupedLinear modules.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@bionemo-recipes/models/esm2/src/esm/collator.py`:
- Around line 286-291: The ValueError message in the TokenPackingDataset code
path inside esm.collator.py concatenates adjacent string literals so "to" and
"ensure" run together; update the error string in the raise ValueError in the
block referencing self.max_tokens_per_batch (inside the token packing / collator
logic) to include a space between "to" and "ensure" (either by adding a trailing
space in the previous literal or combining into one f-string) so the message
reads "...dataset to ensure all samples fit within max_tokens_per_batch."

In `@bionemo-recipes/models/esm2/tests/common/test_modeling_common.py`:
- Around line 349-358: get_reference_model_no_weights currently passes explicit
dtype and revision plus **kwargs into AutoConfig.from_pretrained which can raise
TypeError if callers supply those keys; fix by copying kwargs and removing any
'dtype' and 'revision' keys (e.g., kwargs_copy = dict(kwargs);
kwargs_copy.pop("dtype", None); kwargs_copy.pop("revision", None)) then call
AutoConfig.from_pretrained(..., dtype=torch.float32,
revision=self.get_upstream_model_revision(), **kwargs_copy); also update the
ESM2 test subclass override of get_reference_model_no_weights to accept **kwargs
and forward them to super() (e.g., def get_reference_model_no_weights(self,
**kwargs): return super().get_reference_model_no_weights(**kwargs)) so
signatures match.

In `@bionemo-recipes/models/llama3/collator.py`:
- Around line 285-327: The validation currently uses raw sample_length; update
TokenPackingDataset logic to compute padded_len =
self._padded_len(sample_length) (taking pad_sequences_to_be_divisible_by into
account) and validate that padded_len <= self.max_tokens_per_batch before
proceeding, raising the same ValueError if it exceeds; then use padded_len (not
raw sample_length) for current_length accumulation, comparisons (==, >) and when
computing tokens_in_batch/tokens_available and when calling
_split_sample_by_num_tokens to ensure split logic and non-split branches always
respect max_tokens_per_batch and reuse the padded length consistently.

In `@bionemo-recipes/models/mixtral/.ruff.toml`:
- Line 1: The child ruff config (the extend = "../.ruff.toml" in
mixtral/.ruff.toml) points at a missing parent config; create the parent
.ruff.toml in the models directory that the child extends and populate it with
the shared linting rules/settings (e.g., select/ignore codes, line-length,
target-python, etc.) so model subdirectories (mixtral and llama3) inherit
consistent rules; ensure the filename matches ".ruff.toml" and the child extend
path remains "../.ruff.toml".

In `@bionemo-recipes/models/mixtral/modeling_mixtral_te.py`:
- Around line 318-342: The code assumes input_ids exists when packing and when
computing cache lengths; change uses of input_ids to prefer inputs_embeds or
hidden_states so None input_ids won't crash and lengths are correct: when
computing padded_seq_len replace input_ids.size(1) with hidden_states.size(1)
(or inputs_embeds.size(1) if you prefer explicit check), and when computing
lengths for past_key_values use attention_mask.shape == hidden_states.shape (or
compare to inputs_embeds.shape when input_ids is None) and use input batch size
from hidden_states.shape[0] instead of input_ids.shape[0]; update the branches
around _unpad_input, padded_seq_len, and the lengths calculation that calls
past_key_values.pre_step to reference hidden_states/inputs_embeds accordingly.

In `@bionemo-recipes/models/mixtral/state.py`:
- Around line 458-481: The state_transform function's docstring uses NumPy-style
section separators (e.g., "Returns: -------", "Examples: --------") instead of
Google-style; update the docstring for state_transform to Google style by
removing the underline separators, using plain "Returns:" and "Examples:"
headings followed by an indented descriptive paragraph or block, ensure
parameter descriptions remain under "Args:" in Google format and example code
stays in an indented literal block under "Examples:", and keep the existing
content same while conforming to pydocstyle Google conventions.

In `@bionemo-recipes/models/mixtral/tests/common/__init__.py`:
- Around line 16-37: The module docstring references non-existent classes
BioNeMoModelTester and BioNeMoModelTest; update the docstring and example to use
the actual exported class BaseModelTest and keep TestTolerances, so change
occurrences of BioNeMoModelTester/BioNeMoModelTest in the descriptive list and
the example import/usage to BaseModelTest and TestTolerances (e.g., "from
tests.common import BaseModelTest, TestTolerances" and "class
ESM2ModelTester(BaseModelTest):") to ensure the example matches the real API.

In `@bionemo-recipes/recipes/esm2_peft_te/collator.py`:
- Around line 285-327: The initial sample length check uses raw sample_length
rather than its padded size, so replace the validation to compute padded =
self._padded_len(sample_length) (respecting pad_sequences_to_be_divisible_by)
and compare padded against self.max_tokens_per_batch; raise the ValueError if
padded > max_tokens_per_batch. Also ensure code paths that start a new batch or
set current_length (the branches using self._padded_len(sample_length), the
split_samples branch, and where you assign current_length =
self._padded_len(len(samples[0]["input_ids"]))) consistently use the same padded
value for this incoming sample to avoid creating oversized batches (adjust
tokens_available logic to use tokens capacity versus padded size).

---

Duplicate comments:
In `@bionemo-recipes/models/mixtral/modeling_mixtral_te.py`:
- Around line 453-458: Remove the module-level mutation of
torch._dynamo.config.capture_scalar_outputs and instead toggle it only around
the actual compile sites: when calling torch.compile for the functions like
_pad_input and _unpad_input, save the current
torch._dynamo.config.capture_scalar_outputs value, set it to True, call
torch.compile(...), then restore the original value; ensure no import-time side
effects remain and all references to torch.compile in this module use this
save/set/restore pattern.

In `@bionemo-recipes/models/mixtral/requirements.txt`:
- Around line 1-5: The requirements.txt entries (lm-eval, torch,
torchao!=0.14.0, transformer_engine[pytorch], transformers) are unpinned and
harm reproducibility; update each package line to a specific, tested version or
a strict version range (e.g., >= and < upper bound) instead of bare names or
only an exclusion—replace the current entries with pinned versions or ranges
based on the project’s tested environment (for example pin torch and
transformers to the exact tested versions and change torchao!=0.14.0 to a
concrete range like torchao>=X.Y.Z,<0.14.0 or >=0.14.1,<X.Y.Z) so CI/builds
reproduce reliably and avoid allowing incompatible older releases.

In `@bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py`:
- Around line 349-358: The AutoConfig.from_pretrained call in
get_reference_model_no_weights is passing a non-existent dtype kwarg which just
sets an arbitrary config attribute; change the parameter name to torch_dtype so
the config correctly reflects the intended precision. Locate
get_reference_model_no_weights and replace dtype=torch.float32 with
torch_dtype=torch.float32 in the AutoConfig.from_pretrained(...) invocation to
match other usages and ensure consistency with model precision handling.
- Around line 338-347: The kwargs passed to upstream_class.from_pretrained use
the wrong key "dtype" instead of HuggingFace's expected "torch_dtype", so models
are loaded in default precision; update the kwargs construction before the
upstream_class.from_pretrained(upstream_id, **kwargs) call to set
kwargs["torch_dtype"] = dtype (and remove or stop setting kwargs["dtype"]) so
the model is instantiated with the requested precision; ensure this change
affects the same variables shown (kwargs, dtype, upstream_class.from_pretrained,
upstream_id).

---

Nitpick comments:
In `@bionemo-recipes/models/mixtral/requirements.txt`:
- Line 1: Remove the test-only dependency "lm-eval" from the mixtral
requirements.txt and create a new test requirements file named
requirements-test.txt that contains "lm-eval  # For testing"; update any local
references in the mixtral package (e.g., any setup or CI job that installed
mixtral/requirements.txt) to install mixtral/requirements-test.txt for test runs
instead of the main requirements.txt so runtime installs remain minimal.

In `@bionemo-recipes/models/mixtral/tests/common/fixtures.py`:
- Around line 40-51: The fixture use_te_debug unnecessarily re-imports os
locally; remove the local "import os" statement inside the use_te_debug function
so it uses the module-level os import, leaving the rest of the fixture
(os.environ["NVTE_DEBUG"] = "1", yield, os.environ.pop("NVTE_DEBUG", None))
unchanged.

In `@bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py`:
- Around line 600-622: The test_convert_config function has a redundant local
import of AutoConfig; remove the line "from transformers import AutoConfig"
inside test_convert_config and rely on the module-level AutoConfig import
instead so the test uses the already-imported AutoConfig symbol; update
test_convert_config to instantiate hf_config via
AutoConfig.from_pretrained(upstream_id, **kwargs) without a local import.
- Around line 85-99: Update the docstring example that currently shows "class
ESM2ModelTester(BioNeMoModelTester):" to use the correct base class name
"BaseModelTest" (replace BioNeMoModelTester with BaseModelTest) so the sample
subclass matches the actual abstract base class; adjust any mirrored occurrences
in module docstrings (e.g., __init__.py) to use BaseModelTest as well to keep
examples consistent with the implemented class.
- Around line 630-667: The test helper currently reconverts and re-saves the TE
model every time get_converted_te_model_checkpoint() is called; update
get_converted_te_model_checkpoint to check for an existing checkpoint under the
deterministic path (self._tmp_dir / "converted_te_model") and return it
immediately if present, otherwise perform the convert/save flow; alternatively,
add a simple cache (e.g., a class or instance attribute sentinel like
self._converted_te_checkpoint) or wrap the checkpoint getter with
functools.lru_cache to avoid re-running get_reference_model(),
get_hf_to_te_converter(), convert_fn, and model_te.save_pretrained on subsequent
calls invoked by get_converted_te_model() and tests.
- Around line 259-294: The tests currently only check torch.nn.Embedding and
transformer_engine.pytorch.Linear in
verify_model_parameters_initialized_correctly; add support for
transformer_engine.pytorch.GroupedLinear (used by experts_gate_up and
experts_down) by treating it like transformer_engine.pytorch.Linear: verify
weight mean is ~0, std equals config.initializer_range with the same tolerances,
check bias zeros if present, and perform the FP8 quantization assertion (skip if
name is in model._tied_weights_keys or matches model._do_not_quantize patterns).
Update the logic that asserts isinstance(module.weight, QuantizedTensor) so it
also accepts GroupedLinear modules.

Signed-off-by: Peter St. John <pstjohn@nvidia.com>

rebase on moe fix pr

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn pstjohn force-pushed the pstjohn/bio-216-moe-in-bionemo-recipes branch 2 times, most recently from b8c999e to bd5543c Compare February 21, 2026 00:39
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn pstjohn force-pushed the pstjohn/bio-216-moe-in-bionemo-recipes branch from bd5543c to eda3894 Compare February 21, 2026 00:40
@pstjohn pstjohn enabled auto-merge February 23, 2026 20:47
@pstjohn pstjohn added this pull request to the merge queue Feb 23, 2026
Merged via the queue into NVIDIA:main with commit b5a98d2 Feb 23, 2026
17 checks passed
@pstjohn pstjohn deleted the pstjohn/bio-216-moe-in-bionemo-recipes branch February 23, 2026 21:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants