Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/post_training/slurm/job_llamafactory.sh.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ srun --export=ALL --wait=60 --kill-on-bad-exit=1 \

# Prevent host Python/PATH from interfering with container
export PATH=\"/usr/local/bin:/usr/bin:/bin\"
export PYTHONPATH=\"\"
export PYTHONPATH=\"{{ repo_dir }}/src\"
export PYTHONNOUSERSITE=1

export NODE_RANK=\"\$SLURM_NODEID\"
Expand Down
7 changes: 5 additions & 2 deletions src/post_training/slurm/job_trl_container.sh.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ export WORLD_SIZE=$(( SLURM_NNODES * GPUS_PER_NODE ))
# NCCL tuning for multi-node stability
export NCCL_IB_TIMEOUT=120
export NCCL_DEBUG=INFO
export CUDA_DEVICE_MAX_CONNECTIONS=1
# CUDA_DEVICE_MAX_CONNECTIONS=1 is a Megatron-LM tensor-parallel flag that
# serializes CUDA streams and hurts ZeRO-2 overlap_comm — do not set it.
export OMP_NUM_THREADS=1

echo "=========================================="
Expand Down Expand Up @@ -71,7 +72,7 @@ srun --export=ALL --wait=60 --kill-on-bad-exit=1 \

# Prevent host Python/PATH from interfering with container
export PATH=\"/usr/local/bin:/usr/bin:/bin\"
export PYTHONPATH=\"\"
export PYTHONPATH=\"{{ repo_dir }}/src\"
export PYTHONNOUSERSITE=1

export NODE_RANK=\"\$SLURM_NODEID\"
Expand All @@ -87,6 +88,8 @@ srun --export=ALL --wait=60 --kill-on-bad-exit=1 \

cd {{ repo_dir }}

export WANDB_DIR={{ run_dir }}

accelerate launch \
--num_machines \$NNODES \
--num_processes $WORLD_SIZE \
Expand Down
4 changes: 2 additions & 2 deletions src/post_training/slurm/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def render_trl_container_slurm_script(
wall_time=config.slurm.wall_time,
signal_time_seconds=config.slurm.signal_time_seconds,
max_failures=config.slurm.max_failures,
run_dir=str(run_dir),
run_dir=str(run_dir.resolve()),
config_path=config_path,
# Accelerate flags
mixed_precision=config.accelerate.mixed_precision,
Expand Down Expand Up @@ -162,7 +162,7 @@ def render_llamafactory_slurm_script(
wall_time=config.slurm.wall_time,
signal_time_seconds=config.slurm.signal_time_seconds,
max_failures=config.slurm.max_failures,
run_dir=str(run_dir),
run_dir=str(run_dir.resolve()),
# Container
container_image=config.container.image,
bind_mounts=config.container.bind_mounts,
Expand Down
Loading