Skip to content

Commit dba4208

Browse files
committed
feat: Add PRESHARDED LoadFormat for zero-disk P2P RDMA weight loading
Add LoadFormat.PRESHARDED for loading model weights that are already sharded per TP rank, enabling zero-disk P2P RDMA weight transfers where each MPI worker receives only its own shard directly into GPU memory via ModelExpress. Changes: - llm_args.py: Add PRESHARDED = 3 to LoadFormat enum - model_loader.py: PRESHARDED branch with _weights_presharded flag, publish hook before post_load_weights (auto-detect via MODEL_EXPRESS_URL) - linear.py: Override tp_size to 1 when _weights_presharded=True - worker.py: publish_from_worker hook in setup_engine (auto-detect) Source publishes weights before post_load_weights so targets receive pre-processed weights and run their own transforms independently. Auto-detects source role when MODEL_EXPRESS_URL is set and MODEL_EXPRESS_TARGET is not set. Validated: Kimi K2.5 (TP=8, MoE, nvfp4) on GCP GB200 at 365-509 Gbps. Signed-off-by: Kavin Krishnan <kavink@nvidia.com> Made-with: Cursor
1 parent 889b81c commit dba4208

4 files changed

Lines changed: 81 additions & 14 deletions

File tree

tensorrt_llm/_torch/modules/linear.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,17 @@ def maybe_convert_to_torch_tensor(
157157

158158

159159
def copy_weight(dst: Parameter, src: torch.Tensor):
160-
# TODO check that is it a reasonable change or not
161160
if dst.dtype != src.dtype:
162161
src = src.to(dst.dtype)
163162
assert dst.dtype == src.dtype, f"Incompatible dtype. dst: {dst.dtype}, src: {src.dtype}"
164-
dst.data.copy_(src)
163+
# Zero-copy pointer swap when source is already on the correct device with matching shape
164+
if (src.data_ptr() == dst.data_ptr()):
165+
return # Already in place (e.g., NIXL wrote directly into param buffer)
166+
if (src.device == dst.device and src.shape == dst.shape and src.is_contiguous()
167+
and dst.is_contiguous()):
168+
dst.data = src
169+
else:
170+
dst.data.copy_(src)
165171

166172

167173
def copy_weight_shard(dst: Parameter, src: torch.Tensor, shard_offset: int,
@@ -183,8 +189,10 @@ def load_weights_vanilla_helper(module: Linear,
183189
if module.bias is not None:
184190
assert "bias" in weights[0]
185191
device = torch.device('cuda')
192+
# Skip TP slicing for pre-sharded weights (e.g., from P2P RDMA)
193+
tp_size = 1 if getattr(module, '_weights_presharded', False) else module.tp_size
186194

187-
weight = load_weight_shard(weights[0]['weight'], module.tp_size,
195+
weight = load_weight_shard(weights[0]['weight'], tp_size,
188196
module.tp_rank, module.tp_mode,
189197
device) if "weight" in weights[0] else None
190198

@@ -201,7 +209,7 @@ def load_weights_vanilla_helper(module: Linear,
201209
copy_weight(module.weight, weight_transform(weight))
202210

203211
if module.bias is not None:
204-
bias = load_weight_shard(weights[0]['bias'], module.tp_size,
212+
bias = load_weight_shard(weights[0]['bias'], tp_size,
205213
module.tp_rank, module.tp_mode,
206214
device) if "bias" in weights[0] else None
207215
if bias is not None:
@@ -224,25 +232,27 @@ def load_weights_fused_qkv_helper(
224232
module, "fused_weight_shard_indices_mapping", None
225233
) is not None, "Fused weight shard indices mapping is required in partial loading"
226234
device = torch.device('cuda')
235+
# Skip TP slicing for pre-sharded weights (e.g., from P2P RDMA)
236+
tp_size = 1 if getattr(module, '_weights_presharded', False) else module.tp_size
227237

228-
q_weight = load_weight_shard(weights[0]['weight'], module.tp_size,
238+
q_weight = load_weight_shard(weights[0]['weight'], tp_size,
229239
module.tp_rank, module.tp_mode,
230240
device) if "weight" in weights[0] else None
231-
k_weight = load_weight_shard(weights[1]['weight'], module.tp_size,
241+
k_weight = load_weight_shard(weights[1]['weight'], tp_size,
232242
module.tp_rank, module.tp_mode,
233243
device) if "weight" in weights[1] else None
234-
v_weight = load_weight_shard(weights[2]['weight'], module.tp_size,
244+
v_weight = load_weight_shard(weights[2]['weight'], tp_size,
235245
module.tp_rank, module.tp_mode,
236246
device) if "weight" in weights[2] else None
237247

238248
if module.bias is not None:
239-
q_bias = load_weight_shard(weights[0]['bias'], module.tp_size,
249+
q_bias = load_weight_shard(weights[0]['bias'], tp_size,
240250
module.tp_rank, module.tp_mode,
241251
device) if "bias" in weights[0] else None
242-
k_bias = load_weight_shard(weights[1]['bias'], module.tp_size,
252+
k_bias = load_weight_shard(weights[1]['bias'], tp_size,
243253
module.tp_rank, module.tp_mode,
244254
device) if "bias" in weights[1] else None
245-
v_bias = load_weight_shard(weights[2]['bias'], module.tp_size,
255+
v_bias = load_weight_shard(weights[2]['bias'], tp_size,
246256
module.tp_rank, module.tp_mode,
247257
device) if "bias" in weights[2] else None
248258
if not allow_partial_loading:
@@ -277,18 +287,20 @@ def load_weights_fused_gate_up_helper(
277287
module, "fused_weight_shard_indices_mapping", None
278288
) is not None, "Fused weight shard indices mapping is required in partial loading"
279289
device = torch.device('cuda')
290+
# Skip TP slicing for pre-sharded weights (e.g., from P2P RDMA)
291+
tp_size = 1 if getattr(module, '_weights_presharded', False) else module.tp_size
280292

281-
gate_weight = load_weight_shard(weights[0]['weight'], module.tp_size,
293+
gate_weight = load_weight_shard(weights[0]['weight'], tp_size,
282294
module.tp_rank, module.tp_mode,
283295
device) if "weight" in weights[0] else None
284-
up_weight = load_weight_shard(weights[1]['weight'], module.tp_size,
296+
up_weight = load_weight_shard(weights[1]['weight'], tp_size,
285297
module.tp_rank, module.tp_mode,
286298
device) if "weight" in weights[1] else None
287299
if module.bias is not None:
288-
gate_bias = load_weight_shard(weights[0]['bias'], module.tp_size,
300+
gate_bias = load_weight_shard(weights[0]['bias'], tp_size,
289301
module.tp_rank, module.tp_mode,
290302
device) if "bias" in weights[0] else None
291-
up_bias = load_weight_shard(weights[1]['bias'], module.tp_size,
303+
up_bias = load_weight_shard(weights[1]['bias'], tp_size,
292304
module.tp_rank, module.tp_mode,
293305
device) if "bias" in weights[1] else None
294306
if not allow_partial_loading:

tensorrt_llm/_torch/pyexecutor/model_loader.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,32 @@ def init_meta_tensor(t: torch.Tensor):
410410
self._call_load_weights(model.load_draft_weights, weights,
411411
draft_weight_mapper)
412412

413+
elif load_format == LoadFormat.PRESHARDED:
414+
# P2P RDMA target: source published weights BEFORE
415+
# post_load_weights (pre-processed state). The checkpoint
416+
# loader injects directly into model params via RDMA.
417+
# If it returns empty dict, weights are already in GPU
418+
# memory — skip model.load_weights() but DO run
419+
# post_load_weights() to apply kernel-ready transforms.
420+
from tensorrt_llm._torch.modules.linear import Linear
421+
422+
for m in model.modules():
423+
if isinstance(m, Linear):
424+
m._weights_presharded = True
425+
426+
ckpt_dir = model.llm_checkpoint_dir if hasattr(
427+
model, 'llm_checkpoint_dir') else checkpoint_dir
428+
weights = checkpoint_loader.load_weights(
429+
ckpt_dir, mapping=self.mapping, model=model)
430+
431+
if weights:
432+
self.weight_mapper = checkpoint_loader.get_initialized_weight_mapper(
433+
model, config)
434+
self._call_load_weights(model.load_weights, weights,
435+
self.weight_mapper)
436+
else:
437+
logger.info("PRESHARDED: weights injected via P2P RDMA, skipping load_weights()")
438+
413439
elif load_format == LoadFormat.DUMMY:
414440
self.weight_mapper = checkpoint_loader.get_initialized_weight_mapper(
415441
model, config)
@@ -428,6 +454,17 @@ def init_meta_tensor(t: torch.Tensor):
428454
raise NotImplementedError(
429455
f"No load support for load format: {load_format}")
430456

457+
# ModelExpress source: publish pre-processed weights BEFORE
458+
# post_load_weights so targets receive raw loaded state and can
459+
# run their own post_load_weights() transforms.
460+
if os.environ.get("MODEL_EXPRESS_URL") and not os.environ.get("MODEL_EXPRESS_TARGET"):
461+
try:
462+
from modelexpress.trtllm_live_transfer import publish_model_params
463+
publish_model_params(model)
464+
model._mx_source_published = True
465+
except Exception as e:
466+
logger.warning("ModelExpress publish failed: %s", e)
467+
431468
for module in model.modules():
432469
if hasattr(module, 'post_load_weights') and not getattr(
433470
module, '_weights_removed', False):

tensorrt_llm/executor/worker.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,19 @@ def notify_proxy_threads_to_quit():
300300
logger.error("Failed to deliver error message to proxy")
301301
return
302302

303+
# ModelExpress source: publish this rank's model params via NIXL.
304+
# Skip if already published from ModelLoader.load() (pre-post_load_weights).
305+
if os.environ.get("MODEL_EXPRESS_URL") and not os.environ.get("MODEL_EXPRESS_TARGET"):
306+
model = getattr(getattr(getattr(worker, 'engine', None), 'model_engine', None), 'model', None)
307+
if model and getattr(model, '_mx_source_published', False):
308+
logger.info("ModelExpress: already published from model_loader, skipping worker publish")
309+
else:
310+
try:
311+
from modelexpress.trtllm_live_transfer import publish_from_worker
312+
publish_from_worker(worker)
313+
except Exception as e:
314+
logger.warning("ModelExpress publish_from_worker failed on rank %d: %s", mpi_rank(), e)
315+
303316
# Optionally disable GC (default: not disabled)
304317
if os.getenv("TRTLLM_WORKER_DISABLE_GC", "0") == "1":
305318
gc.disable()

tensorrt_llm/llmapi/llm_args.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3306,6 +3306,11 @@ class LoadFormat(Enum):
33063306
DUMMY = 1
33073307
# Only load the multimodal(vision) encoder weights
33083308
VISION_ONLY = 2
3309+
# Weights are already sharded per TP rank — skip TP slicing during loading.
3310+
# The weight mapper still handles name mapping and fusing (q+k+v → qkv),
3311+
# but load_weight_shard() returns weights as-is without TP slicing.
3312+
# Use case: P2P RDMA transfers where each worker receives its own shard.
3313+
PRESHARDED = 3
33093314

33103315

33113316
class SamplerType(StrEnum):

0 commit comments

Comments
 (0)