From 0512c13843673bf6ec90eeb5bcc0b36187eb8d42 Mon Sep 17 00:00:00 2001 From: Arjun Krishnakumar Date: Tue, 31 Mar 2026 08:52:41 +0200 Subject: [PATCH 1/3] feat: add account name to slurm config --- src/post_training/config.py | 1 + src/post_training/slurm/job.sh.jinja | 3 ++- src/post_training/slurm/job_llamafactory.sh.jinja | 3 ++- src/post_training/slurm/job_trl_container.sh.jinja | 3 ++- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/post_training/config.py b/src/post_training/config.py index a0bde16..b2db875 100644 --- a/src/post_training/config.py +++ b/src/post_training/config.py @@ -180,6 +180,7 @@ class SlurmConfig: """SLURM job scheduler parameters.""" partition: str = "gpu" + account: str | None = "OELLM_prod2026" num_nodes: int = 1 gpus_per_node: int = 4 cpus_per_task: int = 32 diff --git a/src/post_training/slurm/job.sh.jinja b/src/post_training/slurm/job.sh.jinja index b78a561..710b31d 100644 --- a/src/post_training/slurm/job.sh.jinja +++ b/src/post_training/slurm/job.sh.jinja @@ -6,7 +6,8 @@ #SBATCH --job-name={{ job_name }} #SBATCH --partition={{ partition }} -#SBATCH --nodes={{ num_nodes }} +{% if account %}#SBATCH --account={{ account }} +{% endif %}#SBATCH --nodes={{ num_nodes }} #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:{{ gpus_per_node }} #SBATCH --cpus-per-task={{ cpus_per_task }} diff --git a/src/post_training/slurm/job_llamafactory.sh.jinja b/src/post_training/slurm/job_llamafactory.sh.jinja index 832bf86..3deb215 100644 --- a/src/post_training/slurm/job_llamafactory.sh.jinja +++ b/src/post_training/slurm/job_llamafactory.sh.jinja @@ -6,7 +6,8 @@ #SBATCH --job-name={{ job_name }} #SBATCH --partition={{ partition }} -#SBATCH --nodes={{ num_nodes }} +{% if account %}#SBATCH --account={{ account }} +{% endif %}#SBATCH --nodes={{ num_nodes }} #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:{{ gpus_per_node }} {% if cpus_per_gpu is integer and cpus_per_gpu > 0 -%} diff --git a/src/post_training/slurm/job_trl_container.sh.jinja b/src/post_training/slurm/job_trl_container.sh.jinja index 7192667..23e9122 100644 --- a/src/post_training/slurm/job_trl_container.sh.jinja +++ b/src/post_training/slurm/job_trl_container.sh.jinja @@ -6,7 +6,8 @@ #SBATCH --job-name={{ job_name }} #SBATCH --partition={{ partition }} -#SBATCH --nodes={{ num_nodes }} +{% if account %}#SBATCH --account={{ account }} +{% endif %}#SBATCH --nodes={{ num_nodes }} #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:{{ gpus_per_node }} {% if cpus_per_gpu is integer and cpus_per_gpu > 0 -%} From 2e0a780a16e1a3c7adb393fe511bdfbb437fbe97 Mon Sep 17 00:00:00 2001 From: Arjun Krishnakumar Date: Thu, 30 Apr 2026 17:11:23 +0200 Subject: [PATCH 2/3] feat: wire account field through all SLURM templates Default to None (not a hardcoded account name), forward account=config.slurm.account in all three renderers, and fix template style to multi-line {% if %} blocks. --- src/post_training/config.py | 2 +- src/post_training/slurm/job.sh.jinja | 6 ++++-- src/post_training/slurm/job_llamafactory.sh.jinja | 6 ++++-- src/post_training/slurm/job_trl_container.sh.jinja | 6 ++++-- src/post_training/slurm/launcher.py | 3 +++ 5 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/post_training/config.py b/src/post_training/config.py index b2db875..5f79d8b 100644 --- a/src/post_training/config.py +++ b/src/post_training/config.py @@ -180,7 +180,7 @@ class SlurmConfig: """SLURM job scheduler parameters.""" partition: str = "gpu" - account: str | None = "OELLM_prod2026" + account: str | None = None num_nodes: int = 1 gpus_per_node: int = 4 cpus_per_task: int = 32 diff --git a/src/post_training/slurm/job.sh.jinja b/src/post_training/slurm/job.sh.jinja index 710b31d..47a7e3d 100644 --- a/src/post_training/slurm/job.sh.jinja +++ b/src/post_training/slurm/job.sh.jinja @@ -6,8 +6,10 @@ #SBATCH --job-name={{ job_name }} #SBATCH --partition={{ partition }} -{% if account %}#SBATCH --account={{ account }} -{% endif %}#SBATCH --nodes={{ num_nodes }} +{% if account -%} +#SBATCH --account={{ account }} +{% endif -%} +#SBATCH --nodes={{ num_nodes }} #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:{{ gpus_per_node }} #SBATCH --cpus-per-task={{ cpus_per_task }} diff --git a/src/post_training/slurm/job_llamafactory.sh.jinja b/src/post_training/slurm/job_llamafactory.sh.jinja index 3deb215..c84311a 100644 --- a/src/post_training/slurm/job_llamafactory.sh.jinja +++ b/src/post_training/slurm/job_llamafactory.sh.jinja @@ -6,8 +6,10 @@ #SBATCH --job-name={{ job_name }} #SBATCH --partition={{ partition }} -{% if account %}#SBATCH --account={{ account }} -{% endif %}#SBATCH --nodes={{ num_nodes }} +{% if account -%} +#SBATCH --account={{ account }} +{% endif -%} +#SBATCH --nodes={{ num_nodes }} #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:{{ gpus_per_node }} {% if cpus_per_gpu is integer and cpus_per_gpu > 0 -%} diff --git a/src/post_training/slurm/job_trl_container.sh.jinja b/src/post_training/slurm/job_trl_container.sh.jinja index 23e9122..6d44172 100644 --- a/src/post_training/slurm/job_trl_container.sh.jinja +++ b/src/post_training/slurm/job_trl_container.sh.jinja @@ -6,8 +6,10 @@ #SBATCH --job-name={{ job_name }} #SBATCH --partition={{ partition }} -{% if account %}#SBATCH --account={{ account }} -{% endif %}#SBATCH --nodes={{ num_nodes }} +{% if account -%} +#SBATCH --account={{ account }} +{% endif -%} +#SBATCH --nodes={{ num_nodes }} #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:{{ gpus_per_node }} {% if cpus_per_gpu is integer and cpus_per_gpu > 0 -%} diff --git a/src/post_training/slurm/launcher.py b/src/post_training/slurm/launcher.py index d23e15a..1e5e651 100644 --- a/src/post_training/slurm/launcher.py +++ b/src/post_training/slurm/launcher.py @@ -57,6 +57,7 @@ def render_trl_slurm_script( # SLURM parameters job_name=config.slurm.job_name, partition=config.slurm.partition, + account=config.slurm.account, num_nodes=config.slurm.num_nodes, gpus_per_node=config.slurm.gpus_per_node, cpus_per_task=config.slurm.cpus_per_task, @@ -107,6 +108,7 @@ def render_trl_container_slurm_script( # SLURM parameters job_name=config.slurm.job_name, partition=config.slurm.partition, + account=config.slurm.account, num_nodes=config.slurm.num_nodes, gpus_per_node=config.slurm.gpus_per_node, cpus_per_gpu=config.slurm.cpus_per_gpu, @@ -154,6 +156,7 @@ def render_llamafactory_slurm_script( # SLURM parameters job_name=config.slurm.job_name, partition=config.slurm.partition, + account=config.slurm.account, num_nodes=config.slurm.num_nodes, gpus_per_node=config.slurm.gpus_per_node, cpus_per_gpu=config.slurm.cpus_per_gpu, From 43af49f2d77cd9589c2fa7d696c6ee97181efd25 Mon Sep 17 00:00:00 2001 From: Arjun Krishnakumar Date: Thu, 30 Apr 2026 17:11:23 +0200 Subject: [PATCH 3/3] test: add SLURM account rendering tests for all three templates --- tests/test_slurm_render.py | 83 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 tests/test_slurm_render.py diff --git a/tests/test_slurm_render.py b/tests/test_slurm_render.py new file mode 100644 index 0000000..de3e38c --- /dev/null +++ b/tests/test_slurm_render.py @@ -0,0 +1,83 @@ +"""Automated tests for SLURM template rendering.""" + +import pytest + +from post_training.config import PostTrainingConfig +from post_training.slurm.launcher import ( + render_llamafactory_slurm_script, + render_trl_container_slurm_script, + render_trl_slurm_script, +) + + +@pytest.fixture +def config(): + cfg = PostTrainingConfig() + cfg.container.image = "/shared/containers/trl.sif" + cfg.container.bind_mounts = ["/scratch:/scratch"] + cfg.container.env_file = "/shared/env/cluster.env" + return cfg + + +# --------------------------------------------------------------------------- +# account field — all three templates +# --------------------------------------------------------------------------- + + +def test_trl_account_rendered(tmp_path, config): + """account appears as #SBATCH directive when set.""" + config.slurm.account = "my_project" + run_dir = tmp_path / "outputs" / "my-run" + run_dir.mkdir(parents=True) + + content = render_trl_slurm_script(config, run_dir, "configs/trl/sft.yaml").read_text() + + assert "#SBATCH --account=my_project" in content + + +def test_trl_account_absent_when_none(tmp_path, config): + """account directive is suppressed when not set.""" + run_dir = tmp_path / "outputs" / "my-run" + run_dir.mkdir(parents=True) + + content = render_trl_slurm_script(config, run_dir, "configs/trl/sft.yaml").read_text() + + assert "--account" not in content + + +def test_trl_container_account_rendered(tmp_path, config): + config.slurm.account = "my_project" + run_dir = tmp_path / "outputs" / "my-run" + run_dir.mkdir(parents=True) + + content = render_trl_container_slurm_script(config, run_dir, "configs/trl/sft.yaml").read_text() + + assert "#SBATCH --account=my_project" in content + + +def test_trl_container_account_absent_when_none(tmp_path, config): + run_dir = tmp_path / "outputs" / "my-run" + run_dir.mkdir(parents=True) + + content = render_trl_container_slurm_script(config, run_dir, "configs/trl/sft.yaml").read_text() + + assert "--account" not in content + + +def test_llamafactory_account_rendered(tmp_path, config): + config.slurm.account = "my_project" + run_dir = tmp_path / "outputs" / "my-run" + run_dir.mkdir(parents=True) + + content = render_llamafactory_slurm_script(config, run_dir).read_text() + + assert "#SBATCH --account=my_project" in content + + +def test_llamafactory_account_absent_when_none(tmp_path, config): + run_dir = tmp_path / "outputs" / "my-run" + run_dir.mkdir(parents=True) + + content = render_llamafactory_slurm_script(config, run_dir).read_text() + + assert "--account" not in content