Conversation
|
Important Review skippedAuto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis pull request introduces a Mixtral Mixture-of-Experts model implementation with Transformer Engine support, refactors mutable default arguments across existing state modules, enhances collator batch packing logic with validation and splitting behavior, and establishes a comprehensive test framework for model validation and conversion testing. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant DataCollator as DataCollator<br/>WithFlattening
participant TP as TokenPacking<br/>Dataset
participant CP as Context<br/>Parallel
Client->>TP: __iter__()
loop For each sample
TP->>TP: Validate length vs<br/>max_tokens_per_batch
alt Split disabled
TP->>TP: Accumulate or yield<br/>existing batch
TP->>DataCollator: samples
DataCollator->>DataCollator: Flatten + pack
DataCollator-->>TP: packed batch
else Split enabled
TP->>TP: Compute tokens_available
alt Room available
TP->>TP: Split sample by<br/>tokens_available
TP->>DataCollator: first_part + batch
DataCollator-->>TP: packed batch
else No room
TP->>DataCollator: existing batch
DataCollator-->>TP: packed batch
TP->>TP: Start new batch<br/>with current sample
end
end
TP->>CP: packed batch
CP->>CP: Shard by CP rank
CP->>CP: Compute cu_seq_lens
CP-->>Client: CP-formatted batch
end
sequenceDiagram
participant HFModel as HF Model<br/>(Mixtral)
participant Convert as convert_mixtral<br/>_hf_to_te
participant StateDictTransform as apply_transforms<br/>+ TransformFns
participant TEModel as TE Model<br/>(NVMixtral)
HFModel->>Convert: load state_dict
Convert->>Convert: Build NVMixtralConfig
Convert->>TEModel: Allocate on meta device
Convert->>StateDictTransform: Apply transform chain
loop For each transform
StateDictTransform->>StateDictTransform: Match source→target<br/>keys via wildcards
StateDictTransform->>StateDictTransform: Split/merge expert<br/>weights as needed
StateDictTransform->>StateDictTransform: QKV projection<br/>handling
end
StateDictTransform->>TEModel: Load transformed state
TEModel-->>Convert: Converted model
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 17
🤖 Fix all issues with AI agents
In `@bionemo-recipes/models/mixtral/collator.py`:
- Around line 98-180: The __call__ method ignores the return_tensors parameter;
validate that return_tensors is either None or "pt" at the start of __call__ and
raise NotImplementedError for other values, and forward return_tensors into
downstream calls so tensor backend is consistent: call self.collator(features,
return_tensors=return_tensors) instead of self.collator(features) and call
_pt_flatten_collate(features, return_position_ids=self.return_position_ids,
return_tensors=return_tensors) (or the appropriate flatten helper that accepts
return_tensors); keep the rest of the logic (masked_input_ids/labels,
separator_id handling, and padding via
_pad_batch_to_multiple_of/_pad_sequences_to_be_divisible_by) unchanged.
- Around line 305-351: In the __call__ method of the collator, the computation
batch_shard["max_length_k"] = batch_shard["max_length_q"] = max_length *
round(max_length / 64) is incorrect; replace it with a ceil-to-multiple-of-64
calculation (e.g. compute padded_max = ((max_length + 63) // 64) * 64) and
assign padded_max to both batch_shard["max_length_k"] and
batch_shard["max_length_q"] so the result is the next multiple of 64 without
floating point rounding or inflation.
- Around line 238-277: The iterator currently calls
_split_sample_by_num_tokens(sample, tokens_available) when tokens_available can
be 0 (causing an error) and can yield an empty batch when a single sample
exceeds max_tokens_per_batch with split_samples=False; fix __iter__ to first
check tokens_available <= 0 and in that case yield the current samples and reset
samples before handling the incoming sample, and then reprocess the incoming
sample: if split_samples is True, split the sample into chunks of size up to
max_tokens_per_batch (call _split_sample_by_num_tokens in a loop using chunk
size = max_tokens_per_batch) and yield/fill batches accordingly; if
split_samples is False, ensure you never yield an empty batch by appending the
oversized sample as its own batch (or yield it immediately) instead of yielding
samples=[]; update references in __iter__ to use tokens_available guard,
_split_sample_by_num_tokens, split_samples, max_tokens_per_batch, samples and
current_length.
In `@bionemo-recipes/models/mixtral/convert.py`:
- Around line 136-137: The current filtering uses
MixtralConfig.__init__.__code__.co_varnames which is fragile because co_varnames
contains locals too; update the logic that builds valid_keys (used to create
filtered_config from te_config_dict) to derive parameter names from
MixtralConfig.__init__ via inspect.signature (e.g.,
inspect.signature(MixtralConfig.__init__).parameters) and then filter
te_config_dict by those parameter names so only actual constructor args are
preserved.
In `@bionemo-recipes/models/mixtral/export.py`:
- Line 53: The copy call uses a relative path so it breaks when the working
directory isn't the file's folder; change the source to be anchored to this
module by resolving Path(__file__).parent / "modeling_mixtral_te.py" and pass
that resolved path as the first argument to shutil.copy when copying to
export_path (the existing export_path / "modeling_mixtral_te.py" destination can
remain); update the code that calls shutil.copy accordingly (referencing
shutil.copy and the filename "modeling_mixtral_te.py").
- Around line 22-37: The export_hf_checkpoint function is creating a randomly
initialized model by calling AutoModelForCausalLM.from_config; replace that with
loading the actual pretrained weights by calling
AutoModelForCausalLM.from_pretrained using the same tag (keep or remove the
separate AutoConfig.from_pretrained as needed), so model_hf holds the real
checkpoint weights before export; update any related uses of model_hf and ensure
tokenizer/config are loaded from_pretrained as well if required.
In `@bionemo-recipes/models/mixtral/modeling_mixtral_te.py`:
- Line 453: The module sets torch._dynamo.config.capture_scalar_outputs = True
at import time, which mutates global TorchDynamo state and can affect other
code; change this to a local, temporary setting or document it: in the functions
that require this behavior (e.g., the compile/optimization entry points in
modeling_mixtral_te), save the current
torch._dynamo.config.capture_scalar_outputs value, set it to True only for the
scope where you call torch.compile/torch._dynamo operations, then restore the
original value in a finally block (or use a small context manager) so the global
config isn't mutated at module load; alternatively, add a clear comment in the
module/top-level docstring explaining the global requirement if a global change
is unavoidable.
In `@bionemo-recipes/models/mixtral/requirements.txt`:
- Around line 1-5: The requirements file currently leaves torch, transformers,
transformer_engine[pytorch], torchao, and lm-eval unpinned which causes
reproducibility and compatibility issues; update the requirements (the entries
referencing torch, transformers, transformer_engine[pytorch], torchao!=0.14.0,
and lm-eval) to pin explicit compatible versions (or reference a constraints
file) that match the project's tested torch version matrix and known working
transformer_engine and torchao releases, ensuring transformer_engine is the
pytorch build and avoiding the excluded torchao 0.14.0 conflict; provide exact
version specs consistent with the root project strategy so installs are
deterministic.
In `@bionemo-recipes/models/mixtral/state.py`:
- Line 321: Fix the typo in the logger.error message: change "Enountered" to
"Encountered" in the error string used in the code that logs IndexError during
transform (the logger.error call that includes source_matches and
target_matches); update the f-string to read logger.error(f"Encountered
IndexError during transform.\n{source_matches=}\n{target_matches=}") so the
message is spelled correctly while preserving the variables and formatting.
- Around line 70-72: The dataclass/module-level defaults use mutable lists
(transforms, state_dict_ignored_entries); change their default values from [] to
None and initialize them to empty lists at construction (e.g., in __post_init__
of the dataclass or where the object is created) to avoid shared mutable state;
keep cast_dtype as Optional[torch.dtype] = None unchanged and ensure any code
referencing transforms or state_dict_ignored_entries handles the None-to-list
initialization (refer to the symbols transforms, state_dict_ignored_entries,
TransformCTX, cast_dtype).
In `@bionemo-recipes/models/mixtral/tests/common/__init__.py`:
- Around line 1-29: Remove the duplicated license header blocks and leave a
single, consistent header: keep one SPDX copyright line with the correct year
(use 2026), set SPDX-License-Identifier to "Apache-2.0", and retain the standard
Apache-2.0 license text that follows; delete the other entire license block so
only one header remains at the top of the file.
In `@bionemo-recipes/models/mixtral/tests/common/fixtures.py`:
- Around line 141-143: Replace the direct environment deletions with safe
removals using os.environ.pop for keys "NVTE_FUSED_ATTN" and "NVTE_FLASH_ATTN"
to avoid KeyError if they were never set; keep the existing update to
_attention_backends["backend_selection_requires_update"] = True unchanged so the
backend refresh still occurs. Use os.environ.pop("NVTE_FUSED_ATTN", None) and
os.environ.pop("NVTE_FLASH_ATTN", None) in the teardown (where the current del
os.environ[...] calls are) to safely remove the variables.
- Line 66: Replace the unsafe deletion del os.environ["NVTE_DEBUG"] in the
fixtures cleanup with a safe pop call so removing the NVTE_DEBUG env var cannot
raise KeyError; locate the occurrence of del os.environ["NVTE_DEBUG"] in
bionemo-recipes/models/mixtral/tests/common/fixtures.py (the teardown/cleanup
for the test fixture) and use os.environ.pop("NVTE_DEBUG", None) to silently
handle the case where the variable is already absent.
- Around line 1-30: The file contains two SPDX/Apache-2.0 license header blocks
with conflicting years; remove the second duplicate header (the block that
begins with "SPDX-FileCopyrightText: Copyright (c) 2025" and its following
Apache-2.0 license text) so only the first header (the 2026 SPDX header)
remains; ensure no extra blank lines are left at the top after removal.
In `@bionemo-recipes/models/mixtral/tests/common/README.md`:
- Around line 7-13: The fenced code block in README.md under tests/common is
missing a language marker (MD040); update the snippet in that file (the
three-line tree block within the fenced code) to include a language tag such as
"text" or "plaintext" after the opening backticks so linting passes (e.g.,
change ``` to ```text).
In `@bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py`:
- Around line 316-354: The HF API calls use the wrong kwarg name: replace the
"dtype" param with "torch_dtype" in both upstream_class.from_pretrained(...)
inside get_reference_model and AutoConfig.from_pretrained(...) inside
get_reference_model_no_weights so the requested precision (e.g., torch.bfloat16
or torch.float32) is honored; keep the existing attn_implementation and revision
logic and leave model.to("cuda") as-is.
- Around line 34-36: The HAS_DATA_CENTER_GPU probe calls
torch.cuda.get_device_name(0) unguarded which raises in CPU-only environments;
update the definition of HAS_DATA_CENTER_GPU to first check
torch.cuda.is_available() (and optionally wrap the probe in a try/except for
RuntimeError/AssertionError) and only call torch.cuda.get_device_name(0) when
CUDA is available, otherwise set HAS_DATA_CENTER_GPU to False; modify the
variable in the test module (the HAS_DATA_CENTER_GPU assignment) to implement
this guard so imports do not fail on CPU-only systems.
🧹 Nitpick comments (6)
bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py (1)
104-105: Remove redundant pad_token check.This check duplicates lines 69-70 in
get_tokenizer(). Sinceget_test_input_datacallsself.get_tokenizer()at line 97, the pad_token is already set.♻️ Proposed fix
data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, pad_to_multiple_of=pad_to_multiple_of, mlm=False, ) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - if format == "thd":bionemo-recipes/models/mixtral/modeling_mixtral_te.py (1)
380-401: Consider standard import for GenerationMixin.The dynamic import
__import__("transformers").GenerationMixinworks but is unconventional. A standard import would be clearer.♻️ Proposed fix
+from transformers import GenerationMixin + ... -class NVMixtralForCausalLM(NVMixtralPreTrainedModel, __import__("transformers").GenerationMixin): +class NVMixtralForCausalLM(NVMixtralPreTrainedModel, GenerationMixin):bionemo-recipes/models/mixtral/convert.py (1)
60-71: Consider alternative toexec()for dynamic function creation.Using
exec()works here but is harder to debug and understand. A simpler approach using*args:♻️ Proposed alternative
def _make_merge_experts_fn(num_experts: int): - """Create a merge function with the correct number of named parameters. - - The state.py transform system maps function parameter names to source keys, so we need a function - with exactly `num_experts` named parameters (weight0, weight1, ...). - """ - param_names = [f"weight{i}" for i in range(num_experts)] - code = f"def merge_experts({', '.join(param_names)}):\n return torch.stack([{', '.join(param_names)}])" - local_ns = {"torch": torch} - exec(code, local_ns) - return local_ns["merge_experts"] + """Create a merge function that stacks expert weights.""" + def merge_experts(*weights): + if len(weights) != num_experts: + raise ValueError(f"Expected {num_experts} weights, got {len(weights)}") + return torch.stack(weights) + return merge_expertsNote: This assumes the state transform system supports
*args. If named parameters are strictly required by the transform system, the currentexec()approach is acceptable but should be documented.bionemo-recipes/models/mixtral/tests/common/fixtures.py (1)
62-63: Remove redundantimport os.
osis already imported at line 33.♻️ Proposed fix
def use_te_debug(): """Auto-use fixture to enable TransformerEngine debugging. This fixture automatically enables debug mode for TransformerEngine in all tests for better error messages. """ - import os - os.environ["NVTE_DEBUG"] = "1"bionemo-recipes/models/mixtral/state.py (2)
161-161: Use logger instead of print for unexpected keys.Consistency with the rest of the module which uses
logger.debug,logger.warning, etc.♻️ Proposed fix
- print(f"Unexpected key: {name} not in target model but is in source model.") + logger.warning(f"Unexpected key: {name} not in target model but is in source model.")
265-265: Avoid global side effect fromnp.set_printoptions.This modifies global numpy print settings on every transform call. If needed for debugging, consider using a context manager or removing it.
♻️ Proposed fix
- np.set_printoptions(threshold=10)If this is needed for debugging specific issues, use
np.printoptions(threshold=10)as a context manager around the specific debug output instead.
bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py
Outdated
Show resolved
Hide resolved
|
@coderabbitai resolve |
✅ Actions performedComments resolved. |
6574f52 to
58f9af5
Compare
1dba9f4 to
24d1e2c
Compare
|
@coderabbitai run |
|
🧠 Learnings used✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 8
🧹 Nitpick comments (6)
bionemo-recipes/models/mixtral/requirements.txt (1)
1-1: Movelm-evalto a separate test requirements file.
lm-evalis annotated# For testingbut lives in the primaryrequirements.txt, inflating the runtime install surface for all downstream consumers of this package. The repository already follows this pattern at the root level with separaterequirements-test.txtandrequirements-dev.txtfiles. Consider creatingrequirements-test.txtin this directory to align with established project conventions and isolate test-only dependencies.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/models/mixtral/requirements.txt` at line 1, Remove the test-only dependency "lm-eval" from the mixtral requirements.txt and create a new test requirements file named requirements-test.txt that contains "lm-eval # For testing"; update any local references in the mixtral package (e.g., any setup or CI job that installed mixtral/requirements.txt) to install mixtral/requirements-test.txt for test runs instead of the main requirements.txt so runtime installs remain minimal.bionemo-recipes/models/mixtral/tests/common/fixtures.py (1)
40-51: Redundantimport osinside fixture.
osis already imported at module level (line 18). The local re-import on line 47 is unnecessary.♻️ Proposed fix
`@pytest.fixture`(autouse=True) def use_te_debug(): """Auto-use fixture to enable TransformerEngine debugging. This fixture automatically enables debug mode for TransformerEngine in all tests for better error messages. """ - import os - os.environ["NVTE_DEBUG"] = "1" yield os.environ.pop("NVTE_DEBUG", None)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/models/mixtral/tests/common/fixtures.py` around lines 40 - 51, The fixture use_te_debug unnecessarily re-imports os locally; remove the local "import os" statement inside the use_te_debug function so it uses the module-level os import, leaving the rest of the fixture (os.environ["NVTE_DEBUG"] = "1", yield, os.environ.pop("NVTE_DEBUG", None)) unchanged.bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py (4)
600-622: Redundant local import ofAutoConfig.Line 606 imports
AutoConfigfromtransformers, but it's already imported at module level on line 31.♻️ Proposed fix
def test_convert_config(self): """Test that config can be converted between HF and TE formats.""" upstream_id = self.get_upstream_model_id() revision = self.get_upstream_model_revision() - # Load HF config - from transformers import AutoConfig - kwargs = {}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py` around lines 600 - 622, The test_convert_config function has a redundant local import of AutoConfig; remove the line "from transformers import AutoConfig" inside test_convert_config and rely on the module-level AutoConfig import instead so the test uses the already-imported AutoConfig symbol; update test_convert_config to instantiate hf_config via AutoConfig.from_pretrained(upstream_id, **kwargs) without a local import.
85-99: Docstring example references non-existentBioNeMoModelTesterclass.The example on line 87 uses
ESM2ModelTester(BioNeMoModelTester)but the actual class name isBaseModelTest. This mirrors the same issue in__init__.py's docstring.♻️ Proposed fix
Example: ```python - class ESM2ModelTester(BioNeMoModelTester): + class ESM2ModelTester(BaseModelTest): def get_model_class(self): return NVEsmForMaskedLM🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py` around lines 85 - 99, Update the docstring example that currently shows "class ESM2ModelTester(BioNeMoModelTester):" to use the correct base class name "BaseModelTest" (replace BioNeMoModelTester with BaseModelTest) so the sample subclass matches the actual abstract base class; adjust any mirrored occurrences in module docstrings (e.g., __init__.py) to use BaseModelTest as well to keep examples consistent with the implemented class.
630-667:get_converted_te_model_checkpointre-runs full conversion on every call.
get_converted_te_modelcallsget_converted_te_model_checkpoint()each time, which downloads the HF model, converts it, and saves to disk — even though the checkpoint path is deterministic within a class's tmp dir. Tests liketest_golden_valuesandtest_golden_values_thdboth callget_converted_te_model, causing redundant work.Consider caching the checkpoint path (e.g., checking if the directory already exists before re-converting), or using
functools.lru_cache/ a class-level sentinel.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py` around lines 630 - 667, The test helper currently reconverts and re-saves the TE model every time get_converted_te_model_checkpoint() is called; update get_converted_te_model_checkpoint to check for an existing checkpoint under the deterministic path (self._tmp_dir / "converted_te_model") and return it immediately if present, otherwise perform the convert/save flow; alternatively, add a simple cache (e.g., a class or instance attribute sentinel like self._converted_te_checkpoint) or wrap the checkpoint getter with functools.lru_cache to avoid re-running get_reference_model(), get_hf_to_te_converter(), convert_fn, and model_te.save_pretrained on subsequent calls invoked by get_converted_te_model() and tests.
259-294: AddGroupedLinearinitialization and FP8 quantization verification to theverify_model_parameters_initialized_correctlyfunction.The Mixtral MoE block uses
transformer_engine.pytorch.GroupedLinearforexperts_gate_upandexperts_downmodules, but these are not validated in the initialization checks. Since GroupedLinear modules are initialized with the sameinit_methodas regular Linear modules and support FP8 quantization, they should be included in the verification logic to ensure MoE expert weights are correctly initialized and quantized.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py` around lines 259 - 294, The tests currently only check torch.nn.Embedding and transformer_engine.pytorch.Linear in verify_model_parameters_initialized_correctly; add support for transformer_engine.pytorch.GroupedLinear (used by experts_gate_up and experts_down) by treating it like transformer_engine.pytorch.Linear: verify weight mean is ~0, std equals config.initializer_range with the same tolerances, check bias zeros if present, and perform the FP8 quantization assertion (skip if name is in model._tied_weights_keys or matches model._do_not_quantize patterns). Update the logic that asserts isinstance(module.weight, QuantizedTensor) so it also accepts GroupedLinear modules.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@bionemo-recipes/models/esm2/src/esm/collator.py`:
- Around line 286-291: The ValueError message in the TokenPackingDataset code
path inside esm.collator.py concatenates adjacent string literals so "to" and
"ensure" run together; update the error string in the raise ValueError in the
block referencing self.max_tokens_per_batch (inside the token packing / collator
logic) to include a space between "to" and "ensure" (either by adding a trailing
space in the previous literal or combining into one f-string) so the message
reads "...dataset to ensure all samples fit within max_tokens_per_batch."
In `@bionemo-recipes/models/esm2/tests/common/test_modeling_common.py`:
- Around line 349-358: get_reference_model_no_weights currently passes explicit
dtype and revision plus **kwargs into AutoConfig.from_pretrained which can raise
TypeError if callers supply those keys; fix by copying kwargs and removing any
'dtype' and 'revision' keys (e.g., kwargs_copy = dict(kwargs);
kwargs_copy.pop("dtype", None); kwargs_copy.pop("revision", None)) then call
AutoConfig.from_pretrained(..., dtype=torch.float32,
revision=self.get_upstream_model_revision(), **kwargs_copy); also update the
ESM2 test subclass override of get_reference_model_no_weights to accept **kwargs
and forward them to super() (e.g., def get_reference_model_no_weights(self,
**kwargs): return super().get_reference_model_no_weights(**kwargs)) so
signatures match.
In `@bionemo-recipes/models/llama3/collator.py`:
- Around line 285-327: The validation currently uses raw sample_length; update
TokenPackingDataset logic to compute padded_len =
self._padded_len(sample_length) (taking pad_sequences_to_be_divisible_by into
account) and validate that padded_len <= self.max_tokens_per_batch before
proceeding, raising the same ValueError if it exceeds; then use padded_len (not
raw sample_length) for current_length accumulation, comparisons (==, >) and when
computing tokens_in_batch/tokens_available and when calling
_split_sample_by_num_tokens to ensure split logic and non-split branches always
respect max_tokens_per_batch and reuse the padded length consistently.
In `@bionemo-recipes/models/mixtral/.ruff.toml`:
- Line 1: The child ruff config (the extend = "../.ruff.toml" in
mixtral/.ruff.toml) points at a missing parent config; create the parent
.ruff.toml in the models directory that the child extends and populate it with
the shared linting rules/settings (e.g., select/ignore codes, line-length,
target-python, etc.) so model subdirectories (mixtral and llama3) inherit
consistent rules; ensure the filename matches ".ruff.toml" and the child extend
path remains "../.ruff.toml".
In `@bionemo-recipes/models/mixtral/modeling_mixtral_te.py`:
- Around line 318-342: The code assumes input_ids exists when packing and when
computing cache lengths; change uses of input_ids to prefer inputs_embeds or
hidden_states so None input_ids won't crash and lengths are correct: when
computing padded_seq_len replace input_ids.size(1) with hidden_states.size(1)
(or inputs_embeds.size(1) if you prefer explicit check), and when computing
lengths for past_key_values use attention_mask.shape == hidden_states.shape (or
compare to inputs_embeds.shape when input_ids is None) and use input batch size
from hidden_states.shape[0] instead of input_ids.shape[0]; update the branches
around _unpad_input, padded_seq_len, and the lengths calculation that calls
past_key_values.pre_step to reference hidden_states/inputs_embeds accordingly.
In `@bionemo-recipes/models/mixtral/state.py`:
- Around line 458-481: The state_transform function's docstring uses NumPy-style
section separators (e.g., "Returns: -------", "Examples: --------") instead of
Google-style; update the docstring for state_transform to Google style by
removing the underline separators, using plain "Returns:" and "Examples:"
headings followed by an indented descriptive paragraph or block, ensure
parameter descriptions remain under "Args:" in Google format and example code
stays in an indented literal block under "Examples:", and keep the existing
content same while conforming to pydocstyle Google conventions.
In `@bionemo-recipes/models/mixtral/tests/common/__init__.py`:
- Around line 16-37: The module docstring references non-existent classes
BioNeMoModelTester and BioNeMoModelTest; update the docstring and example to use
the actual exported class BaseModelTest and keep TestTolerances, so change
occurrences of BioNeMoModelTester/BioNeMoModelTest in the descriptive list and
the example import/usage to BaseModelTest and TestTolerances (e.g., "from
tests.common import BaseModelTest, TestTolerances" and "class
ESM2ModelTester(BaseModelTest):") to ensure the example matches the real API.
In `@bionemo-recipes/recipes/esm2_peft_te/collator.py`:
- Around line 285-327: The initial sample length check uses raw sample_length
rather than its padded size, so replace the validation to compute padded =
self._padded_len(sample_length) (respecting pad_sequences_to_be_divisible_by)
and compare padded against self.max_tokens_per_batch; raise the ValueError if
padded > max_tokens_per_batch. Also ensure code paths that start a new batch or
set current_length (the branches using self._padded_len(sample_length), the
split_samples branch, and where you assign current_length =
self._padded_len(len(samples[0]["input_ids"]))) consistently use the same padded
value for this incoming sample to avoid creating oversized batches (adjust
tokens_available logic to use tokens capacity versus padded size).
---
Duplicate comments:
In `@bionemo-recipes/models/mixtral/modeling_mixtral_te.py`:
- Around line 453-458: Remove the module-level mutation of
torch._dynamo.config.capture_scalar_outputs and instead toggle it only around
the actual compile sites: when calling torch.compile for the functions like
_pad_input and _unpad_input, save the current
torch._dynamo.config.capture_scalar_outputs value, set it to True, call
torch.compile(...), then restore the original value; ensure no import-time side
effects remain and all references to torch.compile in this module use this
save/set/restore pattern.
In `@bionemo-recipes/models/mixtral/requirements.txt`:
- Around line 1-5: The requirements.txt entries (lm-eval, torch,
torchao!=0.14.0, transformer_engine[pytorch], transformers) are unpinned and
harm reproducibility; update each package line to a specific, tested version or
a strict version range (e.g., >= and < upper bound) instead of bare names or
only an exclusion—replace the current entries with pinned versions or ranges
based on the project’s tested environment (for example pin torch and
transformers to the exact tested versions and change torchao!=0.14.0 to a
concrete range like torchao>=X.Y.Z,<0.14.0 or >=0.14.1,<X.Y.Z) so CI/builds
reproduce reliably and avoid allowing incompatible older releases.
In `@bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py`:
- Around line 349-358: The AutoConfig.from_pretrained call in
get_reference_model_no_weights is passing a non-existent dtype kwarg which just
sets an arbitrary config attribute; change the parameter name to torch_dtype so
the config correctly reflects the intended precision. Locate
get_reference_model_no_weights and replace dtype=torch.float32 with
torch_dtype=torch.float32 in the AutoConfig.from_pretrained(...) invocation to
match other usages and ensure consistency with model precision handling.
- Around line 338-347: The kwargs passed to upstream_class.from_pretrained use
the wrong key "dtype" instead of HuggingFace's expected "torch_dtype", so models
are loaded in default precision; update the kwargs construction before the
upstream_class.from_pretrained(upstream_id, **kwargs) call to set
kwargs["torch_dtype"] = dtype (and remove or stop setting kwargs["dtype"]) so
the model is instantiated with the requested precision; ensure this change
affects the same variables shown (kwargs, dtype, upstream_class.from_pretrained,
upstream_id).
---
Nitpick comments:
In `@bionemo-recipes/models/mixtral/requirements.txt`:
- Line 1: Remove the test-only dependency "lm-eval" from the mixtral
requirements.txt and create a new test requirements file named
requirements-test.txt that contains "lm-eval # For testing"; update any local
references in the mixtral package (e.g., any setup or CI job that installed
mixtral/requirements.txt) to install mixtral/requirements-test.txt for test runs
instead of the main requirements.txt so runtime installs remain minimal.
In `@bionemo-recipes/models/mixtral/tests/common/fixtures.py`:
- Around line 40-51: The fixture use_te_debug unnecessarily re-imports os
locally; remove the local "import os" statement inside the use_te_debug function
so it uses the module-level os import, leaving the rest of the fixture
(os.environ["NVTE_DEBUG"] = "1", yield, os.environ.pop("NVTE_DEBUG", None))
unchanged.
In `@bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py`:
- Around line 600-622: The test_convert_config function has a redundant local
import of AutoConfig; remove the line "from transformers import AutoConfig"
inside test_convert_config and rely on the module-level AutoConfig import
instead so the test uses the already-imported AutoConfig symbol; update
test_convert_config to instantiate hf_config via
AutoConfig.from_pretrained(upstream_id, **kwargs) without a local import.
- Around line 85-99: Update the docstring example that currently shows "class
ESM2ModelTester(BioNeMoModelTester):" to use the correct base class name
"BaseModelTest" (replace BioNeMoModelTester with BaseModelTest) so the sample
subclass matches the actual abstract base class; adjust any mirrored occurrences
in module docstrings (e.g., __init__.py) to use BaseModelTest as well to keep
examples consistent with the implemented class.
- Around line 630-667: The test helper currently reconverts and re-saves the TE
model every time get_converted_te_model_checkpoint() is called; update
get_converted_te_model_checkpoint to check for an existing checkpoint under the
deterministic path (self._tmp_dir / "converted_te_model") and return it
immediately if present, otherwise perform the convert/save flow; alternatively,
add a simple cache (e.g., a class or instance attribute sentinel like
self._converted_te_checkpoint) or wrap the checkpoint getter with
functools.lru_cache to avoid re-running get_reference_model(),
get_hf_to_te_converter(), convert_fn, and model_te.save_pretrained on subsequent
calls invoked by get_converted_te_model() and tests.
- Around line 259-294: The tests currently only check torch.nn.Embedding and
transformer_engine.pytorch.Linear in
verify_model_parameters_initialized_correctly; add support for
transformer_engine.pytorch.GroupedLinear (used by experts_gate_up and
experts_down) by treating it like transformer_engine.pytorch.Linear: verify
weight mean is ~0, std equals config.initializer_range with the same tolerances,
check bias zeros if present, and perform the FP8 quantization assertion (skip if
name is in model._tied_weights_keys or matches model._do_not_quantize patterns).
Update the logic that asserts isinstance(module.weight, QuantizedTensor) so it
also accepts GroupedLinear modules.
bionemo-recipes/models/esm2/tests/common/test_modeling_common.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Peter St. John <pstjohn@nvidia.com> rebase on moe fix pr Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
b8c999e to
bd5543c
Compare
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
bd5543c to
eda3894
Compare
Summary
Adds a Mixtral Mixture-of-Experts model to bionemo-recipes/models/mixtral/ using TransformerEngine, following the same pattern as the existing llama3
model.
MoE Implementation
The core MoE block (NVMixtralSparseMoeBlock) uses:
Weight conversion handles the structural difference between HF's stacked 3D expert tensors ([num_experts, out, in]) and TE's per-expert GroupedLinear
weights (weight0, weight1, ...).
Base test class improvements
Summary by CodeRabbit
New Features
Bug Fixes
Tests