diff --git a/skyrl/backends/skyrl_train_backend.py b/skyrl/backends/skyrl_train_backend.py index 0322c94bc7..989072e5fd 100644 --- a/skyrl/backends/skyrl_train_backend.py +++ b/skyrl/backends/skyrl_train_backend.py @@ -1112,6 +1112,22 @@ def _validate_model_state(self, model_id: str) -> None: if self._dispatch is None: raise RuntimeError("Model not initialized") + def _staging_root(self, reference_path) -> str: + """Return a directory for checkpoint staging on the same filesystem as + ``reference_path``. + + Tar archives are written/read in the engine process, but the actual + model files are produced/consumed by remote Ray worker actors that may + run on a different node. Staging on local /tmp (``tempfile``'s default) + therefore breaks on multi-node deployments because the worker and the + engine do not share that path. ``reference_path`` lives under + ``checkpoints_base`` (expected to be shared storage), so staging next to + it keeps the directory visible to both processes. + """ + staging_root = str(os.path.dirname(reference_path)) + os.makedirs(staging_root, exist_ok=True) + return staging_root + def _create_tar_from_directory(self, source_dir: str, output_path: str) -> None: """Create an uncompressed tar archive from a directory.""" # Ensure parent directory exists @@ -1126,8 +1142,10 @@ def save_checkpoint(self, output_path, model_id: str) -> None: self._validate_model_state(model_id) role = self._get_role(model_id) - # Create temp directory for checkpoint - with tempfile.TemporaryDirectory() as temp_dir: + # Create temp directory for checkpoint on the same (shared) filesystem + # as output_path so the remote worker that writes the files and the + # engine that tars them both see the same path. + with tempfile.TemporaryDirectory(dir=self._staging_root(output_path)) as temp_dir: ckpt_dir = os.path.join(temp_dir, "checkpoint") # Save checkpoint directory (includes optimizer state automatically) @@ -1143,8 +1161,10 @@ def load_checkpoint(self, checkpoint_path, model_id: str) -> None: self._validate_model_state(model_id) role = self._get_role(model_id) - # Extract tar to temp directory (filter='data' prevents path traversal attacks) - with tempfile.TemporaryDirectory() as temp_dir: + # Extract tar to temp directory on the same (shared) filesystem as + # checkpoint_path so the remote worker that loads the files can see it. + # (filter='data' prevents path traversal attacks) + with tempfile.TemporaryDirectory(dir=self._staging_root(checkpoint_path)) as temp_dir: with tarfile.open(checkpoint_path, "r") as tar: tar.extractall(temp_dir, filter="data") @@ -1182,7 +1202,10 @@ def save_sampler_checkpoint(self, output_path, model_id: str, persist: bool = Tr if persist: # TODO(tyler): For LoRA, only save the adapters instead of the full merged model - with tempfile.TemporaryDirectory() as temp_dir: + # Stage on the same (shared) filesystem as output_path so the remote + # worker that exports the HF model and the engine that tars it agree + # on the path (they may run on different nodes). + with tempfile.TemporaryDirectory(dir=self._staging_root(output_path)) as temp_dir: hf_dir = os.path.join(temp_dir, "model") self._dispatch.save_hf_model(model="policy", export_dir=hf_dir, tokenizer=self._tokenizer) self._create_tar_from_directory(hf_dir, output_path)