Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trl = [
"tensorboard",
]
flash-attn-2 = ["flash-attn"]
dev = ["pre-commit"]
dev = ["pre-commit", "pytest"]

[tool.uv]
extra-index-url = ["https://download.pytorch.org/whl/cu126"]
Expand Down
3 changes: 3 additions & 0 deletions src/post_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class DPOMethodConfig:
class CheckpointingConfig:
"""Checkpoint saving strategy."""

save_strategy: str = "steps"
save_steps: int = 500
save_total_limit: int = 2
# When set to ``None`` (or a non-positive value via CLI overrides), inference
Expand Down Expand Up @@ -180,10 +181,12 @@ class SlurmConfig:
"""SLURM job scheduler parameters."""

partition: str = "gpu"
qos: str | None = None
num_nodes: int = 1
gpus_per_node: int = 4
cpus_per_task: int = 32
cpus_per_gpu: int | None = None
mem: str | None = None
wall_time: str = "02:00:00"
job_name: str = "post-training"
signal_time_seconds: int = 300
Expand Down
2 changes: 1 addition & 1 deletion src/post_training/methods/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def build_common_training_kwargs(
bf16=t.bf16,
seed=t.seed,
# Checkpointing
save_strategy="steps",
save_strategy=config.checkpointing.save_strategy,
save_steps=config.checkpointing.save_steps,
save_total_limit=config.checkpointing.save_total_limit,
# Logging
Expand Down
6 changes: 6 additions & 0 deletions src/post_training/slurm/job.sh.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@

#SBATCH --job-name={{ job_name }}
#SBATCH --partition={{ partition }}
{% if qos -%}
#SBATCH --qos={{ qos }}
{% endif -%}
#SBATCH --nodes={{ num_nodes }}
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:{{ gpus_per_node }}
#SBATCH --cpus-per-task={{ cpus_per_task }}
{% if cpus_per_gpu is integer and cpus_per_gpu > 0 -%}
#SBATCH --cpus-per-gpu={{ cpus_per_gpu }}
{% endif -%}
{% if mem -%}
#SBATCH --mem={{ mem }}
{% endif -%}
#SBATCH --time={{ wall_time }}
#SBATCH --signal=B:USR1@{{ signal_time_seconds }}
#SBATCH --output={{ run_dir }}/slurm/slurm-%j.out
Expand Down
6 changes: 6 additions & 0 deletions src/post_training/slurm/job_llamafactory.sh.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@

#SBATCH --job-name={{ job_name }}
#SBATCH --partition={{ partition }}
{% if qos -%}
#SBATCH --qos={{ qos }}
{% endif -%}
#SBATCH --nodes={{ num_nodes }}
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:{{ gpus_per_node }}
#SBATCH --cpus-per-task={{ cpus_per_task }}
{% if cpus_per_gpu is integer and cpus_per_gpu > 0 -%}
#SBATCH --cpus-per-gpu={{ cpus_per_gpu }}
{% endif -%}
{% if mem -%}
#SBATCH --mem={{ mem }}
{% endif -%}
#SBATCH --time={{ wall_time }}
#SBATCH --signal=B:USR1@{{ signal_time_seconds }}
#SBATCH --output={{ run_dir }}/slurm/slurm-%j.out
Expand Down
6 changes: 6 additions & 0 deletions src/post_training/slurm/job_trl_container.sh.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@

#SBATCH --job-name={{ job_name }}
#SBATCH --partition={{ partition }}
{% if qos -%}
#SBATCH --qos={{ qos }}
{% endif -%}
#SBATCH --nodes={{ num_nodes }}
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:{{ gpus_per_node }}
#SBATCH --cpus-per-task={{ cpus_per_task }}
{% if cpus_per_gpu is integer and cpus_per_gpu > 0 -%}
#SBATCH --cpus-per-gpu={{ cpus_per_gpu }}
{% endif -%}
{% if mem -%}
#SBATCH --mem={{ mem }}
{% endif -%}
#SBATCH --time={{ wall_time }}
#SBATCH --signal=B:USR1@{{ signal_time_seconds }}
#SBATCH --output={{ run_dir }}/slurm/slurm-%j.out
Expand Down
6 changes: 6 additions & 0 deletions src/post_training/slurm/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ def render_trl_slurm_script(
# SLURM parameters
job_name=config.slurm.job_name,
partition=config.slurm.partition,
qos=config.slurm.qos,
num_nodes=config.slurm.num_nodes,
gpus_per_node=config.slurm.gpus_per_node,
cpus_per_task=config.slurm.cpus_per_task,
cpus_per_gpu=config.slurm.cpus_per_gpu,
mem=config.slurm.mem,
wall_time=config.slurm.wall_time,
signal_time_seconds=config.slurm.signal_time_seconds,
max_failures=config.slurm.max_failures,
Expand Down Expand Up @@ -107,10 +109,12 @@ def render_trl_container_slurm_script(
# SLURM parameters
job_name=config.slurm.job_name,
partition=config.slurm.partition,
qos=config.slurm.qos,
num_nodes=config.slurm.num_nodes,
gpus_per_node=config.slurm.gpus_per_node,
cpus_per_task=config.slurm.cpus_per_task,
cpus_per_gpu=config.slurm.cpus_per_gpu,
mem=config.slurm.mem,
wall_time=config.slurm.wall_time,
signal_time_seconds=config.slurm.signal_time_seconds,
max_failures=config.slurm.max_failures,
Expand Down Expand Up @@ -155,10 +159,12 @@ def render_llamafactory_slurm_script(
# SLURM parameters
job_name=config.slurm.job_name,
partition=config.slurm.partition,
qos=config.slurm.qos,
num_nodes=config.slurm.num_nodes,
gpus_per_node=config.slurm.gpus_per_node,
cpus_per_task=config.slurm.cpus_per_task,
cpus_per_gpu=config.slurm.cpus_per_gpu,
mem=config.slurm.mem,
wall_time=config.slurm.wall_time,
signal_time_seconds=config.slurm.signal_time_seconds,
max_failures=config.slurm.max_failures,
Expand Down
71 changes: 71 additions & 0 deletions tests/test_slurm_render.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Automated tests for SLURM template rendering."""

import pytest

from post_training.config import PostTrainingConfig
from post_training.slurm.launcher import (
render_llamafactory_slurm_script,
render_trl_container_slurm_script,
)


@pytest.fixture
def config():
cfg = PostTrainingConfig()
cfg.container.image = "/shared/containers/llamafactory.sif"
cfg.container.bind_mounts = ["/scratch:/scratch"]
cfg.container.env_file = "/shared/env/cluster.env"
return cfg


def test_llamafactory_qos_mem_rendered(tmp_path, config):
"""qos and mem appear as #SBATCH directives when set."""
config.slurm.qos = "boost_qos_dbg"
config.slurm.mem = "64G"
run_dir = tmp_path / "outputs" / "my-run"
run_dir.mkdir(parents=True)

content = render_llamafactory_slurm_script(config, run_dir).read_text()

assert "#SBATCH --qos=boost_qos_dbg" in content
assert "#SBATCH --mem=64G" in content


def test_llamafactory_qos_mem_absent_when_none(tmp_path, config):
"""qos and mem directives are suppressed when not set."""
run_dir = tmp_path / "outputs" / "my-run"
run_dir.mkdir(parents=True)

content = render_llamafactory_slurm_script(config, run_dir).read_text()

assert "--qos" not in content
assert "--mem" not in content


# ---------------------------------------------------------------------------
# TRL container template
# ---------------------------------------------------------------------------


def test_trl_container_qos_mem_rendered(tmp_path, config):
"""qos and mem appear as #SBATCH directives when set."""
config.slurm.qos = "boost_qos_dbg"
config.slurm.mem = "64G"
run_dir = tmp_path / "outputs" / "my-run"
run_dir.mkdir(parents=True)

content = render_trl_container_slurm_script(config, run_dir, "configs/trl/sft.yaml").read_text()

assert "#SBATCH --qos=boost_qos_dbg" in content
assert "#SBATCH --mem=64G" in content


def test_trl_container_qos_mem_absent_when_none(tmp_path, config):
"""qos and mem directives are suppressed when not set."""
run_dir = tmp_path / "outputs" / "my-run"
run_dir.mkdir(parents=True)

content = render_trl_container_slurm_script(config, run_dir, "configs/trl/sft.yaml").read_text()

assert "--qos" not in content
assert "--mem" not in content
Loading