From 1843dde99a3bf65745d7487080774b66bd07cdef Mon Sep 17 00:00:00 2001 From: Arjun Krishnakumar Date: Tue, 28 Apr 2026 22:06:19 +0200 Subject: [PATCH 1/4] feat: add qos, mem, and save_strategy config fields Add optional qos and mem fields to SlurmConfig so jobs can specify a SLURM QoS class and memory limit without hardcoding them in the template. Add save_strategy to CheckpointingConfig (default "steps") to allow epoch or disabled checkpointing. Wire all three through the SLURM templates and TrainingArguments. --- src/post_training/config.py | 3 +++ src/post_training/methods/common.py | 2 +- src/post_training/slurm/job.sh.jinja | 6 ++++++ src/post_training/slurm/job_trl_container.sh.jinja | 6 ++++++ src/post_training/slurm/launcher.py | 4 ++++ 5 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/post_training/config.py b/src/post_training/config.py index a0bde16..075e549 100644 --- a/src/post_training/config.py +++ b/src/post_training/config.py @@ -107,6 +107,7 @@ class DPOMethodConfig: class CheckpointingConfig: """Checkpoint saving strategy.""" + save_strategy: str = "steps" save_steps: int = 500 save_total_limit: int = 2 # When set to ``None`` (or a non-positive value via CLI overrides), inference @@ -180,10 +181,12 @@ class SlurmConfig: """SLURM job scheduler parameters.""" partition: str = "gpu" + qos: str | None = None num_nodes: int = 1 gpus_per_node: int = 4 cpus_per_task: int = 32 cpus_per_gpu: int | None = None + mem: str | None = None wall_time: str = "02:00:00" job_name: str = "post-training" signal_time_seconds: int = 300 diff --git a/src/post_training/methods/common.py b/src/post_training/methods/common.py index 4cf0db4..6b1c17a 100644 --- a/src/post_training/methods/common.py +++ b/src/post_training/methods/common.py @@ -113,7 +113,7 @@ def build_common_training_kwargs( bf16=t.bf16, seed=t.seed, # Checkpointing - save_strategy="steps", + save_strategy=config.checkpointing.save_strategy, save_steps=config.checkpointing.save_steps, save_total_limit=config.checkpointing.save_total_limit, # Logging diff --git a/src/post_training/slurm/job.sh.jinja b/src/post_training/slurm/job.sh.jinja index b78a561..121e015 100644 --- a/src/post_training/slurm/job.sh.jinja +++ b/src/post_training/slurm/job.sh.jinja @@ -6,6 +6,9 @@ #SBATCH --job-name={{ job_name }} #SBATCH --partition={{ partition }} +{% if qos -%} +#SBATCH --qos={{ qos }} +{% endif -%} #SBATCH --nodes={{ num_nodes }} #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:{{ gpus_per_node }} @@ -13,6 +16,9 @@ {% if cpus_per_gpu is integer and cpus_per_gpu > 0 -%} #SBATCH --cpus-per-gpu={{ cpus_per_gpu }} {% endif -%} +{% if mem -%} +#SBATCH --mem={{ mem }} +{% endif -%} #SBATCH --time={{ wall_time }} #SBATCH --signal=B:USR1@{{ signal_time_seconds }} #SBATCH --output={{ run_dir }}/slurm/slurm-%j.out diff --git a/src/post_training/slurm/job_trl_container.sh.jinja b/src/post_training/slurm/job_trl_container.sh.jinja index d9cbb89..98de555 100644 --- a/src/post_training/slurm/job_trl_container.sh.jinja +++ b/src/post_training/slurm/job_trl_container.sh.jinja @@ -6,6 +6,9 @@ #SBATCH --job-name={{ job_name }} #SBATCH --partition={{ partition }} +{% if qos -%} +#SBATCH --qos={{ qos }} +{% endif -%} #SBATCH --nodes={{ num_nodes }} #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:{{ gpus_per_node }} @@ -13,6 +16,9 @@ {% if cpus_per_gpu is integer and cpus_per_gpu > 0 -%} #SBATCH --cpus-per-gpu={{ cpus_per_gpu }} {% endif -%} +{% if mem -%} +#SBATCH --mem={{ mem }} +{% endif -%} #SBATCH --time={{ wall_time }} #SBATCH --signal=B:USR1@{{ signal_time_seconds }} #SBATCH --output={{ run_dir }}/slurm/slurm-%j.out diff --git a/src/post_training/slurm/launcher.py b/src/post_training/slurm/launcher.py index d16b4a2..bc215bc 100644 --- a/src/post_training/slurm/launcher.py +++ b/src/post_training/slurm/launcher.py @@ -57,10 +57,12 @@ def render_trl_slurm_script( # SLURM parameters job_name=config.slurm.job_name, partition=config.slurm.partition, + qos=config.slurm.qos, num_nodes=config.slurm.num_nodes, gpus_per_node=config.slurm.gpus_per_node, cpus_per_task=config.slurm.cpus_per_task, cpus_per_gpu=config.slurm.cpus_per_gpu, + mem=config.slurm.mem, wall_time=config.slurm.wall_time, signal_time_seconds=config.slurm.signal_time_seconds, max_failures=config.slurm.max_failures, @@ -107,10 +109,12 @@ def render_trl_container_slurm_script( # SLURM parameters job_name=config.slurm.job_name, partition=config.slurm.partition, + qos=config.slurm.qos, num_nodes=config.slurm.num_nodes, gpus_per_node=config.slurm.gpus_per_node, cpus_per_task=config.slurm.cpus_per_task, cpus_per_gpu=config.slurm.cpus_per_gpu, + mem=config.slurm.mem, wall_time=config.slurm.wall_time, signal_time_seconds=config.slurm.signal_time_seconds, max_failures=config.slurm.max_failures, From af0a070f75e93d6efe40a2e2b02af9e0d664caea Mon Sep 17 00:00:00 2001 From: Arjun Krishnakumar Date: Thu, 30 Apr 2026 16:11:28 +0200 Subject: [PATCH 2/4] feat: forward qos and mem to LlamaFactory SLURM template --- src/post_training/slurm/job_llamafactory.sh.jinja | 6 ++++++ src/post_training/slurm/launcher.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/src/post_training/slurm/job_llamafactory.sh.jinja b/src/post_training/slurm/job_llamafactory.sh.jinja index a532b6e..c03db45 100644 --- a/src/post_training/slurm/job_llamafactory.sh.jinja +++ b/src/post_training/slurm/job_llamafactory.sh.jinja @@ -6,6 +6,9 @@ #SBATCH --job-name={{ job_name }} #SBATCH --partition={{ partition }} +{% if qos -%} +#SBATCH --qos={{ qos }} +{% endif -%} #SBATCH --nodes={{ num_nodes }} #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:{{ gpus_per_node }} @@ -13,6 +16,9 @@ {% if cpus_per_gpu is integer and cpus_per_gpu > 0 -%} #SBATCH --cpus-per-gpu={{ cpus_per_gpu }} {% endif -%} +{% if mem -%} +#SBATCH --mem={{ mem }} +{% endif -%} #SBATCH --time={{ wall_time }} #SBATCH --signal=B:USR1@{{ signal_time_seconds }} #SBATCH --output={{ run_dir }}/slurm/slurm-%j.out diff --git a/src/post_training/slurm/launcher.py b/src/post_training/slurm/launcher.py index bc215bc..81b1334 100644 --- a/src/post_training/slurm/launcher.py +++ b/src/post_training/slurm/launcher.py @@ -159,10 +159,12 @@ def render_llamafactory_slurm_script( # SLURM parameters job_name=config.slurm.job_name, partition=config.slurm.partition, + qos=config.slurm.qos, num_nodes=config.slurm.num_nodes, gpus_per_node=config.slurm.gpus_per_node, cpus_per_task=config.slurm.cpus_per_task, cpus_per_gpu=config.slurm.cpus_per_gpu, + mem=config.slurm.mem, wall_time=config.slurm.wall_time, signal_time_seconds=config.slurm.signal_time_seconds, max_failures=config.slurm.max_failures, From b789d2fdf414487176b6a03f6ab19b9d8fdecf08 Mon Sep 17 00:00:00 2001 From: Arjun Krishnakumar Date: Thu, 30 Apr 2026 16:11:28 +0200 Subject: [PATCH 3/4] test: add SLURM template rendering tests for qos and mem --- tests/test_slurm_render.py | 71 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 tests/test_slurm_render.py diff --git a/tests/test_slurm_render.py b/tests/test_slurm_render.py new file mode 100644 index 0000000..35ce4fc --- /dev/null +++ b/tests/test_slurm_render.py @@ -0,0 +1,71 @@ +"""Automated tests for SLURM template rendering.""" + +import pytest + +from post_training.config import PostTrainingConfig +from post_training.slurm.launcher import ( + render_llamafactory_slurm_script, + render_trl_container_slurm_script, +) + + +@pytest.fixture +def config(): + cfg = PostTrainingConfig() + cfg.container.image = "/shared/containers/llamafactory.sif" + cfg.container.bind_mounts = ["/scratch:/scratch"] + cfg.container.env_file = "/shared/env/cluster.env" + return cfg + + +def test_llamafactory_qos_mem_rendered(tmp_path, config): + """qos and mem appear as #SBATCH directives when set.""" + config.slurm.qos = "boost_qos_dbg" + config.slurm.mem = "64G" + run_dir = tmp_path / "outputs" / "my-run" + run_dir.mkdir(parents=True) + + content = render_llamafactory_slurm_script(config, run_dir).read_text() + + assert "#SBATCH --qos=boost_qos_dbg" in content + assert "#SBATCH --mem=64G" in content + + +def test_llamafactory_qos_mem_absent_when_none(tmp_path, config): + """qos and mem directives are suppressed when not set.""" + run_dir = tmp_path / "outputs" / "my-run" + run_dir.mkdir(parents=True) + + content = render_llamafactory_slurm_script(config, run_dir).read_text() + + assert "--qos" not in content + assert "--mem" not in content + + +# --------------------------------------------------------------------------- +# TRL container template +# --------------------------------------------------------------------------- + + +def test_trl_container_qos_mem_rendered(tmp_path, config): + """qos and mem appear as #SBATCH directives when set.""" + config.slurm.qos = "boost_qos_dbg" + config.slurm.mem = "64G" + run_dir = tmp_path / "outputs" / "my-run" + run_dir.mkdir(parents=True) + + content = render_trl_container_slurm_script(config, run_dir, "configs/trl/sft.yaml").read_text() + + assert "#SBATCH --qos=boost_qos_dbg" in content + assert "#SBATCH --mem=64G" in content + + +def test_trl_container_qos_mem_absent_when_none(tmp_path, config): + """qos and mem directives are suppressed when not set.""" + run_dir = tmp_path / "outputs" / "my-run" + run_dir.mkdir(parents=True) + + content = render_trl_container_slurm_script(config, run_dir, "configs/trl/sft.yaml").read_text() + + assert "--qos" not in content + assert "--mem" not in content From 23ab828a486f76451ebb4afeb5b02c9684481316 Mon Sep 17 00:00:00 2001 From: Arjun Krishnakumar Date: Thu, 30 Apr 2026 16:13:43 +0200 Subject: [PATCH 4/4] chore: add pytest to dev dependencies --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5700081..b2e0d43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ trl = [ "tensorboard", ] flash-attn-2 = ["flash-attn"] -dev = ["pre-commit"] +dev = ["pre-commit", "pytest"] [tool.uv] extra-index-url = ["https://download.pytorch.org/whl/cu126"]