diff --git a/scripts/submit.py b/scripts/submit.py index b86cbc2..5087eac 100644 --- a/scripts/submit.py +++ b/scripts/submit.py @@ -13,6 +13,8 @@ import argparse import logging +import os +import re import sys from pathlib import Path @@ -25,10 +27,53 @@ from post_training.utils.guardrails import run_guardrails from post_training.utils.logging import setup_logging from post_training.utils.paths import setup_run_directory -from post_training.utils.prefetch import prefetch_assets logger = logging.getLogger(__name__) +_HF_CACHE_VARS = frozenset( + { + "HF_HOME", + "HF_HUB_CACHE", + "HUGGINGFACE_HUB_CACHE", + "HF_DATASETS_CACHE", + "TRANSFORMERS_CACHE", + } +) + + +def _apply_hf_env_from_file(env_file: str) -> None: + """Parse HF cache vars from a shell env file and apply them to os.environ. + + Ensures prefetch_assets() downloads to the same cache root that the + container will use (set via container.env_file sourced in the SLURM script). + """ + path = Path(env_file) + if not path.exists(): + raise FileNotFoundError(f"container.env_file '{env_file}' not found.") + + parsed: dict[str, str] = {} + export_re = re.compile(r"^export\s+([A-Za-z_][A-Za-z0-9_]*)=(.*)$") + with path.open() as f: + for line in f: + m = export_re.match(line.strip()) + if not m: + continue + key, value = m.group(1), m.group(2).strip("\"'") + value = re.sub( + r"\$\{?([A-Za-z_][A-Za-z0-9_]*)\}?", + lambda mv: parsed.get(mv.group(1), os.environ.get(mv.group(1), "")), + value, + ) + parsed[key] = value + + applied = [] + for key in _HF_CACHE_VARS: + if key in parsed: + os.environ[key] = parsed[key] + applied.append(f"{key}={parsed[key]}") + if applied: + logger.info("Applied HF cache vars from %s: %s", env_file, ", ".join(applied)) + def _parse_args() -> tuple[str, list[str], bool]: parser = argparse.ArgumentParser(description="Submit a SLURM training job.") @@ -64,10 +109,16 @@ def main() -> None: config.slurm.gpus_per_node = 1 if config.offline: + if config.container.env_file: + _apply_hf_env_from_file(config.container.env_file) logger.info( "offline=True: pre-fetching models and datasets on the login node " "before submitting the job." ) + # Lazy import: must come after _apply_hf_env_from_file so that + # huggingface_hub and datasets read the correct HF_HOME on first import. + from post_training.utils.prefetch import prefetch_assets + prefetch_assets(config) # Set up the run directory (so the SLURM script can reference it).