Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion scripts/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

import argparse
import logging
import os
import re
import sys
from pathlib import Path

Expand All @@ -25,10 +27,53 @@
from post_training.utils.guardrails import run_guardrails
from post_training.utils.logging import setup_logging
from post_training.utils.paths import setup_run_directory
from post_training.utils.prefetch import prefetch_assets

logger = logging.getLogger(__name__)

_HF_CACHE_VARS = frozenset(
{
"HF_HOME",
"HF_HUB_CACHE",
"HUGGINGFACE_HUB_CACHE",
"HF_DATASETS_CACHE",
"TRANSFORMERS_CACHE",
}
)


def _apply_hf_env_from_file(env_file: str) -> None:
"""Parse HF cache vars from a shell env file and apply them to os.environ.

Ensures prefetch_assets() downloads to the same cache root that the
container will use (set via container.env_file sourced in the SLURM script).
"""
path = Path(env_file)
if not path.exists():
raise FileNotFoundError(f"container.env_file '{env_file}' not found.")

parsed: dict[str, str] = {}
export_re = re.compile(r"^export\s+([A-Za-z_][A-Za-z0-9_]*)=(.*)$")
with path.open() as f:
for line in f:
m = export_re.match(line.strip())
if not m:
continue
key, value = m.group(1), m.group(2).strip("\"'")
value = re.sub(
r"\$\{?([A-Za-z_][A-Za-z0-9_]*)\}?",
lambda mv: parsed.get(mv.group(1), os.environ.get(mv.group(1), "")),
value,
)
parsed[key] = value

applied = []
for key in _HF_CACHE_VARS:
if key in parsed:
os.environ[key] = parsed[key]
applied.append(f"{key}={parsed[key]}")
if applied:
logger.info("Applied HF cache vars from %s: %s", env_file, ", ".join(applied))


def _parse_args() -> tuple[str, list[str], bool]:
parser = argparse.ArgumentParser(description="Submit a SLURM training job.")
Expand Down Expand Up @@ -64,10 +109,16 @@ def main() -> None:
config.slurm.gpus_per_node = 1

if config.offline:
if config.container.env_file:
_apply_hf_env_from_file(config.container.env_file)
logger.info(
"offline=True: pre-fetching models and datasets on the login node "
"before submitting the job."
)
# Lazy import: must come after _apply_hf_env_from_file so that
# huggingface_hub and datasets read the correct HF_HOME on first import.
from post_training.utils.prefetch import prefetch_assets

prefetch_assets(config)

# Set up the run directory (so the SLURM script can reference it).
Expand Down
Loading