Skip to content

feat: ModelExpres - Add PRESHARDED LoadFormat for zero-disk P2P RDMA weight loading#12898

Open
KavinKrishnan wants to merge 1 commit intoNVIDIA:mainfrom
KavinKrishnan:kavink/presharded-weight-loading
Open

feat: ModelExpres - Add PRESHARDED LoadFormat for zero-disk P2P RDMA weight loading#12898
KavinKrishnan wants to merge 1 commit intoNVIDIA:mainfrom
KavinKrishnan:kavink/presharded-weight-loading

Conversation

@KavinKrishnan
Copy link
Copy Markdown

@KavinKrishnan KavinKrishnan commented Apr 9, 2026

Summary

Add LoadFormat.PRESHARDED for loading model weights that are already sharded per TP rank, enabling zero-disk P2P RDMA weight transfers where each MPI worker receives only its own shard directly into GPU memory. This is the TRT-LLM integration point for ModelExpress P2P weight transfer.

How it works

When MODEL_EXPRESS_URL is set:

  • Source workers (no existing sources in MX): load from disk normally, then publish_from_worker() publishes GPU tensors to the MX server before post_load_weights()
  • Target workers (sources exist in MX): MxLiveCheckpointLoader receives weights via RDMA directly into model parameters, then runs post_load_weights() locally

Auto-detection: workers publish when MODEL_EXPRESS_URL is set and MODEL_EXPRESS_TARGET is not set (target mode is configured by the Dynamo engine).

Changes

tensorrt_llm/llmapi/llm_args.py (+5 lines):

  • Add PRESHARDED = 3 to LoadFormat enum

tensorrt_llm/_torch/pyexecutor/model_loader.py (+37 lines):

  • Add PRESHARDED branch: sets _weights_presharded flag, delegates to external checkpoint_loader.load_weights()
  • Publish hook before post_load_weights() — auto-detects via MODEL_EXPRESS_URL

tensorrt_llm/_torch/modules/linear.py (+40/-14 lines):

  • Override tp_size to 1 when _weights_presharded=True (skip TP slicing)

tensorrt_llm/executor/worker.py (+13 lines):

  • publish_from_worker() hook in setup_engine() — auto-detects via MODEL_EXPRESS_URL

Why publish before post_load_weights?

post_load_weights() runs FP8 quantization scale recalculation and MoE load balancer init. Publishing before these transforms means each target runs its own post_load_weights() independently, producing correct results.

Backward compatibility

Fully backward compatible — LoadFormat.AUTO and existing paths unchanged. PRESHARDED only activates when explicitly set.

Validated

Kimi K2.5 (TP=8, MoE, nvfp4) on GCP GB200: 90.75 GB/rank at 365-509 Gbps, end-to-end disaggregated inference verified.

Dependencies

  • ModelExpress (merged to main): MxLiveCheckpointLoader, publish_from_worker, publish_model_params
  • Dynamo PR #8037: engine-level --model-express-url integration

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 9, 2026

📝 Walkthrough

Walkthrough

Support for presharded weight loading is introduced via a new LoadFormat.PRESHARDED enum value. When activated, the weight loading pipeline skips tensor-parallel slicing and loads weights as-is. The changes optimize data copying in the linear module, integrate ModelExpress for parameter publishing, and add conditional republishing in the worker process when weights are already presharded.

Changes

Cohort / File(s) Summary
LoadFormat Enum Extension
tensorrt_llm/llmapi/llm_args.py
Added PRESHARDED = 3 enum member with documentation explaining that presharded loading bypasses tensor-parallel slicing while maintaining weight mapping/fusing operations.
Weight Loading Pipeline
tensorrt_llm/_torch/pyexecutor/model_loader.py
Implemented new LoadFormat.PRESHARDED branch that loads weights without tensor-parallel slicing, conditionally calls model.load_weights(), and integrates ModelExpress publishing via MODEL_EXPRESS_SOURCE environment variable.
Linear Module Optimization
tensorrt_llm/_torch/modules/linear.py
Enhanced copy_weight() to avoid unnecessary data copying by checking storage pointers; adjusted weight loading helpers to use tp_size = 1 when module has _weights_presharded=True flag instead of module's tensor-parallel size.
Worker Publishing Integration
tensorrt_llm/executor/worker.py
Added conditional ModelExpress publishing in worker_main via publish_from_worker() after worker initialization, respecting _mx_source_published flag to prevent republishing if already handled by model loader.

Sequence Diagram

sequenceDiagram
    actor User
    participant ML as ModelLoader
    participant CM as Checkpoint<br/>Manager
    participant LM as Linear<br/>Module
    participant MX as ModelExpress
    participant W as Worker

    User->>ML: load(model, LoadFormat.PRESHARDED)
    ML->>CM: load_weights(checkpoint)
    CM-->>ML: weights dict
    ML->>LM: model.load_weights()
    LM->>LM: copy_weight() with tp_size=1<br/>(presharded=True)
    LM-->>ML: weights loaded
    
    alt MODEL_EXPRESS_SOURCE set
        ML->>MX: publish_model_params(model)
        MX-->>ML: published
        ML->>LM: set _mx_source_published=True
    end
    
    ML->>LM: post_load_weights()
    LM-->>ML: done
    
    ML-->>User: model ready
    
    Note over W: Worker initialization
    alt MODEL_EXPRESS_SOURCE set<br/>and not _mx_source_published
        W->>MX: publish_from_worker(worker)
        MX-->>W: published
    else already published
        W->>W: skip republishing
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 16.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description check ✅ Passed The pull request description is comprehensive and well-structured, covering the summary, implementation details, rationale, validation, and backward compatibility.
Title check ✅ Passed The title directly summarizes the main feature being added: a new PRESHARDED LoadFormat for P2P RDMA weight loading. It accurately reflects the core change across all modified files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/pyexecutor/model_loader.py`:
- Around line 457-466: The try/except around importing and calling
publish_model_params currently catches all Exceptions which can mask unrelated
bugs; narrow the except to the specific errors you expect from the import/call
(e.g., ImportError, ModuleNotFoundError, AttributeError, or any custom
ModelexpressError thrown by modelexpress.trtllm_live_transfer) and handle them
explicitly, then set model._mx_source_published only on success and log with
logger.warning including the caught exception; if you intentionally want to
swallow all errors, add a brief comment explaining that design choice instead of
using a bare Exception.
- Around line 413-437: In the PRESHARDED branch (LoadFormat.PRESHARDED) remove
the unsupported model keyword from the checkpoint_loader.load_weights(...) call
(it currently passes ckpt_dir, mapping=self.mapping, model=model); change the
call to pass only the checkpoint_dir and mapping so it matches the weight loader
API, keep the surrounding logic that sets Linear._weights_presharded, obtains
self.weight_mapper via checkpoint_loader.get_initialized_weight_mapper(model,
config), and calls self._call_load_weights(model.load_weights, weights,
self.weight_mapper) when weights is truthy.

In `@tensorrt_llm/executor/worker.py`:
- Line 305: The env-var gate currently uses
os.environ.get("MODEL_EXPRESS_SOURCE") which treats any truthy string (including
"0") as enabled; change the condition to explicitly check for the value "1"
(e.g., os.environ.get("MODEL_EXPRESS_SOURCE") == "1") so only an explicit "1"
enables the source-publish path; update the branch that depends on this check
(the code using os.environ.get("MODEL_EXPRESS_SOURCE") in worker.py) to use the
explicit comparison.
- Around line 303-315: The worker-side fallback publish must not run if the
model_loader already attempted a pre-post_load_weights() publish; update the
guard in the worker publish block to check getattr(model,
'_mx_source_publish_attempted', False) and only call publish_from_worker(worker)
when that flag is False (meaning model_loader did not attempt publishing). Also
update model_loader to set model._mx_source_publish_attempted = True before it
tries the pre-post_load_weights() publish so the worker can reliably detect
attempted-but-failed publishes and avoid a late publish that can break
post-processing; reference the symbols model, _mx_source_published,
_mx_source_publish_attempted, publish_from_worker, model_loader, and
post_load_weights().
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 6ca2816b-ff79-4ac1-a2e0-50d75421d8d8

📥 Commits

Reviewing files that changed from the base of the PR and between f8d2090 and 1d00ab6.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/modules/linear.py
  • tensorrt_llm/_torch/pyexecutor/model_loader.py
  • tensorrt_llm/executor/worker.py
  • tensorrt_llm/llmapi/llm_args.py

Comment on lines +413 to +437
elif load_format == LoadFormat.PRESHARDED:
# P2P RDMA target: source published weights BEFORE
# post_load_weights (pre-processed state). The checkpoint
# loader injects directly into model params via RDMA.
# If it returns empty dict, weights are already in GPU
# memory — skip model.load_weights() but DO run
# post_load_weights() to apply kernel-ready transforms.
from tensorrt_llm._torch.modules.linear import Linear

for m in model.modules():
if isinstance(m, Linear):
m._weights_presharded = True

ckpt_dir = model.llm_checkpoint_dir if hasattr(
model, 'llm_checkpoint_dir') else checkpoint_dir
weights = checkpoint_loader.load_weights(
ckpt_dir, mapping=self.mapping, model=model)

if weights:
self.weight_mapper = checkpoint_loader.get_initialized_weight_mapper(
model, config)
self._call_load_weights(model.load_weights, weights,
self.weight_mapper)
else:
logger.info("PRESHARDED: weights injected via P2P RDMA, skipping load_weights()")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify checkpoint_loader.load_weights() signature supports model parameter
ast-grep --pattern $'def load_weights($$$, model=$_, $$$)'

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

#!/bin/bash
# First, find where checkpoint_loader is used in model_loader.py
rg "checkpoint_loader" tensorrt_llm/_torch/pyexecutor/model_loader.py -B 2 -A 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 4799


🏁 Script executed:

#!/bin/bash
# Search for checkpoint_loader class definitions and load_weights methods
rg "def load_weights" tensorrt_llm --type py -B 2 -A 5

Repository: NVIDIA/TensorRT-LLM

Length of output: 50376


🏁 Script executed:

#!/bin/bash
# Find where checkpoint_loader is instantiated/initialized
rg "checkpoint_loader\s*=" tensorrt_llm/_torch/pyexecutor/model_loader.py -B 3 -A 1

Repository: NVIDIA/TensorRT-LLM

Length of output: 262


🏁 Script executed:

#!/bin/bash
# Find BaseCheckpointLoader.load_weights() definition
find tensorrt_llm -name "*base_checkpoint_loader*" -type f

Repository: NVIDIA/TensorRT-LLM

Length of output: 127


🏁 Script executed:

#!/bin/bash
# Look at the BaseCheckpointLoader class and its load_weights method
rg "class BaseCheckpointLoader" tensorrt_llm -A 100 | head -150

Repository: NVIDIA/TensorRT-LLM

Length of output: 7783


🏁 Script executed:

#!/bin/bash
# Search specifically for load_weights in base checkpoint loaders
rg "def load_weights" tensorrt_llm -g "*checkpoint_loader*" -B 2 -A 8

Repository: NVIDIA/TensorRT-LLM

Length of output: 2533


🏁 Script executed:

#!/bin/bash
# Find BaseWeightLoader class and its load_weights signature
rg "class BaseWeightLoader" tensorrt_llm -A 50 | head -100

Repository: NVIDIA/TensorRT-LLM

Length of output: 1889


🏁 Script executed:

#!/bin/bash
# Search for load_weights in weight_loader files
rg "def load_weights" tensorrt_llm -g "*weight_loader*" -B 1 -A 5 | head -100

Repository: NVIDIA/TensorRT-LLM

Length of output: 2163


🏁 Script executed:

#!/bin/bash
# Check the actual signature in HF weight loader and other implementations
cat tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py | grep -A 10 "def load_weights"

Repository: NVIDIA/TensorRT-LLM

Length of output: 754


🏁 Script executed:

#!/bin/bash
# Check if any weight loader implementations have **kwargs in their signature
rg "def load_weights\(" tensorrt_llm/_torch/models/checkpoints -B 1 -A 8 | grep -A 8 "def load_weights"

Repository: NVIDIA/TensorRT-LLM

Length of output: 4198


🏁 Script executed:

#!/bin/bash
# Check more weight loader implementations to see if any accept **kwargs or model parameter
rg "def load_weights" tensorrt_llm/_torch/models/checkpoints -A 3 | head -80

Repository: NVIDIA/TensorRT-LLM

Length of output: 1966


🏁 Script executed:

#!/bin/bash
# Look at the complete base_checkpoint_loader.py load_weights method
cat tensorrt_llm/_torch/models/checkpoints/base_checkpoint_loader.py | sed -n '50,80p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 1367


🏁 Script executed:

#!/bin/bash
# Search for RDMA, P2P, or PRESHARDED related weight loaders
rg "RDMA|P2P|PRESHARDED" tensorrt_llm/_torch/models/checkpoints -l

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

#!/bin/bash
# Check if there's a presharded or RDMA-specific weight loader
find tensorrt_llm -name "*weight_loader*" -o -name "*checkpoint_loader*" | xargs grep -l "presharded\|rdma\|P2P" -i

Repository: NVIDIA/TensorRT-LLM

Length of output: 47


🏁 Script executed:

#!/bin/bash
# Look for any weight loader that accepts model parameter
rg "def load_weights.*model" tensorrt_llm -B 2 -A 5

Repository: NVIDIA/TensorRT-LLM

Length of output: 13565


🏁 Script executed:

#!/bin/bash
# Check if HF weight loader or any other loader accepts **kwargs
rg "class.*WeightLoader" tensorrt_llm/_torch/models/checkpoints -A 20 | grep -A 15 "def load_weights"

Repository: NVIDIA/TensorRT-LLM

Length of output: 3651


🏁 Script executed:

#!/bin/bash
# Look at the complete HF weight loader implementation
cat tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py | head -100

Repository: NVIDIA/TensorRT-LLM

Length of output: 4394


🏁 Script executed:

#!/bin/bash
# Let me verify by looking at the actual call in the model_loader.py one more time
cat tensorrt_llm/_torch/pyexecutor/model_loader.py | sed -n '413,437p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 1423


🏁 Script executed:

#!/bin/bash
# Check if there's any other weight loader that might handle the model parameter
find tensorrt_llm/_torch/models/checkpoints -name "*.py" -exec grep -l "class.*WeightLoader" {} \; | xargs -I {} basename {}

Repository: NVIDIA/TensorRT-LLM

Length of output: 101


🏁 Script executed:

#!/bin/bash
# Check all weight loader implementations to verify they don't accept **kwargs or model parameter
find tensorrt_llm/_torch/models/checkpoints -name "weight_loader.py" -o -name "*checkpoint_loader.py" | xargs grep -A 3 "def load_weights"

Repository: NVIDIA/TensorRT-LLM

Length of output: 1567


Remove unsupported model parameter from checkpoint_loader.load_weights() call.

The underlying weight loader implementations (e.g., HfWeightLoader) do not accept a model parameter. While BaseCheckpointLoader.load_weights() accepts **kwargs, it passes them directly to self.weight_loader.load_weights(), which only accepts checkpoint_dir and mapping. Passing model=model will cause a TypeError at runtime.

The PRESHARDED logic correctly sets _weights_presharded on Linear modules before loading, but the model parameter should be removed from the load_weights() call.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/pyexecutor/model_loader.py` around lines 413 - 437, In
the PRESHARDED branch (LoadFormat.PRESHARDED) remove the unsupported model
keyword from the checkpoint_loader.load_weights(...) call (it currently passes
ckpt_dir, mapping=self.mapping, model=model); change the call to pass only the
checkpoint_dir and mapping so it matches the weight loader API, keep the
surrounding logic that sets Linear._weights_presharded, obtains
self.weight_mapper via checkpoint_loader.get_initialized_weight_mapper(model,
config), and calls self._call_load_weights(model.load_weights, weights,
self.weight_mapper) when weights is truthy.

Comment on lines +457 to +466
# ModelExpress source: publish pre-processed weights BEFORE
# post_load_weights so targets receive raw loaded state and can
# run their own post_load_weights() transforms.
if os.environ.get("MODEL_EXPRESS_SOURCE"):
try:
from modelexpress.trtllm_live_transfer import publish_model_params
publish_model_params(model)
model._mx_source_published = True
except Exception as e:
logger.warning("ModelExpress publish failed: %s", e)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Narrow the exception catch to specific exception types.

The bare Exception catch is too broad and could mask unrelated bugs. As per coding guidelines, catch only the specific exceptions expected here.

🛡️ Proposed fix to use specific exceptions
         if os.environ.get("MODEL_EXPRESS_SOURCE"):
             try:
                 from modelexpress.trtllm_live_transfer import publish_model_params
                 publish_model_params(model)
                 model._mx_source_published = True
-            except Exception as e:
+            except (ImportError, ModuleNotFoundError) as e:
+                logger.warning("ModelExpress module not available: %s", e)
+            except RuntimeError as e:
                 logger.warning("ModelExpress publish failed: %s", e)

If the publish_model_params function can raise other specific exceptions, those should be added to the catch clause. Alternatively, if the broad catch is intentional to ensure robustness against any ModelExpress failures, add a comment explaining this design choice.

🧰 Tools
🪛 Ruff (0.15.9)

[warning] 465-465: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/pyexecutor/model_loader.py` around lines 457 - 466, The
try/except around importing and calling publish_model_params currently catches
all Exceptions which can mask unrelated bugs; narrow the except to the specific
errors you expect from the import/call (e.g., ImportError, ModuleNotFoundError,
AttributeError, or any custom ModelexpressError thrown by
modelexpress.trtllm_live_transfer) and handle them explicitly, then set
model._mx_source_published only on success and log with logger.warning including
the caught exception; if you intentionally want to swallow all errors, add a
brief comment explaining that design choice instead of using a bare Exception.

Comment on lines +303 to +315
# ModelExpress source: publish this rank's model params via NIXL.
# Skip if already published from ModelLoader.load() (pre-post_load_weights).
if os.environ.get("MODEL_EXPRESS_SOURCE"):
model = getattr(getattr(getattr(worker, 'engine', None), 'model_engine', None), 'model', None)
if model and getattr(model, '_mx_source_published', False):
logger.info("ModelExpress: already published from model_loader, skipping worker publish")
else:
try:
from modelexpress.trtllm_live_transfer import publish_from_worker
publish_from_worker(worker)
except Exception as e:
logger.warning("ModelExpress publish_from_worker failed on rank %d: %s", mpi_rank(), e)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fallback worker publish can break the pre-post_load_weights() publish contract.

model_loader publishes before post_load_weights() and sets _mx_source_published only on success. If that publish fails, this block can publish later from worker after model transforms, which risks inconsistent target-side post-processing.

Suggested guard (worker-side)
-    if os.environ.get("MODEL_EXPRESS_SOURCE"):
+    if os.environ.get("MODEL_EXPRESS_SOURCE") == "1":
         model = getattr(getattr(getattr(worker, 'engine', None), 'model_engine', None), 'model', None)
         if model and getattr(model, '_mx_source_published', False):
             logger.info("ModelExpress: already published from model_loader, skipping worker publish")
+        elif model and getattr(model, "_mx_source_publish_attempted", False):
+            raise RuntimeError(
+                "ModelExpress publish during model loading failed; refusing "
+                "late publish after post_load_weights().")
         else:
             try:
                 from modelexpress.trtllm_live_transfer import publish_from_worker
                 publish_from_worker(worker)

You’ll also want model_loader to set _mx_source_publish_attempted=True before attempting pre-post-load publish so this guard is reliable.

🧰 Tools
🪛 Ruff (0.15.9)

[warning] 313-313: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/executor/worker.py` around lines 303 - 315, The worker-side
fallback publish must not run if the model_loader already attempted a
pre-post_load_weights() publish; update the guard in the worker publish block to
check getattr(model, '_mx_source_publish_attempted', False) and only call
publish_from_worker(worker) when that flag is False (meaning model_loader did
not attempt publishing). Also update model_loader to set
model._mx_source_publish_attempted = True before it tries the
pre-post_load_weights() publish so the worker can reliably detect
attempted-but-failed publishes and avoid a late publish that can break
post-processing; reference the symbols model, _mx_source_published,
_mx_source_publish_attempted, publish_from_worker, model_loader, and
post_load_weights().


# ModelExpress source: publish this rank's model params via NIXL.
# Skip if already published from ModelLoader.load() (pre-post_load_weights).
if os.environ.get("MODEL_EXPRESS_SOURCE"):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Gate source mode on explicit value ("1"), not env-var truthiness.

Using os.environ.get("MODEL_EXPRESS_SOURCE") treats "0" as enabled and can trigger unintended source publish behavior.

Suggested fix
-    if os.environ.get("MODEL_EXPRESS_SOURCE"):
+    if os.environ.get("MODEL_EXPRESS_SOURCE") == "1":
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if os.environ.get("MODEL_EXPRESS_SOURCE"):
if os.environ.get("MODEL_EXPRESS_SOURCE") == "1":
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/executor/worker.py` at line 305, The env-var gate currently uses
os.environ.get("MODEL_EXPRESS_SOURCE") which treats any truthy string (including
"0") as enabled; change the condition to explicitly check for the value "1"
(e.g., os.environ.get("MODEL_EXPRESS_SOURCE") == "1") so only an explicit "1"
enables the source-publish path; update the branch that depends on this check
(the code using os.environ.get("MODEL_EXPRESS_SOURCE") in worker.py) to use the
explicit comparison.

@svc-trtllm-gh-bot svc-trtllm-gh-bot added the Community want to contribute PRs initiated from Community label Apr 9, 2026
KavinKrishnan added a commit to ai-dynamo/modelexpress that referenced this pull request Apr 9, 2026
- Remove --model-express-role args from K8s manifests (source/target
  auto-detected by probing MX server for existing sources)
- Remove MODEL_EXPRESS_SOURCE env var from source DGD manifest
- Update Dockerfile verification to check model_express_url
- Add TODO comments for removing monkey patches once upstream PRs merge:
  - ai-dynamo/dynamo#8037 (--model-express-url native)
  - NVIDIA/TensorRT-LLM#12898 (LoadFormat.PRESHARDED native)

Signed-off-by: Kavin Krishnan <kavinkrishnan@gmail.com>
Made-with: Cursor
@KavinKrishnan KavinKrishnan force-pushed the kavink/presharded-weight-loading branch from 1d00ab6 to dba4208 Compare April 9, 2026 20:03
@KavinKrishnan KavinKrishnan changed the title feat: Add PRESHARDED LoadFormat for zero-disk P2P RDMA weight loading feat: ModelExpres - Add PRESHARDED LoadFormat for zero-disk P2P RDMA weight loading Apr 9, 2026
Add LoadFormat.PRESHARDED for loading model weights that are already
sharded per TP rank, enabling zero-disk P2P RDMA weight transfers
where each MPI worker receives only its own shard directly into GPU
memory via ModelExpress.

Changes:
- llm_args.py: Add PRESHARDED = 3 to LoadFormat enum
- model_loader.py: PRESHARDED branch with _weights_presharded flag,
  publish hook before post_load_weights (auto-detect via MODEL_EXPRESS_URL)
- linear.py: Override tp_size to 1 when _weights_presharded=True
- worker.py: publish_from_worker hook in setup_engine (auto-detect)

Source publishes weights before post_load_weights so targets receive
pre-processed weights and run their own transforms independently.
Auto-detects source role when MODEL_EXPRESS_URL is set and
MODEL_EXPRESS_TARGET is not set.

Validated: Kimi K2.5 (TP=8, MoE, nvfp4) on GCP GB200 at 365-509 Gbps.
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Made-with: Cursor
@KavinKrishnan KavinKrishnan force-pushed the kavink/presharded-weight-loading branch from dba4208 to 3d274d2 Compare April 9, 2026 20:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Community want to contribute PRs initiated from Community

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants