From d6a601b106c1fb02e5ac1da3f1bf94dba2426f81 Mon Sep 17 00:00:00 2001 From: SXX Date: Fri, 9 Jan 2026 19:48:25 +0800 Subject: [PATCH] fix: set current CUDA device in _inplace_pin_memory function --- checkpoint_engine/pin_memory.py | 3 +++ examples/update.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/checkpoint_engine/pin_memory.py b/checkpoint_engine/pin_memory.py index 3edcb12..3caa934 100644 --- a/checkpoint_engine/pin_memory.py +++ b/checkpoint_engine/pin_memory.py @@ -191,6 +191,8 @@ class TPMeta(BaseModel): def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]: + device_index = torch.cuda.current_device() + def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer: """ safetensors format see https://huggingface.co/docs/safetensors/en/index#format. @@ -204,6 +206,7 @@ def _pin(t: torch.Tensor): Pin the memory of tensor in-place. See: https://github.com/pytorch/pytorch/issues/32167 """ + torch.cuda.set_device(device_index) cudart = torch.cuda.cudart() r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0) assert r == 0, f"pin memory error, error code: {r}" diff --git a/examples/update.py b/examples/update.py index 51cb189..f5605cf 100644 --- a/examples/update.py +++ b/examples/update.py @@ -14,7 +14,8 @@ from loguru import logger from safetensors import safe_open -from checkpoint_engine.ps import ParameterServer, request_inference_to_update +from checkpoint_engine import request_inference_to_update +from checkpoint_engine.ps import ParameterServer @contextmanager