feat: ModelExpres - Add PRESHARDED LoadFormat for zero-disk P2P RDMA weight loading#12898
feat: ModelExpres - Add PRESHARDED LoadFormat for zero-disk P2P RDMA weight loading#12898KavinKrishnan wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
📝 WalkthroughWalkthroughSupport for presharded weight loading is introduced via a new Changes
Sequence DiagramsequenceDiagram
actor User
participant ML as ModelLoader
participant CM as Checkpoint<br/>Manager
participant LM as Linear<br/>Module
participant MX as ModelExpress
participant W as Worker
User->>ML: load(model, LoadFormat.PRESHARDED)
ML->>CM: load_weights(checkpoint)
CM-->>ML: weights dict
ML->>LM: model.load_weights()
LM->>LM: copy_weight() with tp_size=1<br/>(presharded=True)
LM-->>ML: weights loaded
alt MODEL_EXPRESS_SOURCE set
ML->>MX: publish_model_params(model)
MX-->>ML: published
ML->>LM: set _mx_source_published=True
end
ML->>LM: post_load_weights()
LM-->>ML: done
ML-->>User: model ready
Note over W: Worker initialization
alt MODEL_EXPRESS_SOURCE set<br/>and not _mx_source_published
W->>MX: publish_from_worker(worker)
MX-->>W: published
else already published
W->>W: skip republishing
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/pyexecutor/model_loader.py`:
- Around line 457-466: The try/except around importing and calling
publish_model_params currently catches all Exceptions which can mask unrelated
bugs; narrow the except to the specific errors you expect from the import/call
(e.g., ImportError, ModuleNotFoundError, AttributeError, or any custom
ModelexpressError thrown by modelexpress.trtllm_live_transfer) and handle them
explicitly, then set model._mx_source_published only on success and log with
logger.warning including the caught exception; if you intentionally want to
swallow all errors, add a brief comment explaining that design choice instead of
using a bare Exception.
- Around line 413-437: In the PRESHARDED branch (LoadFormat.PRESHARDED) remove
the unsupported model keyword from the checkpoint_loader.load_weights(...) call
(it currently passes ckpt_dir, mapping=self.mapping, model=model); change the
call to pass only the checkpoint_dir and mapping so it matches the weight loader
API, keep the surrounding logic that sets Linear._weights_presharded, obtains
self.weight_mapper via checkpoint_loader.get_initialized_weight_mapper(model,
config), and calls self._call_load_weights(model.load_weights, weights,
self.weight_mapper) when weights is truthy.
In `@tensorrt_llm/executor/worker.py`:
- Line 305: The env-var gate currently uses
os.environ.get("MODEL_EXPRESS_SOURCE") which treats any truthy string (including
"0") as enabled; change the condition to explicitly check for the value "1"
(e.g., os.environ.get("MODEL_EXPRESS_SOURCE") == "1") so only an explicit "1"
enables the source-publish path; update the branch that depends on this check
(the code using os.environ.get("MODEL_EXPRESS_SOURCE") in worker.py) to use the
explicit comparison.
- Around line 303-315: The worker-side fallback publish must not run if the
model_loader already attempted a pre-post_load_weights() publish; update the
guard in the worker publish block to check getattr(model,
'_mx_source_publish_attempted', False) and only call publish_from_worker(worker)
when that flag is False (meaning model_loader did not attempt publishing). Also
update model_loader to set model._mx_source_publish_attempted = True before it
tries the pre-post_load_weights() publish so the worker can reliably detect
attempted-but-failed publishes and avoid a late publish that can break
post-processing; reference the symbols model, _mx_source_published,
_mx_source_publish_attempted, publish_from_worker, model_loader, and
post_load_weights().
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 6ca2816b-ff79-4ac1-a2e0-50d75421d8d8
📒 Files selected for processing (4)
tensorrt_llm/_torch/modules/linear.pytensorrt_llm/_torch/pyexecutor/model_loader.pytensorrt_llm/executor/worker.pytensorrt_llm/llmapi/llm_args.py
| elif load_format == LoadFormat.PRESHARDED: | ||
| # P2P RDMA target: source published weights BEFORE | ||
| # post_load_weights (pre-processed state). The checkpoint | ||
| # loader injects directly into model params via RDMA. | ||
| # If it returns empty dict, weights are already in GPU | ||
| # memory — skip model.load_weights() but DO run | ||
| # post_load_weights() to apply kernel-ready transforms. | ||
| from tensorrt_llm._torch.modules.linear import Linear | ||
|
|
||
| for m in model.modules(): | ||
| if isinstance(m, Linear): | ||
| m._weights_presharded = True | ||
|
|
||
| ckpt_dir = model.llm_checkpoint_dir if hasattr( | ||
| model, 'llm_checkpoint_dir') else checkpoint_dir | ||
| weights = checkpoint_loader.load_weights( | ||
| ckpt_dir, mapping=self.mapping, model=model) | ||
|
|
||
| if weights: | ||
| self.weight_mapper = checkpoint_loader.get_initialized_weight_mapper( | ||
| model, config) | ||
| self._call_load_weights(model.load_weights, weights, | ||
| self.weight_mapper) | ||
| else: | ||
| logger.info("PRESHARDED: weights injected via P2P RDMA, skipping load_weights()") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify checkpoint_loader.load_weights() signature supports model parameter
ast-grep --pattern $'def load_weights($$$, model=$_, $$$)'Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# First, find where checkpoint_loader is used in model_loader.py
rg "checkpoint_loader" tensorrt_llm/_torch/pyexecutor/model_loader.py -B 2 -A 2Repository: NVIDIA/TensorRT-LLM
Length of output: 4799
🏁 Script executed:
#!/bin/bash
# Search for checkpoint_loader class definitions and load_weights methods
rg "def load_weights" tensorrt_llm --type py -B 2 -A 5Repository: NVIDIA/TensorRT-LLM
Length of output: 50376
🏁 Script executed:
#!/bin/bash
# Find where checkpoint_loader is instantiated/initialized
rg "checkpoint_loader\s*=" tensorrt_llm/_torch/pyexecutor/model_loader.py -B 3 -A 1Repository: NVIDIA/TensorRT-LLM
Length of output: 262
🏁 Script executed:
#!/bin/bash
# Find BaseCheckpointLoader.load_weights() definition
find tensorrt_llm -name "*base_checkpoint_loader*" -type fRepository: NVIDIA/TensorRT-LLM
Length of output: 127
🏁 Script executed:
#!/bin/bash
# Look at the BaseCheckpointLoader class and its load_weights method
rg "class BaseCheckpointLoader" tensorrt_llm -A 100 | head -150Repository: NVIDIA/TensorRT-LLM
Length of output: 7783
🏁 Script executed:
#!/bin/bash
# Search specifically for load_weights in base checkpoint loaders
rg "def load_weights" tensorrt_llm -g "*checkpoint_loader*" -B 2 -A 8Repository: NVIDIA/TensorRT-LLM
Length of output: 2533
🏁 Script executed:
#!/bin/bash
# Find BaseWeightLoader class and its load_weights signature
rg "class BaseWeightLoader" tensorrt_llm -A 50 | head -100Repository: NVIDIA/TensorRT-LLM
Length of output: 1889
🏁 Script executed:
#!/bin/bash
# Search for load_weights in weight_loader files
rg "def load_weights" tensorrt_llm -g "*weight_loader*" -B 1 -A 5 | head -100Repository: NVIDIA/TensorRT-LLM
Length of output: 2163
🏁 Script executed:
#!/bin/bash
# Check the actual signature in HF weight loader and other implementations
cat tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py | grep -A 10 "def load_weights"Repository: NVIDIA/TensorRT-LLM
Length of output: 754
🏁 Script executed:
#!/bin/bash
# Check if any weight loader implementations have **kwargs in their signature
rg "def load_weights\(" tensorrt_llm/_torch/models/checkpoints -B 1 -A 8 | grep -A 8 "def load_weights"Repository: NVIDIA/TensorRT-LLM
Length of output: 4198
🏁 Script executed:
#!/bin/bash
# Check more weight loader implementations to see if any accept **kwargs or model parameter
rg "def load_weights" tensorrt_llm/_torch/models/checkpoints -A 3 | head -80Repository: NVIDIA/TensorRT-LLM
Length of output: 1966
🏁 Script executed:
#!/bin/bash
# Look at the complete base_checkpoint_loader.py load_weights method
cat tensorrt_llm/_torch/models/checkpoints/base_checkpoint_loader.py | sed -n '50,80p'Repository: NVIDIA/TensorRT-LLM
Length of output: 1367
🏁 Script executed:
#!/bin/bash
# Search for RDMA, P2P, or PRESHARDED related weight loaders
rg "RDMA|P2P|PRESHARDED" tensorrt_llm/_torch/models/checkpoints -lRepository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Check if there's a presharded or RDMA-specific weight loader
find tensorrt_llm -name "*weight_loader*" -o -name "*checkpoint_loader*" | xargs grep -l "presharded\|rdma\|P2P" -iRepository: NVIDIA/TensorRT-LLM
Length of output: 47
🏁 Script executed:
#!/bin/bash
# Look for any weight loader that accepts model parameter
rg "def load_weights.*model" tensorrt_llm -B 2 -A 5Repository: NVIDIA/TensorRT-LLM
Length of output: 13565
🏁 Script executed:
#!/bin/bash
# Check if HF weight loader or any other loader accepts **kwargs
rg "class.*WeightLoader" tensorrt_llm/_torch/models/checkpoints -A 20 | grep -A 15 "def load_weights"Repository: NVIDIA/TensorRT-LLM
Length of output: 3651
🏁 Script executed:
#!/bin/bash
# Look at the complete HF weight loader implementation
cat tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py | head -100Repository: NVIDIA/TensorRT-LLM
Length of output: 4394
🏁 Script executed:
#!/bin/bash
# Let me verify by looking at the actual call in the model_loader.py one more time
cat tensorrt_llm/_torch/pyexecutor/model_loader.py | sed -n '413,437p'Repository: NVIDIA/TensorRT-LLM
Length of output: 1423
🏁 Script executed:
#!/bin/bash
# Check if there's any other weight loader that might handle the model parameter
find tensorrt_llm/_torch/models/checkpoints -name "*.py" -exec grep -l "class.*WeightLoader" {} \; | xargs -I {} basename {}Repository: NVIDIA/TensorRT-LLM
Length of output: 101
🏁 Script executed:
#!/bin/bash
# Check all weight loader implementations to verify they don't accept **kwargs or model parameter
find tensorrt_llm/_torch/models/checkpoints -name "weight_loader.py" -o -name "*checkpoint_loader.py" | xargs grep -A 3 "def load_weights"Repository: NVIDIA/TensorRT-LLM
Length of output: 1567
Remove unsupported model parameter from checkpoint_loader.load_weights() call.
The underlying weight loader implementations (e.g., HfWeightLoader) do not accept a model parameter. While BaseCheckpointLoader.load_weights() accepts **kwargs, it passes them directly to self.weight_loader.load_weights(), which only accepts checkpoint_dir and mapping. Passing model=model will cause a TypeError at runtime.
The PRESHARDED logic correctly sets _weights_presharded on Linear modules before loading, but the model parameter should be removed from the load_weights() call.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/pyexecutor/model_loader.py` around lines 413 - 437, In
the PRESHARDED branch (LoadFormat.PRESHARDED) remove the unsupported model
keyword from the checkpoint_loader.load_weights(...) call (it currently passes
ckpt_dir, mapping=self.mapping, model=model); change the call to pass only the
checkpoint_dir and mapping so it matches the weight loader API, keep the
surrounding logic that sets Linear._weights_presharded, obtains
self.weight_mapper via checkpoint_loader.get_initialized_weight_mapper(model,
config), and calls self._call_load_weights(model.load_weights, weights,
self.weight_mapper) when weights is truthy.
| # ModelExpress source: publish pre-processed weights BEFORE | ||
| # post_load_weights so targets receive raw loaded state and can | ||
| # run their own post_load_weights() transforms. | ||
| if os.environ.get("MODEL_EXPRESS_SOURCE"): | ||
| try: | ||
| from modelexpress.trtllm_live_transfer import publish_model_params | ||
| publish_model_params(model) | ||
| model._mx_source_published = True | ||
| except Exception as e: | ||
| logger.warning("ModelExpress publish failed: %s", e) |
There was a problem hiding this comment.
Narrow the exception catch to specific exception types.
The bare Exception catch is too broad and could mask unrelated bugs. As per coding guidelines, catch only the specific exceptions expected here.
🛡️ Proposed fix to use specific exceptions
if os.environ.get("MODEL_EXPRESS_SOURCE"):
try:
from modelexpress.trtllm_live_transfer import publish_model_params
publish_model_params(model)
model._mx_source_published = True
- except Exception as e:
+ except (ImportError, ModuleNotFoundError) as e:
+ logger.warning("ModelExpress module not available: %s", e)
+ except RuntimeError as e:
logger.warning("ModelExpress publish failed: %s", e)If the publish_model_params function can raise other specific exceptions, those should be added to the catch clause. Alternatively, if the broad catch is intentional to ensure robustness against any ModelExpress failures, add a comment explaining this design choice.
🧰 Tools
🪛 Ruff (0.15.9)
[warning] 465-465: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/pyexecutor/model_loader.py` around lines 457 - 466, The
try/except around importing and calling publish_model_params currently catches
all Exceptions which can mask unrelated bugs; narrow the except to the specific
errors you expect from the import/call (e.g., ImportError, ModuleNotFoundError,
AttributeError, or any custom ModelexpressError thrown by
modelexpress.trtllm_live_transfer) and handle them explicitly, then set
model._mx_source_published only on success and log with logger.warning including
the caught exception; if you intentionally want to swallow all errors, add a
brief comment explaining that design choice instead of using a bare Exception.
| # ModelExpress source: publish this rank's model params via NIXL. | ||
| # Skip if already published from ModelLoader.load() (pre-post_load_weights). | ||
| if os.environ.get("MODEL_EXPRESS_SOURCE"): | ||
| model = getattr(getattr(getattr(worker, 'engine', None), 'model_engine', None), 'model', None) | ||
| if model and getattr(model, '_mx_source_published', False): | ||
| logger.info("ModelExpress: already published from model_loader, skipping worker publish") | ||
| else: | ||
| try: | ||
| from modelexpress.trtllm_live_transfer import publish_from_worker | ||
| publish_from_worker(worker) | ||
| except Exception as e: | ||
| logger.warning("ModelExpress publish_from_worker failed on rank %d: %s", mpi_rank(), e) | ||
|
|
There was a problem hiding this comment.
Fallback worker publish can break the pre-post_load_weights() publish contract.
model_loader publishes before post_load_weights() and sets _mx_source_published only on success. If that publish fails, this block can publish later from worker after model transforms, which risks inconsistent target-side post-processing.
Suggested guard (worker-side)
- if os.environ.get("MODEL_EXPRESS_SOURCE"):
+ if os.environ.get("MODEL_EXPRESS_SOURCE") == "1":
model = getattr(getattr(getattr(worker, 'engine', None), 'model_engine', None), 'model', None)
if model and getattr(model, '_mx_source_published', False):
logger.info("ModelExpress: already published from model_loader, skipping worker publish")
+ elif model and getattr(model, "_mx_source_publish_attempted", False):
+ raise RuntimeError(
+ "ModelExpress publish during model loading failed; refusing "
+ "late publish after post_load_weights().")
else:
try:
from modelexpress.trtllm_live_transfer import publish_from_worker
publish_from_worker(worker)You’ll also want model_loader to set _mx_source_publish_attempted=True before attempting pre-post-load publish so this guard is reliable.
🧰 Tools
🪛 Ruff (0.15.9)
[warning] 313-313: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/executor/worker.py` around lines 303 - 315, The worker-side
fallback publish must not run if the model_loader already attempted a
pre-post_load_weights() publish; update the guard in the worker publish block to
check getattr(model, '_mx_source_publish_attempted', False) and only call
publish_from_worker(worker) when that flag is False (meaning model_loader did
not attempt publishing). Also update model_loader to set
model._mx_source_publish_attempted = True before it tries the
pre-post_load_weights() publish so the worker can reliably detect
attempted-but-failed publishes and avoid a late publish that can break
post-processing; reference the symbols model, _mx_source_published,
_mx_source_publish_attempted, publish_from_worker, model_loader, and
post_load_weights().
tensorrt_llm/executor/worker.py
Outdated
|
|
||
| # ModelExpress source: publish this rank's model params via NIXL. | ||
| # Skip if already published from ModelLoader.load() (pre-post_load_weights). | ||
| if os.environ.get("MODEL_EXPRESS_SOURCE"): |
There was a problem hiding this comment.
Gate source mode on explicit value ("1"), not env-var truthiness.
Using os.environ.get("MODEL_EXPRESS_SOURCE") treats "0" as enabled and can trigger unintended source publish behavior.
Suggested fix
- if os.environ.get("MODEL_EXPRESS_SOURCE"):
+ if os.environ.get("MODEL_EXPRESS_SOURCE") == "1":📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if os.environ.get("MODEL_EXPRESS_SOURCE"): | |
| if os.environ.get("MODEL_EXPRESS_SOURCE") == "1": |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/executor/worker.py` at line 305, The env-var gate currently uses
os.environ.get("MODEL_EXPRESS_SOURCE") which treats any truthy string (including
"0") as enabled; change the condition to explicitly check for the value "1"
(e.g., os.environ.get("MODEL_EXPRESS_SOURCE") == "1") so only an explicit "1"
enables the source-publish path; update the branch that depends on this check
(the code using os.environ.get("MODEL_EXPRESS_SOURCE") in worker.py) to use the
explicit comparison.
- Remove --model-express-role args from K8s manifests (source/target auto-detected by probing MX server for existing sources) - Remove MODEL_EXPRESS_SOURCE env var from source DGD manifest - Update Dockerfile verification to check model_express_url - Add TODO comments for removing monkey patches once upstream PRs merge: - ai-dynamo/dynamo#8037 (--model-express-url native) - NVIDIA/TensorRT-LLM#12898 (LoadFormat.PRESHARDED native) Signed-off-by: Kavin Krishnan <kavinkrishnan@gmail.com> Made-with: Cursor
1d00ab6 to
dba4208
Compare
Add LoadFormat.PRESHARDED for loading model weights that are already sharded per TP rank, enabling zero-disk P2P RDMA weight transfers where each MPI worker receives only its own shard directly into GPU memory via ModelExpress. Changes: - llm_args.py: Add PRESHARDED = 3 to LoadFormat enum - model_loader.py: PRESHARDED branch with _weights_presharded flag, publish hook before post_load_weights (auto-detect via MODEL_EXPRESS_URL) - linear.py: Override tp_size to 1 when _weights_presharded=True - worker.py: publish_from_worker hook in setup_engine (auto-detect) Source publishes weights before post_load_weights so targets receive pre-processed weights and run their own transforms independently. Auto-detects source role when MODEL_EXPRESS_URL is set and MODEL_EXPRESS_TARGET is not set. Validated: Kimi K2.5 (TP=8, MoE, nvfp4) on GCP GB200 at 365-509 Gbps. Signed-off-by: Kavin Krishnan <kavink@nvidia.com> Made-with: Cursor
dba4208 to
3d274d2
Compare
Summary
Add
LoadFormat.PRESHARDEDfor loading model weights that are already sharded per TP rank, enabling zero-disk P2P RDMA weight transfers where each MPI worker receives only its own shard directly into GPU memory. This is the TRT-LLM integration point for ModelExpress P2P weight transfer.How it works
When
MODEL_EXPRESS_URLis set:publish_from_worker()publishes GPU tensors to the MX server beforepost_load_weights()MxLiveCheckpointLoaderreceives weights via RDMA directly into model parameters, then runspost_load_weights()locallyAuto-detection: workers publish when
MODEL_EXPRESS_URLis set andMODEL_EXPRESS_TARGETis not set (target mode is configured by the Dynamo engine).Changes
tensorrt_llm/llmapi/llm_args.py(+5 lines):PRESHARDED = 3toLoadFormatenumtensorrt_llm/_torch/pyexecutor/model_loader.py(+37 lines):PRESHARDEDbranch: sets_weights_preshardedflag, delegates to externalcheckpoint_loader.load_weights()post_load_weights()— auto-detects viaMODEL_EXPRESS_URLtensorrt_llm/_torch/modules/linear.py(+40/-14 lines):tp_sizeto 1 when_weights_presharded=True(skip TP slicing)tensorrt_llm/executor/worker.py(+13 lines):publish_from_worker()hook insetup_engine()— auto-detects viaMODEL_EXPRESS_URLWhy publish before post_load_weights?
post_load_weights()runs FP8 quantization scale recalculation and MoE load balancer init. Publishing before these transforms means each target runs its ownpost_load_weights()independently, producing correct results.Backward compatibility
Fully backward compatible —
LoadFormat.AUTOand existing paths unchanged.PRESHARDEDonly activates when explicitly set.Validated
Kimi K2.5 (TP=8, MoE, nvfp4) on GCP GB200: 90.75 GB/rank at 365-509 Gbps, end-to-end disaggregated inference verified.
Dependencies
MxLiveCheckpointLoader,publish_from_worker,publish_model_params--model-express-urlintegration