diff --git a/src/post_training/slurm/job_llamafactory.sh.jinja b/src/post_training/slurm/job_llamafactory.sh.jinja index 832bf86..a532b6e 100644 --- a/src/post_training/slurm/job_llamafactory.sh.jinja +++ b/src/post_training/slurm/job_llamafactory.sh.jinja @@ -9,6 +9,7 @@ #SBATCH --nodes={{ num_nodes }} #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:{{ gpus_per_node }} +#SBATCH --cpus-per-task={{ cpus_per_task }} {% if cpus_per_gpu is integer and cpus_per_gpu > 0 -%} #SBATCH --cpus-per-gpu={{ cpus_per_gpu }} {% endif -%} diff --git a/src/post_training/slurm/job_trl_container.sh.jinja b/src/post_training/slurm/job_trl_container.sh.jinja index 7192667..d9cbb89 100644 --- a/src/post_training/slurm/job_trl_container.sh.jinja +++ b/src/post_training/slurm/job_trl_container.sh.jinja @@ -9,6 +9,7 @@ #SBATCH --nodes={{ num_nodes }} #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:{{ gpus_per_node }} +#SBATCH --cpus-per-task={{ cpus_per_task }} {% if cpus_per_gpu is integer and cpus_per_gpu > 0 -%} #SBATCH --cpus-per-gpu={{ cpus_per_gpu }} {% endif -%} diff --git a/src/post_training/slurm/launcher.py b/src/post_training/slurm/launcher.py index d23e15a..d16b4a2 100644 --- a/src/post_training/slurm/launcher.py +++ b/src/post_training/slurm/launcher.py @@ -109,6 +109,7 @@ def render_trl_container_slurm_script( partition=config.slurm.partition, num_nodes=config.slurm.num_nodes, gpus_per_node=config.slurm.gpus_per_node, + cpus_per_task=config.slurm.cpus_per_task, cpus_per_gpu=config.slurm.cpus_per_gpu, wall_time=config.slurm.wall_time, signal_time_seconds=config.slurm.signal_time_seconds, @@ -156,6 +157,7 @@ def render_llamafactory_slurm_script( partition=config.slurm.partition, num_nodes=config.slurm.num_nodes, gpus_per_node=config.slurm.gpus_per_node, + cpus_per_task=config.slurm.cpus_per_task, cpus_per_gpu=config.slurm.cpus_per_gpu, wall_time=config.slurm.wall_time, signal_time_seconds=config.slurm.signal_time_seconds,