diff --git a/src/post_training/slurm/job_llamafactory.sh.jinja b/src/post_training/slurm/job_llamafactory.sh.jinja index a532b6e..e214165 100644 --- a/src/post_training/slurm/job_llamafactory.sh.jinja +++ b/src/post_training/slurm/job_llamafactory.sh.jinja @@ -70,7 +70,7 @@ srun --export=ALL --wait=60 --kill-on-bad-exit=1 \ # Prevent host Python/PATH from interfering with container export PATH=\"/usr/local/bin:/usr/bin:/bin\" - export PYTHONPATH=\"\" + export PYTHONPATH=\"{{ repo_dir }}/src\" export PYTHONNOUSERSITE=1 export NODE_RANK=\"\$SLURM_NODEID\" diff --git a/src/post_training/slurm/job_trl_container.sh.jinja b/src/post_training/slurm/job_trl_container.sh.jinja index 0dda5ef..b085291 100644 --- a/src/post_training/slurm/job_trl_container.sh.jinja +++ b/src/post_training/slurm/job_trl_container.sh.jinja @@ -35,7 +35,8 @@ export WORLD_SIZE=$(( SLURM_NNODES * GPUS_PER_NODE )) # NCCL tuning for multi-node stability export NCCL_IB_TIMEOUT=120 export NCCL_DEBUG=INFO -export CUDA_DEVICE_MAX_CONNECTIONS=1 +# CUDA_DEVICE_MAX_CONNECTIONS=1 is a Megatron-LM tensor-parallel flag that +# serializes CUDA streams and hurts ZeRO-2 overlap_comm — do not set it. export OMP_NUM_THREADS=1 echo "==========================================" @@ -71,7 +72,7 @@ srun --export=ALL --wait=60 --kill-on-bad-exit=1 \ # Prevent host Python/PATH from interfering with container export PATH=\"/usr/local/bin:/usr/bin:/bin\" - export PYTHONPATH=\"\" + export PYTHONPATH=\"{{ repo_dir }}/src\" export PYTHONNOUSERSITE=1 export NODE_RANK=\"\$SLURM_NODEID\" @@ -87,6 +88,8 @@ srun --export=ALL --wait=60 --kill-on-bad-exit=1 \ cd {{ repo_dir }} + export WANDB_DIR={{ run_dir }} + accelerate launch \ --num_machines \$NNODES \ --num_processes $WORLD_SIZE \ diff --git a/src/post_training/slurm/launcher.py b/src/post_training/slurm/launcher.py index d16b4a2..17e989a 100644 --- a/src/post_training/slurm/launcher.py +++ b/src/post_training/slurm/launcher.py @@ -114,7 +114,7 @@ def render_trl_container_slurm_script( wall_time=config.slurm.wall_time, signal_time_seconds=config.slurm.signal_time_seconds, max_failures=config.slurm.max_failures, - run_dir=str(run_dir), + run_dir=str(run_dir.resolve()), config_path=config_path, # Accelerate flags mixed_precision=config.accelerate.mixed_precision, @@ -162,7 +162,7 @@ def render_llamafactory_slurm_script( wall_time=config.slurm.wall_time, signal_time_seconds=config.slurm.signal_time_seconds, max_failures=config.slurm.max_failures, - run_dir=str(run_dir), + run_dir=str(run_dir.resolve()), # Container container_image=config.container.image, bind_mounts=config.container.bind_mounts,