Skip to content

[Research] FedUMM: Federated Learning for Unified Multimodal Models#4158

Merged
ZiyueXu77 merged 11 commits intoNVIDIA:mainfrom
rollingsu:feat/fedumm
Mar 31, 2026
Merged

[Research] FedUMM: Federated Learning for Unified Multimodal Models#4158
ZiyueXu77 merged 11 commits intoNVIDIA:mainfrom
rollingsu:feat/fedumm

Conversation

@rollingsu
Copy link
Copy Markdown
Contributor

@rollingsu rollingsu commented Feb 9, 2026

Fixes # .

Description

A few sentences describing the changes proposed in this pull request.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • [ x] Quick tests passed locally by running ./runtest.sh.
  • [ x] In-line docstrings updated.
  • [ x] Documentation updated.

@rollingsu
Copy link
Copy Markdown
Contributor Author

official code for fedumm, first fed learning pipeline using nvflare to train unified multimodal models

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 9, 2026

Greptile Summary

This PR adds the FedUMM research example — a federated learning framework for fine-tuning vision-language models (BLIP-VQA and JanusPro) using LoRA adapters with FedAvg, implemented on top of NVFlare. The PR addresses two prior review comments: the empty-dataloader guard now correctly raises ValueError in common.py, and the evaluation metric key has been renamed to val_accuracy consistently across fl_client.py and job.py.

Key findings:

  • JanusPro labels shape mismatch (P1): In januspro_backend.py, labels has shape [B, max_a_len] (default 32), but inputs_embeds from prepare_inputs_embeds has a much longer sequence dimension after image-token expansion. The CausalLM forward pass requires labels to match the input sequence length; the current mismatch causes a RuntimeError on the very first JanusPro training step. The BLIP path is unaffected.

  • weight_decay not set in AdamW (P1): Both fl_client.py and local_train.py create torch.optim.AdamW(...) without a weight_decay argument, silently using PyTorch's default of 0.01. The README's hyperparameter table states AdamW (wd=0.05) hardcoded, so paper results cannot be reproduced with the current code.

  • The aiohttp import in local_train.py without a corresponding requirements.txt entry was already flagged in a prior thread and remains unresolved.

Confidence Score: 3/5

Not safe to merge — the JanusPro training path will crash at runtime due to a tensor shape mismatch, and the weight_decay discrepancy prevents paper result reproduction.

Two P1 findings remain: (1) a definite runtime crash in januspro_backend.py due to labels vs inputs_embeds sequence-length mismatch, and (2) weight_decay=0.01 (PyTorch default) used instead of documented 0.05 across both training entry points. The BLIP path works correctly, and previously flagged issues (empty dataloader, wrong metric key) have been addressed.

research/fedumm/src/januspro_backend.py (label shape bug — will crash), research/fedumm/src/fl_client.py and research/fedumm/src/local_train.py (weight_decay mismatch)

Important Files Changed

Filename Overview
research/fedumm/src/januspro_backend.py JanusPro backend with LoRA fine-tuning — labels tensor shape [B, max_a_len] will mismatch inputs_embeds shape [B, full_seq_len, hidden] causing a runtime crash in train_step
research/fedumm/src/fl_client.py Unified NVFlare FL client — metric key corrected to val_accuracy, but AdamW optimizer created without weight_decay, using PyTorch default 0.01 vs paper-documented 0.05
research/fedumm/src/common.py Shared helpers — empty-dataloader guard now raises ValueError as requested; Dirichlet partitioning and parameter exchange utilities look correct
research/fedumm/src/local_train.py Centralized baseline — aiohttp dependency still missing from requirements.txt (pre-existing); AdamW optimizer missing weight_decay=0.05 as documented in README
research/fedumm/src/blip_backend.py BLIP-VQA backend with LoRA on text_encoder/text_decoder — dataset, collate, train_step, and evaluate paths look correct
research/fedumm/job.py FedJob config — IntimeModelSelector updated to use val_accuracy key matching renamed metric; FedAvg + ScriptRunner setup looks correct
research/fedumm/src/model_registry.py Simple global registry with register/get/list helpers — no issues found
research/fedumm/scripts/launch_blip.sh Conda env activation wrapper for BLIP — CONDA_PREFIX fallback to HOME/miniconda3 is reasonable for most setups
research/fedumm/scripts/launch_januspro.sh Conda env activation wrapper for JanusPro — same CONDA_PREFIX pattern as launch_blip.sh
research/fedumm/scripts/setup_envs.sh One-click conda env setup — clones Janus repo to /tmp and installs in editable mode; looks correct
research/fedumm/scripts/slurm_run.sh SLURM batch job template — runs centralized baseline then FL simulator; template comments guide users to adapt cluster-specific settings
research/fedumm/src/init.py Auto-registers BLIP unconditionally and JanusPro optionally (ImportError swallowed) — correct approach for optional dependencies
research/fedumm/requirements.txt Minimal BLIP-only requirements — aiohttp still missing (flagged previously), all other core deps present
research/fedumm/envs/env_blip.yml Conda env spec for BLIP — all required deps present with reasonable version floors
research/fedumm/envs/env_januspro.yml Conda env spec for JanusPro — includes accelerate, sentencepiece, protobuf; Janus package installed separately by setup_envs.sh
research/fedumm/README.md Comprehensive README — claims AdamW wd=0.05 is hardcoded but actual code uses PyTorch default 0.01; otherwise thorough documentation

Sequence Diagram

sequenceDiagram
    participant job as job.py
    participant server as Server (FedAvg + ModelSelector)
    participant client as FL Client (fl_client.py)
    participant backend as VLM Backend

    job->>server: configure FedAvg and IntimeModelSelector
    job->>client: ScriptRunner with script_args

    client->>client: load and shard dataset (Dirichlet or IID)
    client->>backend: build_model_and_processor with LoRA config
    backend-->>client: model and processor

    loop FL Rounds
        server->>client: FLModel with global LoRA weights
        client->>client: load_trainable_params into model

        alt Evaluate round
            client->>backend: evaluate on eval_loader
            backend-->>client: val_accuracy
            client->>server: FLModel metrics val_accuracy
            server->>server: ModelSelector tracks best checkpoint
        else Train round
            loop local_epochs
                client->>backend: train_one_epoch
                backend-->>client: avg loss
            end
            client->>backend: evaluate on eval_loader
            backend-->>client: local_acc
            client->>server: FLModel with LoRA delta and metrics
            server->>server: FedAvg aggregates LoRA weights
        end
    end
Loading

Comments Outside Diff (1)

  1. research/fedumm/src/januspro_backend.py, line 158-179 (link)

    labels shape mismatch causes runtime crash in JanusPro training

    labels is constructed from the answer-only tokenization with shape [B, max_a_len] (default 32), but inputs_embeds produced by model.prepare_inputs_embeds(...) has shape [B, full_seq_len, hidden_dim] — full_seq_len includes the image token expansion and is far larger than 32.

    When model.language_model(inputs_embeds=inputs_embeds, ..., labels=labels) is called, HuggingFace's causal LM loss shifts logits and labels along the sequence dimension before computing cross-entropy:

    shift_logits = logits[..., :-1, :]  # [B, full_seq_len-1, vocab_size]
    shift_labels = labels[..., 1:]      # [B, max_a_len-1]  ← different size

    This produces a size mismatch in CrossEntropyLoss and raises a RuntimeError immediately on the first training step with JanusPro.

    The fix is to build labels with the same sequence length as the full input, placing -100 (ignore index) at every position except the answer tokens. For example, construct a label tensor of shape [seq_len] filled with -100 and copy the answer token IDs into the tail positions corresponding to the <|Assistant|> turn.

Reviews (7): Last reviewed commit: "Merge branch 'main' into feat/fedumm" | Re-trigger Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

16 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment thread research/fedumm/src/common.py
Comment thread research/fedumm/src/fl_client.py
Comment thread research/fedumm/job.py
@ZiyueXu77 ZiyueXu77 changed the title Add files via upload Implementation of research FedUMM: Federated Learning for Unified Multimodal Models Feb 9, 2026
Comment thread research/fedumm/scripts/launch_blip.sh
Copy link
Copy Markdown
Collaborator

@ZiyueXu77 ZiyueXu77 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @rollingsu! Overall looks good and aligns well with the paper. Do we have experiments that have explicitly modeled the "missing modality" (as shown in Fig.3)? One further question: could you point me to the part of the code that handles the "shared alignment token is introduced to stabilize cross-client updates and maintain semantic consistency across modalities."?

Please move it from /examples/advanced/ to /research

@ZiyueXu77 ZiyueXu77 requested a review from holgerroth February 9, 2026 18:14
@ZiyueXu77 ZiyueXu77 changed the title Implementation of research FedUMM: Federated Learning for Unified Multimodal Models [Research] FedUMM: Federated Learning for Unified Multimodal Models Feb 9, 2026
@holgerroth
Copy link
Copy Markdown
Collaborator

Thanks for the great contribution! I agree, let's put this under /research to increase its visibility.

@holgerroth
Copy link
Copy Markdown
Collaborator

@rollingsu please provide a PR description as well.

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

16 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment thread research/fedumm/src/common.py
Comment thread research/fedumm/src/fl_client.py
Comment thread examples/advanced/fedumm/job.py Outdated
Comment thread research/fedumm/README.md
Copy link
Copy Markdown
Collaborator

@ZiyueXu77 ZiyueXu77 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some updates needed

Comment thread examples/advanced/fedumm/src/local_train.py Outdated
Comment thread examples/advanced/fedumm/src/blip_backend.py Outdated
Comment thread research/fedumm/README.md
Added a check to raise an error if the dataloader is empty.
Removed the line that sets padding token IDs to -100 in labels.
Added aiohttp for timeout configuration in dataset loading.
Comment thread research/fedumm/src/local_train.py
Comment thread research/fedumm/scripts/launch_blip.sh
rollingsu and others added 3 commits March 24, 2026 00:23
@ZiyueXu77 ZiyueXu77 self-requested a review March 30, 2026 22:14
Copy link
Copy Markdown
Collaborator

@ZiyueXu77 ZiyueXu77 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good enough for now, I will further polish it

@ZiyueXu77
Copy link
Copy Markdown
Collaborator

/build

@ZiyueXu77 ZiyueXu77 enabled auto-merge (squash) March 31, 2026 13:39
@ZiyueXu77 ZiyueXu77 merged commit af4ff38 into NVIDIA:main Mar 31, 2026
29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants