diff --git a/src/post_training/config.py b/src/post_training/config.py index 075e549..91c7aa8 100644 --- a/src/post_training/config.py +++ b/src/post_training/config.py @@ -181,6 +181,7 @@ class SlurmConfig: """SLURM job scheduler parameters.""" partition: str = "gpu" + account: str | None = None qos: str | None = None num_nodes: int = 1 gpus_per_node: int = 4 diff --git a/src/post_training/slurm/job.sh.jinja b/src/post_training/slurm/job.sh.jinja index 121e015..690ae0f 100644 --- a/src/post_training/slurm/job.sh.jinja +++ b/src/post_training/slurm/job.sh.jinja @@ -6,6 +6,9 @@ #SBATCH --job-name={{ job_name }} #SBATCH --partition={{ partition }} +{% if account -%} +#SBATCH --account={{ account }} +{% endif -%} {% if qos -%} #SBATCH --qos={{ qos }} {% endif -%} diff --git a/src/post_training/slurm/job_llamafactory.sh.jinja b/src/post_training/slurm/job_llamafactory.sh.jinja index 49be0b3..38ddd4a 100644 --- a/src/post_training/slurm/job_llamafactory.sh.jinja +++ b/src/post_training/slurm/job_llamafactory.sh.jinja @@ -6,6 +6,9 @@ #SBATCH --job-name={{ job_name }} #SBATCH --partition={{ partition }} +{% if account -%} +#SBATCH --account={{ account }} +{% endif -%} {% if qos -%} #SBATCH --qos={{ qos }} {% endif -%} diff --git a/src/post_training/slurm/job_trl_container.sh.jinja b/src/post_training/slurm/job_trl_container.sh.jinja index 2f68236..6688d6d 100644 --- a/src/post_training/slurm/job_trl_container.sh.jinja +++ b/src/post_training/slurm/job_trl_container.sh.jinja @@ -6,6 +6,9 @@ #SBATCH --job-name={{ job_name }} #SBATCH --partition={{ partition }} +{% if account -%} +#SBATCH --account={{ account }} +{% endif -%} {% if qos -%} #SBATCH --qos={{ qos }} {% endif -%} diff --git a/src/post_training/slurm/launcher.py b/src/post_training/slurm/launcher.py index 5694423..f838462 100644 --- a/src/post_training/slurm/launcher.py +++ b/src/post_training/slurm/launcher.py @@ -57,6 +57,7 @@ def render_trl_slurm_script( # SLURM parameters job_name=config.slurm.job_name, partition=config.slurm.partition, + account=config.slurm.account, qos=config.slurm.qos, num_nodes=config.slurm.num_nodes, gpus_per_node=config.slurm.gpus_per_node, @@ -109,6 +110,7 @@ def render_trl_container_slurm_script( # SLURM parameters job_name=config.slurm.job_name, partition=config.slurm.partition, + account=config.slurm.account, qos=config.slurm.qos, num_nodes=config.slurm.num_nodes, gpus_per_node=config.slurm.gpus_per_node, @@ -159,6 +161,7 @@ def render_llamafactory_slurm_script( # SLURM parameters job_name=config.slurm.job_name, partition=config.slurm.partition, + account=config.slurm.account, qos=config.slurm.qos, num_nodes=config.slurm.num_nodes, gpus_per_node=config.slurm.gpus_per_node, diff --git a/tests/test_slurm_render.py b/tests/test_slurm_render.py index 35ce4fc..26ed7c2 100644 --- a/tests/test_slurm_render.py +++ b/tests/test_slurm_render.py @@ -6,18 +6,88 @@ from post_training.slurm.launcher import ( render_llamafactory_slurm_script, render_trl_container_slurm_script, + render_trl_slurm_script, ) @pytest.fixture def config(): cfg = PostTrainingConfig() - cfg.container.image = "/shared/containers/llamafactory.sif" + cfg.container.image = "/shared/containers/trl.sif" cfg.container.bind_mounts = ["/scratch:/scratch"] cfg.container.env_file = "/shared/env/cluster.env" return cfg +# --------------------------------------------------------------------------- +# account field — all three templates +# --------------------------------------------------------------------------- + + +def test_trl_account_rendered(tmp_path, config): + """account appears as #SBATCH directive when set.""" + config.slurm.account = "my_project" + run_dir = tmp_path / "outputs" / "my-run" + run_dir.mkdir(parents=True) + + content = render_trl_slurm_script(config, run_dir, "configs/trl/sft.yaml").read_text() + + assert "#SBATCH --account=my_project" in content + + +def test_trl_account_absent_when_none(tmp_path, config): + """account directive is suppressed when not set.""" + run_dir = tmp_path / "outputs" / "my-run" + run_dir.mkdir(parents=True) + + content = render_trl_slurm_script(config, run_dir, "configs/trl/sft.yaml").read_text() + + assert "--account" not in content + + +def test_trl_container_account_rendered(tmp_path, config): + config.slurm.account = "my_project" + run_dir = tmp_path / "outputs" / "my-run" + run_dir.mkdir(parents=True) + + content = render_trl_container_slurm_script(config, run_dir, "configs/trl/sft.yaml").read_text() + + assert "#SBATCH --account=my_project" in content + + +def test_trl_container_account_absent_when_none(tmp_path, config): + run_dir = tmp_path / "outputs" / "my-run" + run_dir.mkdir(parents=True) + + content = render_trl_container_slurm_script(config, run_dir, "configs/trl/sft.yaml").read_text() + + assert "--account" not in content + + +def test_llamafactory_account_rendered(tmp_path, config): + config.slurm.account = "my_project" + run_dir = tmp_path / "outputs" / "my-run" + run_dir.mkdir(parents=True) + + content = render_llamafactory_slurm_script(config, run_dir).read_text() + + assert "#SBATCH --account=my_project" in content + + +def test_llamafactory_account_absent_when_none(tmp_path, config): + run_dir = tmp_path / "outputs" / "my-run" + run_dir.mkdir(parents=True) + + content = render_llamafactory_slurm_script(config, run_dir).read_text() + + assert "--account" not in content + + +# --------------------------------------------------------------------------- +# qos / mem fields — LlamaFactory and TRL container templates +# --------------------------------------------------------------------------- + + def test_llamafactory_qos_mem_rendered(tmp_path, config): """qos and mem appear as #SBATCH directives when set.""" config.slurm.qos = "boost_qos_dbg" @@ -42,11 +112,6 @@ def test_llamafactory_qos_mem_absent_when_none(tmp_path, config): assert "--mem" not in content -# --------------------------------------------------------------------------- -# TRL container template -# --------------------------------------------------------------------------- - - def test_trl_container_qos_mem_rendered(tmp_path, config): """qos and mem appear as #SBATCH directives when set.""" config.slurm.qos = "boost_qos_dbg"