Skip to content

[model] feat: add homogeneous MIMO mode for standard pretrain path#2695

Open
aroshanghias-nvd wants to merge 10 commits intoNVIDIA-NeMo:mimo/phase4-trainingfrom
aroshanghias-nvd:mimo/homogeneous-baseline
Open

[model] feat: add homogeneous MIMO mode for standard pretrain path#2695
aroshanghias-nvd wants to merge 10 commits intoNVIDIA-NeMo:mimo/phase4-trainingfrom
aroshanghias-nvd:mimo/homogeneous-baseline

Conversation

@aroshanghias-nvd
Copy link
Copy Markdown
Contributor

Summary

Enables MimoModelProvider to work with the standard pretrain() entry point when mimo_parallelism_config=None (homogeneous mode), where all modules (LLM + vision encoders) run on every rank as a single model. This lets research teams run ablation experiments comparing homogeneous vs heterogeneous MIMO without changing the model provider.

Key changes:

  • __getattr__/__setattr__ on MimoModelProvider — proxy reads/writes to the nested TransformerConfig in homogeneous mode; raises an actionable AttributeError in heterogeneous mode if pretrain() is used incorrectly.
  • _MimoConfigProxy — wraps model.config so get_model_config(model[0]) resolves provider-level fields (seq_length, make_vocab_size_divisible_by, share_embeddings_and_output_weights, etc.) that don't exist on TransformerConfig.
  • Stack-compatibility fields on MimoModelProvider: vocab_size, seq_length, make_vocab_size_divisible_by, should_pad_vocab, share_embeddings_and_output_weights.
  • setup.py — route post-model writes (timers, _update_model_config_funcs) through get_model_config(model[0]) instead of cfg.model. No behavior change for GPTModelProvider.
  • loaders.py — narrow MIMO data loading path to heterogeneous mode only.
  • pretrain_mimo.py — guard against mimo_parallelism_config=None.
  • Homogeneous delegation in provide_distributed_model(), initialize_model_parallel(), provide(), and finalize() (enforces PP=1).

Test plan

  • Unit tests in test_mimo_provider.py: __getattr__/__setattr__ proxy, heterogeneous guard, finalize PP=1 enforcement, homogeneous initialize_model_parallel delegation
  • E2E test test_mimo_homogeneous_e2e.py: full pretrain() loop with GPTModel + CLIPViTModel, TP=4 DP=2 on 8 GPUs — PASSED on all ranks
  • E2E regression: existing heterogeneous test_mimo_training_e2e.pyPASSED on all ranks

Align train_step_mimo() with train.py by computing
increment = num_microbatches * micro_batch_size * data_parallel_size
and passing it to scheduler.step(increment=increment) instead of the
bare scheduler.step() call.

Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
@aroshanghias-nvd aroshanghias-nvd force-pushed the mimo/homogeneous-baseline branch 2 times, most recently from 0be4c80 to fa02c75 Compare March 9, 2026 16:50
kamran-nvidia and others added 6 commits March 9, 2026 10:36
… support; add training script for heterogeneous LLaVA

Signed-off-by: Kamran Jafari <kjafarisadeg@nvidia.com>
- Fix copyright year 2025 to 2026 in test_mimo_training_llava.py
- Replace hardcoded /tmp/claude-0/mimo_rank_logs with MIMO_LOG_DIR env var
- Restore _build_config defaults to train_iters=2/global_batch_size=1 in e2e test
- Remove redundant timer monkey-patch from setup_mimo; GlobalState.timers already handles it
- Replace fragile pickle-over-GPU P2P loss relay with broadcast_object_list

Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
[MIMO] Add Heterogeneous LLaVA training script (PR #2)

- Loss mask support in MimoDataset and mimo_collate_fn
- Heterogeneous LLaVA e2e test (Vicuna-7B + CLIP ViT-L/14, 8 GPUs)
- Auto-create per-module LR schedulers when none provided
- Fix skipped_iter propagation and get_lr() call signature
- Guard training_log behind skip_train_metrics_log
- Log iteration-time directly for MIMO models
- HFMimoDatasetProvider: add hf_data_files and preprocess_fn fields
- Fix review issues: copyright year, log dir, e2e smoke test defaults,
  redundant timer patch, replace P2P pickle relay with broadcast_object_list
The mimo_collate_fn now requires loss_mask in each batch item. Update
make_sample() to include it and add loss_mask shape assertions.

Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
… logging rank

Tensors received via broadcast_object_list carry the source rank's CUDA
device index. Move them to the current device so training_log arithmetic
doesn't hit a cross-device RuntimeError on the logging rank.

Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
Enable MimoModelProvider to work with the standard pretrain() entry
point when mimo_parallelism_config=None (homogeneous mode), where all
modules run on every rank as a single model.

Key changes:
- Add __getattr__/__setattr__ to proxy reads/writes to the nested
  TransformerConfig in homogeneous mode, with a heterogeneous guard
  that raises an actionable error if pretrain() is used incorrectly.
- Add _MimoConfigProxy to wrap model.config so get_model_config()
  resolves provider-level fields (seq_length, share_embeddings_and_
  output_weights, etc.) that don't exist on TransformerConfig.
- Add stack-compatibility fields: vocab_size, seq_length,
  make_vocab_size_divisible_by, should_pad_vocab,
  share_embeddings_and_output_weights.
- Route post-model writes in setup.py through get_model_config()
  instead of cfg.model so they reach the actual model config.
- Narrow loaders.py MIMO data path to heterogeneous mode only.
- Add guards in pretrain_mimo()/setup_mimo() rejecting homogeneous.
- Add homogeneous delegation in provide_distributed_model(),
  initialize_model_parallel(), provide(), and finalize() (PP=1).

Verified: both homogeneous and heterogeneous e2e tests pass on 8 GPUs.
Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
@aroshanghias-nvd aroshanghias-nvd force-pushed the mimo/homogeneous-baseline branch from fa02c75 to 935839c Compare March 9, 2026 18:59
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 9, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@yaoyu-33 yaoyu-33 added area:model Model implementations and HF bridge logic needs-review PR is ready for code review and waiting on a reviewer labels Mar 10, 2026
@yaoyu-33 yaoyu-33 changed the title feat(mimo): add homogeneous MIMO mode for standard pretrain() path [model] feat: add homogeneous MIMO mode for standard pretrain path Mar 10, 2026
@yaoyu-33 yaoyu-33 added area:model Model implementations and HF bridge logic needs-review PR is ready for code review and waiting on a reviewer and removed area:model Model implementations and HF bridge logic needs-review PR is ready for code review and waiting on a reviewer labels Mar 11, 2026
kamran-nvidia and others added 3 commits March 11, 2026 12:08
… test

Signed-off-by: Kamran Jafari <kjafarisadeg@nvidia.com>
Signed-off-by: Kamran Jafari <kjafarisadeg@nvidia.com>
Add homogeneous MIMO LLaVA training support and e2e test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:model Model implementations and HF bridge logic needs-review PR is ready for code review and waiting on a reviewer

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants