Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added examples/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion examples/run_pretrain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ fi
setup_pythonpath() {
local site_packages
site_packages=$(python -c "import sysconfig; print(sysconfig.get_paths()['purelib'])")
export PYTHONPATH="${site_packages}:${PRIMUS_PATH}:$:${PYTHONPATH}"
export PYTHONPATH="${PRIMUS_PATH}:${site_packages}:${PYTHONPATH}"
}

setup_pythonpath
Expand Down
Empty file added examples/scripts/__init__.py
Empty file.
12 changes: 9 additions & 3 deletions primus/modules/trainer/megatron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1957,7 +1957,9 @@ def training_log(
# Note(wenx): If we want to collect rocm-smi memory information for the first two iterations,
# place the collection before the timer to minimize its impact on latency measurements for iterations ≥ 3.
if args.log_throughput:
if args.use_rocm_mem_info or iteration in args.use_rocm_mem_info_iters:
if args.use_rocm_mem_info or (
args.use_rocm_mem_info_iters is not None and iteration in args.use_rocm_mem_info_iters
):
rocm_total_mem, rocm_used_mem, rocm_free_mem = get_rocm_smi_mem_info(
self.module_local_rank
)
Expand Down Expand Up @@ -2015,7 +2017,9 @@ def training_log(
log_string += f"{hip_free_mem/1024/1024/1024:.2f}GiB/"
log_string += f"{hip_total_mem/1024/1024/1024:.2f}GiB/{hip_mem_usage*100:.2f}% |"

if args.use_rocm_mem_info or iteration in args.use_rocm_mem_info_iters:
if args.use_rocm_mem_info or (
args.use_rocm_mem_info_iters is not None and iteration in args.use_rocm_mem_info_iters
):
rocm_mem_usage = rocm_used_mem / rocm_total_mem

# get the max rocm_mem_usage
Expand Down Expand Up @@ -2052,7 +2056,9 @@ def training_log(
f"{statistics.mean(self.recent_token_throughputs):.1f} |"
)
if args.log_timers_to_tensorboard:
if args.use_rocm_mem_info or iteration in args.use_rocm_mem_info_iters:
if args.use_rocm_mem_info or (
args.use_rocm_mem_info_iters is not None and iteration in args.use_rocm_mem_info_iters
):
mem_collector = "rocm"
used_mem, free_mem, total_mem, mem_usage = (
rocm_used_mem,
Expand Down