From 139c4b1f1b11b7bc2b03cc4c6ba7bbcf48198150 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Thu, 9 Apr 2026 13:03:14 -0700 Subject: [PATCH 1/2] feat: Add PRESHARDED LoadFormat for zero-disk P2P RDMA weight loading Add LoadFormat.PRESHARDED for loading model weights that are already sharded per TP rank, enabling zero-disk P2P RDMA weight transfers where each MPI worker receives only its own shard directly into GPU memory via ModelExpress. Changes: - llm_args.py: Add PRESHARDED = 3 to LoadFormat enum - model_loader.py: PRESHARDED branch with _weights_presharded flag, publish hook before post_load_weights (auto-detect via MODEL_EXPRESS_URL) - linear.py: Override tp_size to 1 when _weights_presharded=True - worker.py: publish_from_worker hook in setup_engine (auto-detect) Source publishes weights before post_load_weights so targets receive pre-processed weights and run their own transforms independently. Auto-detects source role when MODEL_EXPRESS_URL is set and MODEL_EXPRESS_TARGET is not set. Validated: Kimi K2.5 (TP=8, MoE, nvfp4) on GCP GB200 at 365-509 Gbps. Signed-off-by: Kavin Krishnan Made-with: Cursor Signed-off-by: Kavin Krishnan Made-with: Cursor --- tensorrt_llm/_torch/modules/linear.py | 40 ++++++++++++------- .../_torch/pyexecutor/model_loader.py | 37 +++++++++++++++++ tensorrt_llm/executor/worker.py | 16 ++++++++ tensorrt_llm/llmapi/llm_args.py | 5 +++ 4 files changed, 84 insertions(+), 14 deletions(-) diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 1efb539fa016..d830e49fdc13 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -157,11 +157,17 @@ def maybe_convert_to_torch_tensor( def copy_weight(dst: Parameter, src: torch.Tensor): - # TODO check that is it a reasonable change or not if dst.dtype != src.dtype: src = src.to(dst.dtype) assert dst.dtype == src.dtype, f"Incompatible dtype. dst: {dst.dtype}, src: {src.dtype}" - dst.data.copy_(src) + # Zero-copy pointer swap when source is already on the correct device with matching shape + if (src.data_ptr() == dst.data_ptr()): + return # Already in place (e.g., NIXL wrote directly into param buffer) + if (src.device == dst.device and src.shape == dst.shape and src.is_contiguous() + and dst.is_contiguous()): + dst.data = src + else: + dst.data.copy_(src) def copy_weight_shard(dst: Parameter, src: torch.Tensor, shard_offset: int, @@ -183,8 +189,10 @@ def load_weights_vanilla_helper(module: Linear, if module.bias is not None: assert "bias" in weights[0] device = torch.device('cuda') + # Skip TP slicing for pre-sharded weights (e.g., from P2P RDMA) + tp_size = 1 if getattr(module, '_weights_presharded', False) else module.tp_size - weight = load_weight_shard(weights[0]['weight'], module.tp_size, + weight = load_weight_shard(weights[0]['weight'], tp_size, module.tp_rank, module.tp_mode, device) if "weight" in weights[0] else None @@ -201,7 +209,7 @@ def load_weights_vanilla_helper(module: Linear, copy_weight(module.weight, weight_transform(weight)) if module.bias is not None: - bias = load_weight_shard(weights[0]['bias'], module.tp_size, + bias = load_weight_shard(weights[0]['bias'], tp_size, module.tp_rank, module.tp_mode, device) if "bias" in weights[0] else None if bias is not None: @@ -224,25 +232,27 @@ def load_weights_fused_qkv_helper( module, "fused_weight_shard_indices_mapping", None ) is not None, "Fused weight shard indices mapping is required in partial loading" device = torch.device('cuda') + # Skip TP slicing for pre-sharded weights (e.g., from P2P RDMA) + tp_size = 1 if getattr(module, '_weights_presharded', False) else module.tp_size - q_weight = load_weight_shard(weights[0]['weight'], module.tp_size, + q_weight = load_weight_shard(weights[0]['weight'], tp_size, module.tp_rank, module.tp_mode, device) if "weight" in weights[0] else None - k_weight = load_weight_shard(weights[1]['weight'], module.tp_size, + k_weight = load_weight_shard(weights[1]['weight'], tp_size, module.tp_rank, module.tp_mode, device) if "weight" in weights[1] else None - v_weight = load_weight_shard(weights[2]['weight'], module.tp_size, + v_weight = load_weight_shard(weights[2]['weight'], tp_size, module.tp_rank, module.tp_mode, device) if "weight" in weights[2] else None if module.bias is not None: - q_bias = load_weight_shard(weights[0]['bias'], module.tp_size, + q_bias = load_weight_shard(weights[0]['bias'], tp_size, module.tp_rank, module.tp_mode, device) if "bias" in weights[0] else None - k_bias = load_weight_shard(weights[1]['bias'], module.tp_size, + k_bias = load_weight_shard(weights[1]['bias'], tp_size, module.tp_rank, module.tp_mode, device) if "bias" in weights[1] else None - v_bias = load_weight_shard(weights[2]['bias'], module.tp_size, + v_bias = load_weight_shard(weights[2]['bias'], tp_size, module.tp_rank, module.tp_mode, device) if "bias" in weights[2] else None if not allow_partial_loading: @@ -277,18 +287,20 @@ def load_weights_fused_gate_up_helper( module, "fused_weight_shard_indices_mapping", None ) is not None, "Fused weight shard indices mapping is required in partial loading" device = torch.device('cuda') + # Skip TP slicing for pre-sharded weights (e.g., from P2P RDMA) + tp_size = 1 if getattr(module, '_weights_presharded', False) else module.tp_size - gate_weight = load_weight_shard(weights[0]['weight'], module.tp_size, + gate_weight = load_weight_shard(weights[0]['weight'], tp_size, module.tp_rank, module.tp_mode, device) if "weight" in weights[0] else None - up_weight = load_weight_shard(weights[1]['weight'], module.tp_size, + up_weight = load_weight_shard(weights[1]['weight'], tp_size, module.tp_rank, module.tp_mode, device) if "weight" in weights[1] else None if module.bias is not None: - gate_bias = load_weight_shard(weights[0]['bias'], module.tp_size, + gate_bias = load_weight_shard(weights[0]['bias'], tp_size, module.tp_rank, module.tp_mode, device) if "bias" in weights[0] else None - up_bias = load_weight_shard(weights[1]['bias'], module.tp_size, + up_bias = load_weight_shard(weights[1]['bias'], tp_size, module.tp_rank, module.tp_mode, device) if "bias" in weights[1] else None if not allow_partial_loading: diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index b3c57a29f75a..16557b813b8c 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -410,6 +410,32 @@ def init_meta_tensor(t: torch.Tensor): self._call_load_weights(model.load_draft_weights, weights, draft_weight_mapper) + elif load_format == LoadFormat.PRESHARDED: + # P2P RDMA target: source published weights BEFORE + # post_load_weights (pre-processed state). The checkpoint + # loader injects directly into model params via RDMA. + # If it returns empty dict, weights are already in GPU + # memory — skip model.load_weights() but DO run + # post_load_weights() to apply kernel-ready transforms. + from tensorrt_llm._torch.modules.linear import Linear + + for m in model.modules(): + if isinstance(m, Linear): + m._weights_presharded = True + + ckpt_dir = model.llm_checkpoint_dir if hasattr( + model, 'llm_checkpoint_dir') else checkpoint_dir + weights = checkpoint_loader.load_weights( + ckpt_dir, mapping=self.mapping, model=model) + + if weights: + self.weight_mapper = checkpoint_loader.get_initialized_weight_mapper( + model, config) + self._call_load_weights(model.load_weights, weights, + self.weight_mapper) + else: + logger.info("PRESHARDED: weights injected via P2P RDMA, skipping load_weights()") + elif load_format == LoadFormat.DUMMY: self.weight_mapper = checkpoint_loader.get_initialized_weight_mapper( model, config) @@ -428,6 +454,17 @@ def init_meta_tensor(t: torch.Tensor): raise NotImplementedError( f"No load support for load format: {load_format}") + # ModelExpress source: publish pre-processed weights BEFORE + # post_load_weights so targets receive raw loaded state and can + # run their own post_load_weights() transforms. + if os.environ.get("MODEL_EXPRESS_URL") and not os.environ.get("MODEL_EXPRESS_TARGET"): + try: + from modelexpress.trtllm_live_transfer import publish_model_params + publish_model_params(model) + model._mx_source_published = True + except Exception as e: + logger.warning("ModelExpress publish failed: %s", e) + for module in model.modules(): if hasattr(module, 'post_load_weights') and not getattr( module, '_weights_removed', False): diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 8e13f8c63612..dc191c86d122 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -300,6 +300,22 @@ def notify_proxy_threads_to_quit(): logger.error("Failed to deliver error message to proxy") return + # ModelExpress source: publish this rank's model params via NIXL. + # Skip if already published from ModelLoader.load() (pre-post_load_weights). + if os.environ.get("MODEL_EXPRESS_URL") and not os.environ.get("MODEL_EXPRESS_TARGET"): + try: + model = worker.engine.model_engine.model + except AttributeError: + model = None + if model and getattr(model, '_mx_source_published', False): + logger.info("ModelExpress: already published from model_loader, skipping worker publish") + else: + try: + from modelexpress.trtllm_live_transfer import publish_from_worker + publish_from_worker(worker) + except Exception as e: + logger.warning("ModelExpress publish_from_worker failed on rank %d: %s", mpi_rank(), e) + # Optionally disable GC (default: not disabled) if os.getenv("TRTLLM_WORKER_DISABLE_GC", "0") == "1": gc.disable() diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 987fdd2ac69d..ef48bf88ab40 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -3413,6 +3413,11 @@ class LoadFormat(Enum): DUMMY = 1 # Only load the multimodal(vision) encoder weights VISION_ONLY = 2 + # Weights are already sharded per TP rank — skip TP slicing during loading. + # The weight mapper still handles name mapping and fusing (q+k+v → qkv), + # but load_weight_shard() returns weights as-is without TP slicing. + # Use case: P2P RDMA transfers where each worker receives its own shard. + PRESHARDED = 3 class SamplerType(StrEnum): From 9c2ec33582676ec467c3e83b56b08567ea390470 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Mon, 13 Apr 2026 13:12:45 -0700 Subject: [PATCH 2/2] feat: Add MX checkpoint loader for ModelExpress P2P RDMA weight transfer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add @register_checkpoint_loader("MX") for GPU-to-GPU weight transfer via ModelExpress. Uses TRT-LLM's existing checkpoint_loader architecture: New files: - _torch/models/checkpoints/mx/checkpoint_loader.py: MxCheckpointLoader - _torch/models/checkpoints/mx/weight_loader.py: MxWeightLoader with auto-detect source/target via MX server probe - _torch/models/checkpoints/mx/config_loader.py: delegates to HfConfigLoader Modified files: - _torch/models/checkpoints/__init__.py: register MxCheckpointLoader - executor/worker.py: publish_from_worker hook for source auto-publish Auto-detects role: if MX sources exist → RDMA receive (target); if none → fall back to HF disk load, then publish as source. Validated: Kimi K2.5 (TP=8, MoE, nvfp4) on GCP GB200 at 365-509 Gbps. Signed-off-by: Kavin Krishnan Made-with: Cursor Signed-off-by: Kavin Krishnan Made-with: Cursor Signed-off-by: Kavin Krishnan Made-with: Cursor --- .../_torch/models/checkpoints/__init__.py | 3 + .../models/checkpoints/base_weight_loader.py | 3 +- .../models/checkpoints/hf/weight_loader.py | 2 +- .../_torch/models/checkpoints/mx/__init__.py | 8 ++ .../checkpoints/mx/checkpoint_loader.py | 79 ++++++++++++++++++ .../models/checkpoints/mx/config_loader.py | 21 +++++ .../models/checkpoints/mx/weight_loader.py | 83 +++++++++++++++++++ .../_torch/pyexecutor/model_loader.py | 6 +- tensorrt_llm/executor/worker.py | 11 +++ 9 files changed, 212 insertions(+), 4 deletions(-) create mode 100644 tensorrt_llm/_torch/models/checkpoints/mx/__init__.py create mode 100644 tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py create mode 100644 tensorrt_llm/_torch/models/checkpoints/mx/config_loader.py create mode 100644 tensorrt_llm/_torch/models/checkpoints/mx/weight_loader.py diff --git a/tensorrt_llm/_torch/models/checkpoints/__init__.py b/tensorrt_llm/_torch/models/checkpoints/__init__.py index d9f3f0e7e7c4..98b691f66b60 100644 --- a/tensorrt_llm/_torch/models/checkpoints/__init__.py +++ b/tensorrt_llm/_torch/models/checkpoints/__init__.py @@ -19,6 +19,9 @@ from .mistral.config_loader import MistralConfigLoader from .mistral.weight_mapper import (MistralLarge3WeightMapper, MistralWeightMapper) +from .mx.checkpoint_loader import MxCheckpointLoader +from .mx.config_loader import MxConfigLoader +from .mx.weight_loader import MxWeightLoader __all__ = [ "HfConfigLoader", "HfWeightLoader", "HfWeightMapper", "MistralConfigLoader", diff --git a/tensorrt_llm/_torch/models/checkpoints/base_weight_loader.py b/tensorrt_llm/_torch/models/checkpoints/base_weight_loader.py index 2cd78e3f9df3..c4ab4bdd158d 100644 --- a/tensorrt_llm/_torch/models/checkpoints/base_weight_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/base_weight_loader.py @@ -93,7 +93,8 @@ class BaseWeightLoader(ABC): @abstractmethod def load_weights( self, checkpoint_dir: str, - mapping: Mapping) -> Union[Dict[str, Any], ConsumableWeightsDict]: + mapping: Mapping, + **kwargs) -> Union[Dict[str, Any], ConsumableWeightsDict]: """ Loads weights from a checkpoint directory. diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py index f47e77a81661..4fb362d5d262 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py @@ -27,7 +27,7 @@ class HfWeightLoader(BaseWeightLoader): """ def load_weights(self, checkpoint_dir: str, - mapping: Mapping) -> dict[str, Any]: + mapping: Mapping, **kwargs) -> dict[str, Any]: weight_files = glob.glob(f"{checkpoint_dir}/*.safetensors") # Some model checkpoint directories contain not only the sharded safetensors, but one # consolidated tensor. In the presence of both, we favor the former, as there really is no need diff --git a/tensorrt_llm/_torch/models/checkpoints/mx/__init__.py b/tensorrt_llm/_torch/models/checkpoints/mx/__init__.py new file mode 100644 index 000000000000..80606a69ffa4 --- /dev/null +++ b/tensorrt_llm/_torch/models/checkpoints/mx/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from .checkpoint_loader import MxCheckpointLoader +from .config_loader import MxConfigLoader +from .weight_loader import MxWeightLoader + +__all__ = ["MxCheckpointLoader", "MxConfigLoader", "MxWeightLoader"] diff --git a/tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py b/tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py new file mode 100644 index 000000000000..c834ed69c4ba --- /dev/null +++ b/tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +ModelExpress checkpoint loader for P2P RDMA weight transfer. + +When checkpoint_format="MX", this loader auto-detects source vs target: +- If an MX source exists: receive weights via GPU-to-GPU RDMA (target mode) +- If no source exists: fall back to disk load, then publish as source + +The config is always loaded from the local HF checkpoint (on PVC/disk). +""" + +from typing import Any, Optional + +from torch import nn + +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader +from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader +from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader +from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import BaseWeightMapper +from tensorrt_llm._torch.models.checkpoints.mx.config_loader import MxConfigLoader +from tensorrt_llm._torch.models.checkpoints.mx.weight_loader import MxWeightLoader +from tensorrt_llm._torch.models.modeling_utils import register_checkpoint_loader +from tensorrt_llm.mapping import Mapping + + +@register_checkpoint_loader("MX") +class MxCheckpointLoader(BaseCheckpointLoader): + + def __init__( + self, + *, + mx_server_url: str | None = None, + weight_loader: Optional[BaseWeightLoader] = None, + weight_mapper: Optional[BaseWeightMapper] = None, + config_loader: Optional[BaseConfigLoader] = None, + ): + self._weight_loader = weight_loader or MxWeightLoader(mx_server_url=mx_server_url) + self._config_loader = config_loader or self.get_default_config_loader() + self._weight_mapper = weight_mapper + self._checkpoint_format = "MX" + + def cleanup(self) -> None: + if self._weight_mapper is not None: + self._weight_mapper.cleanup() + self._weight_mapper = None + if self._weight_loader is not None: + self._weight_loader.cleanup() + self._weight_loader = None + if self._config_loader is not None: + self._config_loader.cleanup() + self._config_loader = None + + def get_default_weight_loader(self) -> MxWeightLoader: + return MxWeightLoader() + + def get_default_config_loader(self) -> MxConfigLoader: + return MxConfigLoader() + + @property + def weight_loader(self) -> BaseWeightLoader: + return self._weight_loader + + @property + def weight_mapper(self) -> Optional[BaseWeightMapper]: + return self._weight_mapper + + @weight_mapper.setter + def weight_mapper(self, value: BaseWeightMapper) -> None: + self._weight_mapper = value + + @property + def config_loader(self) -> Optional[BaseConfigLoader]: + return self._config_loader + + @property + def checkpoint_format(self) -> str: + return self._checkpoint_format diff --git a/tensorrt_llm/_torch/models/checkpoints/mx/config_loader.py b/tensorrt_llm/_torch/models/checkpoints/mx/config_loader.py new file mode 100644 index 000000000000..a8e8c88bb04f --- /dev/null +++ b/tensorrt_llm/_torch/models/checkpoints/mx/config_loader.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Config loader for MX checkpoint format — delegates to HF config loader.""" + +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader +from tensorrt_llm._torch.models.checkpoints.hf.config_loader import HfConfigLoader +from tensorrt_llm._torch.models.modeling_utils import register_config_loader + + +@register_config_loader("MX") +class MxConfigLoader(BaseConfigLoader): + + def __init__(self): + self._hf_loader = HfConfigLoader() + + def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig: + return self._hf_loader.load(checkpoint_dir, **kwargs) + + def cleanup(self) -> None: + self._hf_loader.cleanup() diff --git a/tensorrt_llm/_torch/models/checkpoints/mx/weight_loader.py b/tensorrt_llm/_torch/models/checkpoints/mx/weight_loader.py new file mode 100644 index 000000000000..57932778c1d8 --- /dev/null +++ b/tensorrt_llm/_torch/models/checkpoints/mx/weight_loader.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""ModelExpress P2P weight loader — receives weights via RDMA from an MX source.""" + +import logging +import os +from typing import Any, Dict, Union + +from tensorrt_llm._torch.models.checkpoints.base_weight_loader import ( + BaseWeightLoader, + ConsumableWeightsDict, +) +from tensorrt_llm._torch.models.modeling_utils import register_checkpoint_weight_loader +from tensorrt_llm.mapping import Mapping + +logger = logging.getLogger(__name__) + + +@register_checkpoint_weight_loader("MX") +class MxWeightLoader(BaseWeightLoader): + + def __init__(self, mx_server_url: str | None = None): + self._mx_server_url = mx_server_url + self._received_via_rdma = False + + @property + def received_via_rdma(self) -> bool: + return self._received_via_rdma + + def load_weights( + self, + checkpoint_dir: str, + mapping: Mapping, + **kwargs, + ) -> Union[Dict[str, Any], ConsumableWeightsDict]: + mx_server = ( + self._mx_server_url + or os.environ.get("MODEL_EXPRESS_URL") + or os.environ.get("MX_SERVER_ADDRESS", "localhost:8001") + ) + + try: + from modelexpress.client import MxClient + except ImportError as err: + raise ImportError( + "checkpoint_format='MX' requires the 'modelexpress' package. " + "Install with: pip install modelexpress" + ) from err + + probe_timeout = int(os.environ.get("MX_SOURCE_PROBE_TIMEOUT", "30")) + try: + mx_client = MxClient(server_url=mx_server) + try: + resp = mx_client.list_sources() + has_sources = len(resp.instances) > 0 + finally: + mx_client.close() + except Exception: + has_sources = False + + if has_sources: + logger.info("MX source found — loading weights via RDMA") + from modelexpress.trtllm_live_transfer import MxLiveWeightLoader + live_loader = MxLiveWeightLoader(mx_server=mx_server) + model = kwargs.get("model") + result = live_loader.load_weights( + checkpoint_dir, mapping=mapping, model=model + ) + self._received_via_rdma = True + os.environ["MODEL_EXPRESS_TARGET"] = "1" + return result + else: + logger.info( + "No MX source found — loading from disk, will publish as source" + ) + os.environ["MODEL_EXPRESS_URL"] = mx_server + from tensorrt_llm._torch.models.checkpoints.hf.weight_loader import ( + HfWeightLoader, + ) + return HfWeightLoader().load_weights(checkpoint_dir, mapping=mapping) + + def cleanup(self) -> None: + pass diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index 16557b813b8c..fe862b9577fa 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -384,10 +384,12 @@ def init_meta_tensor(t: torch.Tensor): if load_format == LoadFormat.AUTO: if hasattr(model, 'llm_checkpoint_dir'): weights = checkpoint_loader.load_weights( - model.llm_checkpoint_dir, mapping=self.mapping) + model.llm_checkpoint_dir, mapping=self.mapping, + model=model) else: weights = checkpoint_loader.load_weights( - checkpoint_dir, mapping=self.mapping) + checkpoint_dir, mapping=self.mapping, + model=model) self.weight_mapper = checkpoint_loader.get_initialized_weight_mapper( model, config) diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index dc191c86d122..9f995102ff36 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -316,6 +316,17 @@ def notify_proxy_threads_to_quit(): except Exception as e: logger.warning("ModelExpress publish_from_worker failed on rank %d: %s", mpi_rank(), e) + # ModelExpress source publish: if MX URL is configured and this worker + # was not set up as a target (RDMA receiver), publish model weights. + if os.environ.get("MODEL_EXPRESS_URL") and not os.environ.get("MODEL_EXPRESS_TARGET"): + try: + from modelexpress.trtllm_live_transfer import publish_from_worker + publish_from_worker(worker) + except ImportError: + pass + except Exception as e: + logger.warning("ModelExpress publish_from_worker failed: %s", e) + # Optionally disable GC (default: not disabled) if os.getenv("TRTLLM_WORKER_DISABLE_GC", "0") == "1": gc.disable()