Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/models/checkpoints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .mistral.config_loader import MistralConfigLoader
from .mistral.weight_mapper import (MistralLarge3WeightMapper,
MistralWeightMapper)
from .mx.checkpoint_loader import MxCheckpointLoader

__all__ = [
"HfConfigLoader", "HfWeightLoader", "HfWeightMapper", "MistralConfigLoader",
Expand Down
6 changes: 6 additions & 0 deletions tensorrt_llm/_torch/models/checkpoints/mx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from .checkpoint_loader import MxCheckpointLoader

__all__ = ["MxCheckpointLoader"]
79 changes: 79 additions & 0 deletions tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
ModelExpress checkpoint loader for P2P RDMA weight transfer.

When checkpoint_format="MX", this loader auto-detects source vs target:
- If an MX source exists: receive weights via GPU-to-GPU RDMA (target mode)
- If no source exists: fall back to disk load, then publish as source

The config is always loaded from the local HF checkpoint (on PVC/disk).
"""

from typing import Any, Optional

from torch import nn

from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader
from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import BaseWeightMapper
from tensorrt_llm._torch.models.checkpoints.mx.config_loader import MxConfigLoader
from tensorrt_llm._torch.models.checkpoints.mx.weight_loader import MxWeightLoader
from tensorrt_llm._torch.models.modeling_utils import register_checkpoint_loader
from tensorrt_llm.mapping import Mapping


@register_checkpoint_loader("MX")
class MxCheckpointLoader(BaseCheckpointLoader):

def __init__(
self,
*,
mx_server_url: str | None = None,
weight_loader: Optional[BaseWeightLoader] = None,
weight_mapper: Optional[BaseWeightMapper] = None,
config_loader: Optional[BaseConfigLoader] = None,
):
self._weight_loader = weight_loader or MxWeightLoader(mx_server_url=mx_server_url)
self._config_loader = config_loader or self.get_default_config_loader()
self._weight_mapper = weight_mapper
self._checkpoint_format = "MX"

def cleanup(self) -> None:
if self._weight_mapper is not None:
self._weight_mapper.cleanup()
self._weight_mapper = None
if self._weight_loader is not None:
self._weight_loader.cleanup()
self._weight_loader = None
if self._config_loader is not None:
self._config_loader.cleanup()
self._config_loader = None

def get_default_weight_loader(self) -> MxWeightLoader:
return MxWeightLoader()

def get_default_config_loader(self) -> MxConfigLoader:
return MxConfigLoader()

@property
def weight_loader(self) -> BaseWeightLoader:
return self._weight_loader

@property
def weight_mapper(self) -> Optional[BaseWeightMapper]:
return self._weight_mapper

@weight_mapper.setter
def weight_mapper(self, value: BaseWeightMapper) -> None:
self._weight_mapper = value

@property
def config_loader(self) -> Optional[BaseConfigLoader]:
return self._config_loader

@property
def checkpoint_format(self) -> str:
return self._checkpoint_format
19 changes: 19 additions & 0 deletions tensorrt_llm/_torch/models/checkpoints/mx/config_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Config loader for MX checkpoint format — delegates to HF config loader."""

from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader
from tensorrt_llm._torch.models.checkpoints.hf.config_loader import HfConfigLoader


class MxConfigLoader(BaseConfigLoader):

def __init__(self):
self._hf_loader = HfConfigLoader()

def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig:
return self._hf_loader.load(checkpoint_dir, **kwargs)

def cleanup(self) -> None:
self._hf_loader.cleanup()
81 changes: 81 additions & 0 deletions tensorrt_llm/_torch/models/checkpoints/mx/weight_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""ModelExpress P2P weight loader — receives weights via RDMA from an MX source."""

import logging
import os
from typing import Any, Dict, Union

from tensorrt_llm._torch.models.checkpoints.base_weight_loader import (
BaseWeightLoader,
ConsumableWeightsDict,
)
from tensorrt_llm.mapping import Mapping

logger = logging.getLogger(__name__)


class MxWeightLoader(BaseWeightLoader):

def __init__(self, mx_server_url: str | None = None):
self._mx_server_url = mx_server_url
self._received_via_rdma = False

@property
def received_via_rdma(self) -> bool:
return self._received_via_rdma

def load_weights(
self,
checkpoint_dir: str,
mapping: Mapping,
**kwargs,
) -> Union[Dict[str, Any], ConsumableWeightsDict]:
mx_server = (
self._mx_server_url
or os.environ.get("MODEL_EXPRESS_URL")
or os.environ.get("MX_SERVER_ADDRESS", "localhost:8001")
)

try:
from modelexpress.client import MxClient
except ImportError as err:
raise ImportError(
"checkpoint_format='MX' requires the 'modelexpress' package. "
"Install with: pip install modelexpress"
) from err

probe_timeout = int(os.environ.get("MX_SOURCE_PROBE_TIMEOUT", "30"))
try:
mx_client = MxClient(server_url=mx_server)
try:
resp = mx_client.list_sources()
has_sources = len(resp.instances) > 0
finally:
mx_client.close()
except Exception:
has_sources = False

if has_sources:
logger.info("MX source found — loading weights via RDMA")
from modelexpress.trtllm_live_transfer import MxLiveWeightLoader
live_loader = MxLiveWeightLoader(mx_server=mx_server)
model = kwargs.get("model")
result = live_loader.load_weights(
checkpoint_dir, mapping=mapping, model=model
)
self._received_via_rdma = True
os.environ["MODEL_EXPRESS_TARGET"] = "1"
return result
else:
logger.info(
"No MX source found — loading from disk, will publish as source"
)
os.environ["MODEL_EXPRESS_URL"] = mx_server
from tensorrt_llm._torch.models.checkpoints.hf.weight_loader import (
HfWeightLoader,
)
return HfWeightLoader().load_weights(checkpoint_dir, mapping=mapping)

def cleanup(self) -> None:
pass
40 changes: 26 additions & 14 deletions tensorrt_llm/_torch/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,17 @@ def maybe_convert_to_torch_tensor(


def copy_weight(dst: Parameter, src: torch.Tensor):
# TODO check that is it a reasonable change or not
if dst.dtype != src.dtype:
src = src.to(dst.dtype)
assert dst.dtype == src.dtype, f"Incompatible dtype. dst: {dst.dtype}, src: {src.dtype}"
dst.data.copy_(src)
# Zero-copy pointer swap when source is already on the correct device with matching shape
if (src.data_ptr() == dst.data_ptr()):
return # Already in place (e.g., NIXL wrote directly into param buffer)
if (src.device == dst.device and src.shape == dst.shape and src.is_contiguous()
and dst.is_contiguous()):
dst.data = src
else:
dst.data.copy_(src)


def copy_weight_shard(dst: Parameter, src: torch.Tensor, shard_offset: int,
Expand All @@ -183,8 +189,10 @@ def load_weights_vanilla_helper(module: Linear,
if module.bias is not None:
assert "bias" in weights[0]
device = torch.device('cuda')
# Skip TP slicing for pre-sharded weights (e.g., from P2P RDMA)
tp_size = 1 if getattr(module, '_weights_presharded', False) else module.tp_size

weight = load_weight_shard(weights[0]['weight'], module.tp_size,
weight = load_weight_shard(weights[0]['weight'], tp_size,
module.tp_rank, module.tp_mode,
device) if "weight" in weights[0] else None

Expand All @@ -201,7 +209,7 @@ def load_weights_vanilla_helper(module: Linear,
copy_weight(module.weight, weight_transform(weight))

if module.bias is not None:
bias = load_weight_shard(weights[0]['bias'], module.tp_size,
bias = load_weight_shard(weights[0]['bias'], tp_size,
module.tp_rank, module.tp_mode,
device) if "bias" in weights[0] else None
if bias is not None:
Expand All @@ -224,25 +232,27 @@ def load_weights_fused_qkv_helper(
module, "fused_weight_shard_indices_mapping", None
) is not None, "Fused weight shard indices mapping is required in partial loading"
device = torch.device('cuda')
# Skip TP slicing for pre-sharded weights (e.g., from P2P RDMA)
tp_size = 1 if getattr(module, '_weights_presharded', False) else module.tp_size

q_weight = load_weight_shard(weights[0]['weight'], module.tp_size,
q_weight = load_weight_shard(weights[0]['weight'], tp_size,
module.tp_rank, module.tp_mode,
device) if "weight" in weights[0] else None
k_weight = load_weight_shard(weights[1]['weight'], module.tp_size,
k_weight = load_weight_shard(weights[1]['weight'], tp_size,
module.tp_rank, module.tp_mode,
device) if "weight" in weights[1] else None
v_weight = load_weight_shard(weights[2]['weight'], module.tp_size,
v_weight = load_weight_shard(weights[2]['weight'], tp_size,
module.tp_rank, module.tp_mode,
device) if "weight" in weights[2] else None

if module.bias is not None:
q_bias = load_weight_shard(weights[0]['bias'], module.tp_size,
q_bias = load_weight_shard(weights[0]['bias'], tp_size,
module.tp_rank, module.tp_mode,
device) if "bias" in weights[0] else None
k_bias = load_weight_shard(weights[1]['bias'], module.tp_size,
k_bias = load_weight_shard(weights[1]['bias'], tp_size,
module.tp_rank, module.tp_mode,
device) if "bias" in weights[1] else None
v_bias = load_weight_shard(weights[2]['bias'], module.tp_size,
v_bias = load_weight_shard(weights[2]['bias'], tp_size,
module.tp_rank, module.tp_mode,
device) if "bias" in weights[2] else None
if not allow_partial_loading:
Expand Down Expand Up @@ -277,18 +287,20 @@ def load_weights_fused_gate_up_helper(
module, "fused_weight_shard_indices_mapping", None
) is not None, "Fused weight shard indices mapping is required in partial loading"
device = torch.device('cuda')
# Skip TP slicing for pre-sharded weights (e.g., from P2P RDMA)
tp_size = 1 if getattr(module, '_weights_presharded', False) else module.tp_size

gate_weight = load_weight_shard(weights[0]['weight'], module.tp_size,
gate_weight = load_weight_shard(weights[0]['weight'], tp_size,
module.tp_rank, module.tp_mode,
device) if "weight" in weights[0] else None
up_weight = load_weight_shard(weights[1]['weight'], module.tp_size,
up_weight = load_weight_shard(weights[1]['weight'], tp_size,
module.tp_rank, module.tp_mode,
device) if "weight" in weights[1] else None
if module.bias is not None:
gate_bias = load_weight_shard(weights[0]['bias'], module.tp_size,
gate_bias = load_weight_shard(weights[0]['bias'], tp_size,
module.tp_rank, module.tp_mode,
device) if "bias" in weights[0] else None
up_bias = load_weight_shard(weights[1]['bias'], module.tp_size,
up_bias = load_weight_shard(weights[1]['bias'], tp_size,
module.tp_rank, module.tp_mode,
device) if "bias" in weights[1] else None
if not allow_partial_loading:
Expand Down
37 changes: 37 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,32 @@ def init_meta_tensor(t: torch.Tensor):
self._call_load_weights(model.load_draft_weights, weights,
draft_weight_mapper)

elif load_format == LoadFormat.PRESHARDED:
# P2P RDMA target: source published weights BEFORE
# post_load_weights (pre-processed state). The checkpoint
# loader injects directly into model params via RDMA.
# If it returns empty dict, weights are already in GPU
# memory — skip model.load_weights() but DO run
# post_load_weights() to apply kernel-ready transforms.
from tensorrt_llm._torch.modules.linear import Linear

for m in model.modules():
if isinstance(m, Linear):
m._weights_presharded = True

ckpt_dir = model.llm_checkpoint_dir if hasattr(
model, 'llm_checkpoint_dir') else checkpoint_dir
weights = checkpoint_loader.load_weights(
ckpt_dir, mapping=self.mapping, model=model)

if weights:
self.weight_mapper = checkpoint_loader.get_initialized_weight_mapper(
model, config)
self._call_load_weights(model.load_weights, weights,
self.weight_mapper)
else:
logger.info("PRESHARDED: weights injected via P2P RDMA, skipping load_weights()")

elif load_format == LoadFormat.DUMMY:
self.weight_mapper = checkpoint_loader.get_initialized_weight_mapper(
model, config)
Expand All @@ -428,6 +454,17 @@ def init_meta_tensor(t: torch.Tensor):
raise NotImplementedError(
f"No load support for load format: {load_format}")

# ModelExpress source: publish pre-processed weights BEFORE
# post_load_weights so targets receive raw loaded state and can
# run their own post_load_weights() transforms.
if os.environ.get("MODEL_EXPRESS_URL") and not os.environ.get("MODEL_EXPRESS_TARGET"):
try:
from modelexpress.trtllm_live_transfer import publish_model_params
publish_model_params(model)
model._mx_source_published = True
except Exception as e:
logger.warning("ModelExpress publish failed: %s", e)

for module in model.modules():
if hasattr(module, 'post_load_weights') and not getattr(
module, '_weights_removed', False):
Expand Down
27 changes: 27 additions & 0 deletions tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,33 @@ def notify_proxy_threads_to_quit():
logger.error("Failed to deliver error message to proxy")
return

# ModelExpress source: publish this rank's model params via NIXL.
# Skip if already published from ModelLoader.load() (pre-post_load_weights).
if os.environ.get("MODEL_EXPRESS_URL") and not os.environ.get("MODEL_EXPRESS_TARGET"):
try:
model = worker.engine.model_engine.model
except AttributeError:
model = None
if model and getattr(model, '_mx_source_published', False):
logger.info("ModelExpress: already published from model_loader, skipping worker publish")
else:
try:
from modelexpress.trtllm_live_transfer import publish_from_worker
publish_from_worker(worker)
except Exception as e:
logger.warning("ModelExpress publish_from_worker failed on rank %d: %s", mpi_rank(), e)

# ModelExpress source publish: if MX URL is configured and this worker
# was not set up as a target (RDMA receiver), publish model weights.
if os.environ.get("MODEL_EXPRESS_URL") and not os.environ.get("MODEL_EXPRESS_TARGET"):
try:
from modelexpress.trtllm_live_transfer import publish_from_worker
publish_from_worker(worker)
except ImportError:
pass
except Exception as e:
logger.warning("ModelExpress publish_from_worker failed: %s", e)

# Optionally disable GC (default: not disabled)
if os.getenv("TRTLLM_WORKER_DISABLE_GC", "0") == "1":
gc.disable()
Expand Down
5 changes: 5 additions & 0 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3413,6 +3413,11 @@ class LoadFormat(Enum):
DUMMY = 1
# Only load the multimodal(vision) encoder weights
VISION_ONLY = 2
# Weights are already sharded per TP rank — skip TP slicing during loading.
# The weight mapper still handles name mapping and fusing (q+k+v → qkv),
# but load_weight_shard() returns weights as-is without TP slicing.
# Use case: P2P RDMA transfers where each worker receives its own shard.
PRESHARDED = 3


class SamplerType(StrEnum):
Expand Down