Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions skyrl/backends/skyrl_train_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,22 @@ def _validate_model_state(self, model_id: str) -> None:
if self._dispatch is None:
raise RuntimeError("Model not initialized")

def _staging_root(self, reference_path) -> str:
"""Return a directory for checkpoint staging on the same filesystem as
``reference_path``.

Tar archives are written/read in the engine process, but the actual
model files are produced/consumed by remote Ray worker actors that may
run on a different node. Staging on local /tmp (``tempfile``'s default)
therefore breaks on multi-node deployments because the worker and the
engine do not share that path. ``reference_path`` lives under
``checkpoints_base`` (expected to be shared storage), so staging next to
it keeps the directory visible to both processes.
"""
staging_root = str(os.path.dirname(reference_path))
os.makedirs(staging_root, exist_ok=True)
return staging_root
Comment on lines +1127 to +1129
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If reference_path is a relative path (e.g., a simple filename like "checkpoint.tar"), os.path.dirname(reference_path) will return an empty string "". Calling os.makedirs("", exist_ok=True) will then raise a FileNotFoundError: [Errno 2] No such file or directory: ''.

To prevent this and ensure robustness for relative paths, resolve the absolute path using os.path.abspath before extracting the directory name.

Suggested change
staging_root = str(os.path.dirname(reference_path))
os.makedirs(staging_root, exist_ok=True)
return staging_root
staging_root = os.path.dirname(os.path.abspath(reference_path))
os.makedirs(staging_root, exist_ok=True)
return staging_root


def _create_tar_from_directory(self, source_dir: str, output_path: str) -> None:
"""Create an uncompressed tar archive from a directory."""
# Ensure parent directory exists
Expand All @@ -1126,8 +1142,10 @@ def save_checkpoint(self, output_path, model_id: str) -> None:
self._validate_model_state(model_id)
role = self._get_role(model_id)

# Create temp directory for checkpoint
with tempfile.TemporaryDirectory() as temp_dir:
# Create temp directory for checkpoint on the same (shared) filesystem
# as output_path so the remote worker that writes the files and the
# engine that tars them both see the same path.
with tempfile.TemporaryDirectory(dir=self._staging_root(output_path)) as temp_dir:
ckpt_dir = os.path.join(temp_dir, "checkpoint")

# Save checkpoint directory (includes optimizer state automatically)
Expand All @@ -1143,8 +1161,10 @@ def load_checkpoint(self, checkpoint_path, model_id: str) -> None:
self._validate_model_state(model_id)
role = self._get_role(model_id)

# Extract tar to temp directory (filter='data' prevents path traversal attacks)
with tempfile.TemporaryDirectory() as temp_dir:
# Extract tar to temp directory on the same (shared) filesystem as
# checkpoint_path so the remote worker that loads the files can see it.
# (filter='data' prevents path traversal attacks)
with tempfile.TemporaryDirectory(dir=self._staging_root(checkpoint_path)) as temp_dir:
with tarfile.open(checkpoint_path, "r") as tar:
tar.extractall(temp_dir, filter="data")

Expand Down Expand Up @@ -1182,7 +1202,10 @@ def save_sampler_checkpoint(self, output_path, model_id: str, persist: bool = Tr

if persist:
# TODO(tyler): For LoRA, only save the adapters instead of the full merged model
with tempfile.TemporaryDirectory() as temp_dir:
# Stage on the same (shared) filesystem as output_path so the remote
# worker that exports the HF model and the engine that tars it agree
# on the path (they may run on different nodes).
with tempfile.TemporaryDirectory(dir=self._staging_root(output_path)) as temp_dir:
hf_dir = os.path.join(temp_dir, "model")
self._dispatch.save_hf_model(model="policy", export_dir=hf_dir, tokenizer=self._tokenizer)
self._create_tar_from_directory(hf_dir, output_path)
Expand Down
Loading