From a1f1853129cb2244c771085d3229b1fec9db713d Mon Sep 17 00:00:00 2001 From: Arjun Krishnakumar Date: Fri, 17 Apr 2026 22:23:05 +0200 Subject: [PATCH] fix: pass cpus-per-task into container job script --- src/post_training/slurm/job_llamafactory.sh.jinja | 1 + src/post_training/slurm/job_trl_container.sh.jinja | 1 + src/post_training/slurm/launcher.py | 2 ++ 3 files changed, 4 insertions(+) diff --git a/src/post_training/slurm/job_llamafactory.sh.jinja b/src/post_training/slurm/job_llamafactory.sh.jinja index 832bf86..a532b6e 100644 --- a/src/post_training/slurm/job_llamafactory.sh.jinja +++ b/src/post_training/slurm/job_llamafactory.sh.jinja @@ -9,6 +9,7 @@ #SBATCH --nodes={{ num_nodes }} #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:{{ gpus_per_node }} +#SBATCH --cpus-per-task={{ cpus_per_task }} {% if cpus_per_gpu is integer and cpus_per_gpu > 0 -%} #SBATCH --cpus-per-gpu={{ cpus_per_gpu }} {% endif -%} diff --git a/src/post_training/slurm/job_trl_container.sh.jinja b/src/post_training/slurm/job_trl_container.sh.jinja index 7192667..d9cbb89 100644 --- a/src/post_training/slurm/job_trl_container.sh.jinja +++ b/src/post_training/slurm/job_trl_container.sh.jinja @@ -9,6 +9,7 @@ #SBATCH --nodes={{ num_nodes }} #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:{{ gpus_per_node }} +#SBATCH --cpus-per-task={{ cpus_per_task }} {% if cpus_per_gpu is integer and cpus_per_gpu > 0 -%} #SBATCH --cpus-per-gpu={{ cpus_per_gpu }} {% endif -%} diff --git a/src/post_training/slurm/launcher.py b/src/post_training/slurm/launcher.py index d23e15a..d16b4a2 100644 --- a/src/post_training/slurm/launcher.py +++ b/src/post_training/slurm/launcher.py @@ -109,6 +109,7 @@ def render_trl_container_slurm_script( partition=config.slurm.partition, num_nodes=config.slurm.num_nodes, gpus_per_node=config.slurm.gpus_per_node, + cpus_per_task=config.slurm.cpus_per_task, cpus_per_gpu=config.slurm.cpus_per_gpu, wall_time=config.slurm.wall_time, signal_time_seconds=config.slurm.signal_time_seconds, @@ -156,6 +157,7 @@ def render_llamafactory_slurm_script( partition=config.slurm.partition, num_nodes=config.slurm.num_nodes, gpus_per_node=config.slurm.gpus_per_node, + cpus_per_task=config.slurm.cpus_per_task, cpus_per_gpu=config.slurm.cpus_per_gpu, wall_time=config.slurm.wall_time, signal_time_seconds=config.slurm.signal_time_seconds,