[model] feat: add homogeneous MIMO mode for standard pretrain path#2695
Open
aroshanghias-nvd wants to merge 10 commits intoNVIDIA-NeMo:mimo/phase4-trainingfrom
Open
[model] feat: add homogeneous MIMO mode for standard pretrain path#2695aroshanghias-nvd wants to merge 10 commits intoNVIDIA-NeMo:mimo/phase4-trainingfrom
aroshanghias-nvd wants to merge 10 commits intoNVIDIA-NeMo:mimo/phase4-trainingfrom
Conversation
Align train_step_mimo() with train.py by computing increment = num_microbatches * micro_batch_size * data_parallel_size and passing it to scheduler.step(increment=increment) instead of the bare scheduler.step() call. Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
0be4c80 to
fa02c75
Compare
… support; add training script for heterogeneous LLaVA Signed-off-by: Kamran Jafari <kjafarisadeg@nvidia.com>
- Fix copyright year 2025 to 2026 in test_mimo_training_llava.py - Replace hardcoded /tmp/claude-0/mimo_rank_logs with MIMO_LOG_DIR env var - Restore _build_config defaults to train_iters=2/global_batch_size=1 in e2e test - Remove redundant timer monkey-patch from setup_mimo; GlobalState.timers already handles it - Replace fragile pickle-over-GPU P2P loss relay with broadcast_object_list Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
[MIMO] Add Heterogeneous LLaVA training script (PR #2) - Loss mask support in MimoDataset and mimo_collate_fn - Heterogeneous LLaVA e2e test (Vicuna-7B + CLIP ViT-L/14, 8 GPUs) - Auto-create per-module LR schedulers when none provided - Fix skipped_iter propagation and get_lr() call signature - Guard training_log behind skip_train_metrics_log - Log iteration-time directly for MIMO models - HFMimoDatasetProvider: add hf_data_files and preprocess_fn fields - Fix review issues: copyright year, log dir, e2e smoke test defaults, redundant timer patch, replace P2P pickle relay with broadcast_object_list
The mimo_collate_fn now requires loss_mask in each batch item. Update make_sample() to include it and add loss_mask shape assertions. Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
… logging rank Tensors received via broadcast_object_list carry the source rank's CUDA device index. Move them to the current device so training_log arithmetic doesn't hit a cross-device RuntimeError on the logging rank. Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
Enable MimoModelProvider to work with the standard pretrain() entry point when mimo_parallelism_config=None (homogeneous mode), where all modules run on every rank as a single model. Key changes: - Add __getattr__/__setattr__ to proxy reads/writes to the nested TransformerConfig in homogeneous mode, with a heterogeneous guard that raises an actionable error if pretrain() is used incorrectly. - Add _MimoConfigProxy to wrap model.config so get_model_config() resolves provider-level fields (seq_length, share_embeddings_and_ output_weights, etc.) that don't exist on TransformerConfig. - Add stack-compatibility fields: vocab_size, seq_length, make_vocab_size_divisible_by, should_pad_vocab, share_embeddings_and_output_weights. - Route post-model writes in setup.py through get_model_config() instead of cfg.model so they reach the actual model config. - Narrow loaders.py MIMO data path to heterogeneous mode only. - Add guards in pretrain_mimo()/setup_mimo() rejecting homogeneous. - Add homogeneous delegation in provide_distributed_model(), initialize_model_parallel(), provide(), and finalize() (PP=1). Verified: both homogeneous and heterogeneous e2e tests pass on 8 GPUs. Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
fa02c75 to
935839c
Compare
… test Signed-off-by: Kamran Jafari <kjafarisadeg@nvidia.com>
Signed-off-by: Kamran Jafari <kjafarisadeg@nvidia.com>
Add homogeneous MIMO LLaVA training support and e2e test
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Enables
MimoModelProviderto work with the standardpretrain()entry point whenmimo_parallelism_config=None(homogeneous mode), where all modules (LLM + vision encoders) run on every rank as a single model. This lets research teams run ablation experiments comparing homogeneous vs heterogeneous MIMO without changing the model provider.Key changes:
__getattr__/__setattr__onMimoModelProvider— proxy reads/writes to the nestedTransformerConfigin homogeneous mode; raises an actionableAttributeErrorin heterogeneous mode ifpretrain()is used incorrectly._MimoConfigProxy— wrapsmodel.configsoget_model_config(model[0])resolves provider-level fields (seq_length,make_vocab_size_divisible_by,share_embeddings_and_output_weights, etc.) that don't exist onTransformerConfig.MimoModelProvider:vocab_size,seq_length,make_vocab_size_divisible_by,should_pad_vocab,share_embeddings_and_output_weights.setup.py— route post-model writes (timers,_update_model_config_funcs) throughget_model_config(model[0])instead ofcfg.model. No behavior change forGPTModelProvider.loaders.py— narrow MIMO data loading path to heterogeneous mode only.pretrain_mimo.py— guard againstmimo_parallelism_config=None.provide_distributed_model(),initialize_model_parallel(),provide(), andfinalize()(enforces PP=1).Test plan
test_mimo_provider.py:__getattr__/__setattr__proxy, heterogeneous guard, finalize PP=1 enforcement, homogeneousinitialize_model_paralleldelegationtest_mimo_homogeneous_e2e.py: fullpretrain()loop with GPTModel + CLIPViTModel, TP=4 DP=2 on 8 GPUs — PASSED on all rankstest_mimo_training_e2e.py— PASSED on all ranks