From b76b7866fde80d068b4eee1954ca8cd9df4bfad9 Mon Sep 17 00:00:00 2001 From: wenxie-amd Date: Tue, 6 Jan 2026 09:08:48 +0000 Subject: [PATCH] fix training issue on NVIDIA GPUs --- examples/__init__.py | 0 examples/run_pretrain.sh | 2 +- examples/scripts/__init__.py | 0 primus/modules/trainer/megatron/trainer.py | 12 +++++++++--- 4 files changed, 10 insertions(+), 4 deletions(-) create mode 100644 examples/__init__.py create mode 100644 examples/scripts/__init__.py diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/run_pretrain.sh b/examples/run_pretrain.sh index 197e0b084..48fb47b3c 100755 --- a/examples/run_pretrain.sh +++ b/examples/run_pretrain.sh @@ -506,7 +506,7 @@ fi setup_pythonpath() { local site_packages site_packages=$(python -c "import sysconfig; print(sysconfig.get_paths()['purelib'])") - export PYTHONPATH="${site_packages}:${PRIMUS_PATH}:$:${PYTHONPATH}" + export PYTHONPATH="${PRIMUS_PATH}:${site_packages}:${PYTHONPATH}" } setup_pythonpath diff --git a/examples/scripts/__init__.py b/examples/scripts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/primus/modules/trainer/megatron/trainer.py b/primus/modules/trainer/megatron/trainer.py index ff33990f1..aa7eb9e8d 100644 --- a/primus/modules/trainer/megatron/trainer.py +++ b/primus/modules/trainer/megatron/trainer.py @@ -1957,7 +1957,9 @@ def training_log( # Note(wenx): If we want to collect rocm-smi memory information for the first two iterations, # place the collection before the timer to minimize its impact on latency measurements for iterations ≥ 3. if args.log_throughput: - if args.use_rocm_mem_info or iteration in args.use_rocm_mem_info_iters: + if args.use_rocm_mem_info or ( + args.use_rocm_mem_info_iters is not None and iteration in args.use_rocm_mem_info_iters + ): rocm_total_mem, rocm_used_mem, rocm_free_mem = get_rocm_smi_mem_info( self.module_local_rank ) @@ -2015,7 +2017,9 @@ def training_log( log_string += f"{hip_free_mem/1024/1024/1024:.2f}GiB/" log_string += f"{hip_total_mem/1024/1024/1024:.2f}GiB/{hip_mem_usage*100:.2f}% |" - if args.use_rocm_mem_info or iteration in args.use_rocm_mem_info_iters: + if args.use_rocm_mem_info or ( + args.use_rocm_mem_info_iters is not None and iteration in args.use_rocm_mem_info_iters + ): rocm_mem_usage = rocm_used_mem / rocm_total_mem # get the max rocm_mem_usage @@ -2052,7 +2056,9 @@ def training_log( f"{statistics.mean(self.recent_token_throughputs):.1f} |" ) if args.log_timers_to_tensorboard: - if args.use_rocm_mem_info or iteration in args.use_rocm_mem_info_iters: + if args.use_rocm_mem_info or ( + args.use_rocm_mem_info_iters is not None and iteration in args.use_rocm_mem_info_iters + ): mem_collector = "rocm" used_mem, free_mem, total_mem, mem_usage = ( rocm_used_mem,