diff --git a/pyproject.toml b/pyproject.toml index 5700081..b2e0d43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ trl = [ "tensorboard", ] flash-attn-2 = ["flash-attn"] -dev = ["pre-commit"] +dev = ["pre-commit", "pytest"] [tool.uv] extra-index-url = ["https://download.pytorch.org/whl/cu126"] diff --git a/src/post_training/config.py b/src/post_training/config.py index a0bde16..075e549 100644 --- a/src/post_training/config.py +++ b/src/post_training/config.py @@ -107,6 +107,7 @@ class DPOMethodConfig: class CheckpointingConfig: """Checkpoint saving strategy.""" + save_strategy: str = "steps" save_steps: int = 500 save_total_limit: int = 2 # When set to ``None`` (or a non-positive value via CLI overrides), inference @@ -180,10 +181,12 @@ class SlurmConfig: """SLURM job scheduler parameters.""" partition: str = "gpu" + qos: str | None = None num_nodes: int = 1 gpus_per_node: int = 4 cpus_per_task: int = 32 cpus_per_gpu: int | None = None + mem: str | None = None wall_time: str = "02:00:00" job_name: str = "post-training" signal_time_seconds: int = 300 diff --git a/src/post_training/methods/common.py b/src/post_training/methods/common.py index 4cf0db4..6b1c17a 100644 --- a/src/post_training/methods/common.py +++ b/src/post_training/methods/common.py @@ -113,7 +113,7 @@ def build_common_training_kwargs( bf16=t.bf16, seed=t.seed, # Checkpointing - save_strategy="steps", + save_strategy=config.checkpointing.save_strategy, save_steps=config.checkpointing.save_steps, save_total_limit=config.checkpointing.save_total_limit, # Logging diff --git a/src/post_training/slurm/job.sh.jinja b/src/post_training/slurm/job.sh.jinja index b78a561..121e015 100644 --- a/src/post_training/slurm/job.sh.jinja +++ b/src/post_training/slurm/job.sh.jinja @@ -6,6 +6,9 @@ #SBATCH --job-name={{ job_name }} #SBATCH --partition={{ partition }} +{% if qos -%} +#SBATCH --qos={{ qos }} +{% endif -%} #SBATCH --nodes={{ num_nodes }} #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:{{ gpus_per_node }} @@ -13,6 +16,9 @@ {% if cpus_per_gpu is integer and cpus_per_gpu > 0 -%} #SBATCH --cpus-per-gpu={{ cpus_per_gpu }} {% endif -%} +{% if mem -%} +#SBATCH --mem={{ mem }} +{% endif -%} #SBATCH --time={{ wall_time }} #SBATCH --signal=B:USR1@{{ signal_time_seconds }} #SBATCH --output={{ run_dir }}/slurm/slurm-%j.out diff --git a/src/post_training/slurm/job_llamafactory.sh.jinja b/src/post_training/slurm/job_llamafactory.sh.jinja index a532b6e..c03db45 100644 --- a/src/post_training/slurm/job_llamafactory.sh.jinja +++ b/src/post_training/slurm/job_llamafactory.sh.jinja @@ -6,6 +6,9 @@ #SBATCH --job-name={{ job_name }} #SBATCH --partition={{ partition }} +{% if qos -%} +#SBATCH --qos={{ qos }} +{% endif -%} #SBATCH --nodes={{ num_nodes }} #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:{{ gpus_per_node }} @@ -13,6 +16,9 @@ {% if cpus_per_gpu is integer and cpus_per_gpu > 0 -%} #SBATCH --cpus-per-gpu={{ cpus_per_gpu }} {% endif -%} +{% if mem -%} +#SBATCH --mem={{ mem }} +{% endif -%} #SBATCH --time={{ wall_time }} #SBATCH --signal=B:USR1@{{ signal_time_seconds }} #SBATCH --output={{ run_dir }}/slurm/slurm-%j.out diff --git a/src/post_training/slurm/job_trl_container.sh.jinja b/src/post_training/slurm/job_trl_container.sh.jinja index 0dda5ef..d57fd53 100644 --- a/src/post_training/slurm/job_trl_container.sh.jinja +++ b/src/post_training/slurm/job_trl_container.sh.jinja @@ -6,6 +6,9 @@ #SBATCH --job-name={{ job_name }} #SBATCH --partition={{ partition }} +{% if qos -%} +#SBATCH --qos={{ qos }} +{% endif -%} #SBATCH --nodes={{ num_nodes }} #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:{{ gpus_per_node }} @@ -13,6 +16,9 @@ {% if cpus_per_gpu is integer and cpus_per_gpu > 0 -%} #SBATCH --cpus-per-gpu={{ cpus_per_gpu }} {% endif -%} +{% if mem -%} +#SBATCH --mem={{ mem }} +{% endif -%} #SBATCH --time={{ wall_time }} #SBATCH --signal=B:USR1@{{ signal_time_seconds }} #SBATCH --output={{ run_dir }}/slurm/slurm-%j.out diff --git a/src/post_training/slurm/launcher.py b/src/post_training/slurm/launcher.py index d16b4a2..81b1334 100644 --- a/src/post_training/slurm/launcher.py +++ b/src/post_training/slurm/launcher.py @@ -57,10 +57,12 @@ def render_trl_slurm_script( # SLURM parameters job_name=config.slurm.job_name, partition=config.slurm.partition, + qos=config.slurm.qos, num_nodes=config.slurm.num_nodes, gpus_per_node=config.slurm.gpus_per_node, cpus_per_task=config.slurm.cpus_per_task, cpus_per_gpu=config.slurm.cpus_per_gpu, + mem=config.slurm.mem, wall_time=config.slurm.wall_time, signal_time_seconds=config.slurm.signal_time_seconds, max_failures=config.slurm.max_failures, @@ -107,10 +109,12 @@ def render_trl_container_slurm_script( # SLURM parameters job_name=config.slurm.job_name, partition=config.slurm.partition, + qos=config.slurm.qos, num_nodes=config.slurm.num_nodes, gpus_per_node=config.slurm.gpus_per_node, cpus_per_task=config.slurm.cpus_per_task, cpus_per_gpu=config.slurm.cpus_per_gpu, + mem=config.slurm.mem, wall_time=config.slurm.wall_time, signal_time_seconds=config.slurm.signal_time_seconds, max_failures=config.slurm.max_failures, @@ -155,10 +159,12 @@ def render_llamafactory_slurm_script( # SLURM parameters job_name=config.slurm.job_name, partition=config.slurm.partition, + qos=config.slurm.qos, num_nodes=config.slurm.num_nodes, gpus_per_node=config.slurm.gpus_per_node, cpus_per_task=config.slurm.cpus_per_task, cpus_per_gpu=config.slurm.cpus_per_gpu, + mem=config.slurm.mem, wall_time=config.slurm.wall_time, signal_time_seconds=config.slurm.signal_time_seconds, max_failures=config.slurm.max_failures, diff --git a/tests/test_slurm_render.py b/tests/test_slurm_render.py new file mode 100644 index 0000000..35ce4fc --- /dev/null +++ b/tests/test_slurm_render.py @@ -0,0 +1,71 @@ +"""Automated tests for SLURM template rendering.""" + +import pytest + +from post_training.config import PostTrainingConfig +from post_training.slurm.launcher import ( + render_llamafactory_slurm_script, + render_trl_container_slurm_script, +) + + +@pytest.fixture +def config(): + cfg = PostTrainingConfig() + cfg.container.image = "/shared/containers/llamafactory.sif" + cfg.container.bind_mounts = ["/scratch:/scratch"] + cfg.container.env_file = "/shared/env/cluster.env" + return cfg + + +def test_llamafactory_qos_mem_rendered(tmp_path, config): + """qos and mem appear as #SBATCH directives when set.""" + config.slurm.qos = "boost_qos_dbg" + config.slurm.mem = "64G" + run_dir = tmp_path / "outputs" / "my-run" + run_dir.mkdir(parents=True) + + content = render_llamafactory_slurm_script(config, run_dir).read_text() + + assert "#SBATCH --qos=boost_qos_dbg" in content + assert "#SBATCH --mem=64G" in content + + +def test_llamafactory_qos_mem_absent_when_none(tmp_path, config): + """qos and mem directives are suppressed when not set.""" + run_dir = tmp_path / "outputs" / "my-run" + run_dir.mkdir(parents=True) + + content = render_llamafactory_slurm_script(config, run_dir).read_text() + + assert "--qos" not in content + assert "--mem" not in content + + +# --------------------------------------------------------------------------- +# TRL container template +# --------------------------------------------------------------------------- + + +def test_trl_container_qos_mem_rendered(tmp_path, config): + """qos and mem appear as #SBATCH directives when set.""" + config.slurm.qos = "boost_qos_dbg" + config.slurm.mem = "64G" + run_dir = tmp_path / "outputs" / "my-run" + run_dir.mkdir(parents=True) + + content = render_trl_container_slurm_script(config, run_dir, "configs/trl/sft.yaml").read_text() + + assert "#SBATCH --qos=boost_qos_dbg" in content + assert "#SBATCH --mem=64G" in content + + +def test_trl_container_qos_mem_absent_when_none(tmp_path, config): + """qos and mem directives are suppressed when not set.""" + run_dir = tmp_path / "outputs" / "my-run" + run_dir.mkdir(parents=True) + + content = render_trl_container_slurm_script(config, run_dir, "configs/trl/sft.yaml").read_text() + + assert "--qos" not in content + assert "--mem" not in content