From 7dbd2421d91ac40b987149f5fea5aa1a23d9b177 Mon Sep 17 00:00:00 2001 From: guangphu Date: Thu, 18 Dec 2025 08:52:47 +0000 Subject: [PATCH 1/6] feat: Add MLflow artifact upload for traces and logs - Add mlflow_artifacts.py with functions to collect and upload trace/log files - Add upload_mlflow_artifacts() wrapper in global_vars.py - Integrate artifact upload in trainer.py before MLflow run ends - Add mlflow_upload_traces and mlflow_upload_logs config options - Add unique timestamp-based output directories for multi-node consistency - Pass MLflow environment variables through Docker container --- examples/run_local_pretrain.sh | 11 + examples/run_pretrain.sh | 21 +- examples/run_slurm_pretrain.sh | 19 +- .../backends/megatron/training/global_vars.py | 55 ++++ .../megatron/training/mlflow_artifacts.py | 246 ++++++++++++++++++ .../megatron/primus_megatron_module.yaml | 2 + primus/modules/trainer/megatron/trainer.py | 14 + 7 files changed, 365 insertions(+), 3 deletions(-) create mode 100644 primus/backends/megatron/training/mlflow_artifacts.py diff --git a/examples/run_local_pretrain.sh b/examples/run_local_pretrain.sh index 3e4ea341b..48f612779 100755 --- a/examples/run_local_pretrain.sh +++ b/examples/run_local_pretrain.sh @@ -93,6 +93,11 @@ ENV_ARGS+=("--env" "HF_TOKEN") ENV_ARGS+=("--env" "WANDB_API_KEY") ENV_ARGS+=("--env" "ENABLE_NUMA_BINDING") ENV_ARGS+=("--env" "HSA_KERNARG_POOL_SIZE") +# MLflow environment variables +ENV_ARGS+=("--env" "DATABRICKS_TOKEN") +ENV_ARGS+=("--env" "DATABRICKS_HOST") +ENV_ARGS+=("--env" "MLFLOW_TRACKING_URI") +ENV_ARGS+=("--env" "MLFLOW_REGISTRY_URI") echo "ENV_ARGS: ${ENV_ARGS[*]}" HOSTNAME=$(hostname) @@ -158,6 +163,12 @@ docker_podman_proxy run --rm \ --env GPUS_PER_NODE \ --env DATA_PATH \ --env TRAIN_LOG \ + --env PRIMUS_WORKSPACE \ + --env PRIMUS_EXP_NAME \ + --env TIMESTAMP \ + --env LOG_DIR \ + --env PRIMUS_TEAM \ + --env PRIMUS_USER \ --env HSA_NO_SCRATCH_RECLAIM \ --env NVTE_CK_USES_BWD_V3 \ --env GPU_MAX_HW_QUEUES \ diff --git a/examples/run_pretrain.sh b/examples/run_pretrain.sh index 9053b43ab..a936df288 100755 --- a/examples/run_pretrain.sh +++ b/examples/run_pretrain.sh @@ -123,11 +123,28 @@ fi # export AITER_JIT_DIR="${TMP_BUILD_DIR}/${CACHE_TAG}_aiter_cache" -TRAIN_LOG=${TRAIN_LOG:-"output/log_mp_pretrain_$(basename "$EXP" .yaml).txt"} +# Extract model name from EXP config file path (e.g., deepseek_v2_lite-pretrain.yaml -> deepseek_v2_lite-pretrain) +MODEL_NAME=$(basename "${EXP}" .yaml) + +# Only generate new timestamp/paths if not already set by run_slurm_pretrain.sh +# This ensures: 1) single-node gets fresh timestamp, 2) multi-node shares same directory +if [ -z "${PRIMUS_EXP_NAME}" ]; then + TIMESTAMP=$(date +%Y%m%d_%H%M%S) + export PRIMUS_WORKSPACE=${PRIMUS_WORKSPACE:-"./output"} + export PRIMUS_EXP_NAME="${MODEL_NAME}_${TIMESTAMP}" + export LOG_DIR="${PRIMUS_WORKSPACE}/${PRIMUS_EXP_NAME}" +fi +# Clear work_group and user_name to simplify path: workspace/exp_name +export PRIMUS_TEAM="" +export PRIMUS_USER="" + +mkdir -p "$LOG_DIR" +TRAIN_LOG="${LOG_DIR}/log_mp_pretrain.txt" LOG_INFO_RANK0 "==========Training info==========" LOG_INFO_RANK0 "EXP: $EXP" -LOG_INFO_RANK0 "EXP: $BACKEND" +LOG_INFO_RANK0 "BACKEND: $BACKEND" +LOG_INFO_RANK0 "OUTPUT_DIR: ${LOG_DIR}" LOG_INFO_RANK0 "TRAIN_LOG: $TRAIN_LOG" LOG_INFO_RANK0 "PRIMUS_PATH: $PRIMUS_PATH" LOG_INFO_RANK0 "DATA_PATH: $DATA_PATH" diff --git a/examples/run_slurm_pretrain.sh b/examples/run_slurm_pretrain.sh index 04da35a4d..7e6523239 100755 --- a/examples/run_slurm_pretrain.sh +++ b/examples/run_slurm_pretrain.sh @@ -34,7 +34,22 @@ export NNODES=${NNODES:-1} SCRIPT_DIR=$(dirname "$(realpath "${BASH_SOURCE[0]}")") -export LOG_DIR=${LOG_DIR:-"./output"} +# -------------------- Unique Output Directory Per Run -------------------- +# Extract model name from EXP config file path (e.g., deepseek_v2_lite-pretrain.yaml -> deepseek_v2_lite-pretrain) +MODEL_NAME=$(basename "${EXP:-unknown}" .yaml) +# Export TIMESTAMP so all nodes use the same value (prevents multi-node race condition) +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +export TIMESTAMP + +# Set PRIMUS environment variables for output paths +BASE_LOG_DIR=${LOG_DIR:-"./output"} +export PRIMUS_WORKSPACE="${BASE_LOG_DIR}" +export PRIMUS_EXP_NAME="${MODEL_NAME}_${TIMESTAMP}" +export LOG_DIR="${PRIMUS_WORKSPACE}/${PRIMUS_EXP_NAME}" +# Clear work_group and user_name to simplify path: workspace/exp_name +export PRIMUS_TEAM="" +export PRIMUS_USER="" + LOG_FILE="${LOG_DIR}/log_slurm_pretrain.txt" mkdir -p "$LOG_DIR" @@ -52,6 +67,8 @@ srun -N "${NNODES}" \ echo \"SLURM_GPUS_ON_NODE: \${SLURM_GPUS_ON_NODE}\" echo \"\" fi + # Log TIMESTAMP on each node to verify consistency across nodes + echo \"[Node \$SLURM_NODEID] TIMESTAMP=\${TIMESTAMP}\" export MASTER_ADDR=\${node_array[0]} export MASTER_PORT=\${MASTER_PORT} export NNODES=\${SLURM_NNODES} diff --git a/primus/backends/megatron/training/global_vars.py b/primus/backends/megatron/training/global_vars.py index b23016d46..11c34d461 100644 --- a/primus/backends/megatron/training/global_vars.py +++ b/primus/backends/megatron/training/global_vars.py @@ -8,8 +8,11 @@ from primus.modules.module_utils import debug_rank_0 +from .mlflow_artifacts import upload_artifacts_to_mlflow + _GLOBAL_ARGS = None _GLOBAL_MLFLOW_WRITER = None +_GLOBAL_EXP_ROOT_PATH = None def set_args(args): @@ -23,6 +26,17 @@ def get_args(): return _GLOBAL_ARGS +def set_exp_root_path(exp_root_path): + """Set the experiment root path for artifact logging.""" + global _GLOBAL_EXP_ROOT_PATH + _GLOBAL_EXP_ROOT_PATH = exp_root_path + + +def get_exp_root_path(): + """Return experiment root path. Can be None.""" + return _GLOBAL_EXP_ROOT_PATH + + def get_mlflow_writer(): """Return mlflow writer. It can be None so no need to check if it is initialized.""" @@ -62,14 +76,51 @@ def _set_mlflow_writer(args): _GLOBAL_MLFLOW_WRITER = mlflow +def upload_mlflow_artifacts( + upload_traces: bool = True, + upload_logs: bool = True, +): + """ + Upload trace files and log files to MLflow as artifacts. + + This should be called before ending the MLflow run to ensure all + artifacts are uploaded. Only the rank that initialized MLflow + (typically rank world_size - 1) should call this. + + Args: + upload_traces: Whether to upload profiler trace files + upload_logs: Whether to upload training log files + + Returns: + Dictionary with counts of uploaded files, or None if MLflow is not enabled + """ + mlflow_writer = get_mlflow_writer() + if mlflow_writer is None: + return None + + args = get_args() + exp_root_path = get_exp_root_path() + tensorboard_dir = getattr(args, "tensorboard_dir", None) + + return upload_artifacts_to_mlflow( + mlflow_writer=mlflow_writer, + tensorboard_dir=tensorboard_dir, + exp_root_path=exp_root_path, + upload_traces=upload_traces, + upload_logs=upload_logs, + ) + + def unset_global_variables(): """Unset global vars.""" global _GLOBAL_ARGS global _GLOBAL_MLFLOW_WRITER + global _GLOBAL_EXP_ROOT_PATH _GLOBAL_ARGS = None _GLOBAL_MLFLOW_WRITER = None + _GLOBAL_EXP_ROOT_PATH = None def _ensure_var_is_initialized(var, name): @@ -84,4 +135,8 @@ def _ensure_var_is_not_initialized(var, name): def destroy_global_vars(): global _GLOBAL_ARGS + global _GLOBAL_MLFLOW_WRITER + global _GLOBAL_EXP_ROOT_PATH _GLOBAL_ARGS = None + _GLOBAL_MLFLOW_WRITER = None + _GLOBAL_EXP_ROOT_PATH = None diff --git a/primus/backends/megatron/training/mlflow_artifacts.py b/primus/backends/megatron/training/mlflow_artifacts.py new file mode 100644 index 000000000..67caa0e62 --- /dev/null +++ b/primus/backends/megatron/training/mlflow_artifacts.py @@ -0,0 +1,246 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +MLflow Artifact Logging Utilities + +This module provides functions to upload trace files and log files to MLflow +when MLflow tracking is enabled. + +Features: +- Upload profiler trace files from all profiled ranks (including multi-node) +- Upload log files from all levels and all ranks +- Supports both local and distributed training scenarios +""" + +import glob +import os +from typing import Optional + +from primus.modules.module_utils import log_rank_0, warning_rank_0 + + +def _get_all_trace_files(tensorboard_dir: str) -> list: + """ + Find all profiler trace files in the tensorboard directory. + + Trace files are typically named like: + - *.pt.trace.json + - *.pt.trace.json.gz + + Args: + tensorboard_dir: Path to the tensorboard directory containing trace files + + Returns: + List of paths to trace files + """ + if not tensorboard_dir or not os.path.exists(tensorboard_dir): + return [] + + trace_files = [] + # Look for PyTorch profiler trace files (both compressed and uncompressed) + patterns = ["*.pt.trace.json", "*.pt.trace.json.gz"] + for pattern in patterns: + trace_files.extend(glob.glob(os.path.join(tensorboard_dir, pattern))) + trace_files.extend(glob.glob(os.path.join(tensorboard_dir, "**", pattern), recursive=True)) + + # Remove duplicates while preserving order + seen = set() + unique_files = [] + for f in trace_files: + if f not in seen: + seen.add(f) + unique_files.append(f) + + return unique_files + + +def _get_all_log_files(exp_root_path: str) -> list: + """ + Find all log files in the experiment logs directory. + + Log files are organized as: + - {exp_root_path}/logs/master/master-*.log + - {exp_root_path}/logs/{module_name}/rank-{rank}/*.log + + Args: + exp_root_path: Root path of the experiment + + Returns: + List of paths to log files + """ + if not exp_root_path: + return [] + + logs_dir = os.path.join(exp_root_path, "logs") + if not os.path.exists(logs_dir): + return [] + + log_files = [] + # Find all .log files recursively + log_files.extend(glob.glob(os.path.join(logs_dir, "**", "*.log"), recursive=True)) + + return log_files + + +def upload_trace_files_to_mlflow( + mlflow_writer, + tensorboard_dir: str, + artifact_path: str = "traces", +) -> int: + """ + Upload all profiler trace files to MLflow as artifacts. + + This function collects trace files from the tensorboard directory and + uploads them to MLflow. In distributed settings, only rank 0 (or the + last rank where MLflow writer is initialized) should call this. + + Args: + mlflow_writer: The MLflow module instance (from get_mlflow_writer()) + tensorboard_dir: Path to the tensorboard directory containing trace files + artifact_path: MLflow artifact subdirectory for trace files + + Returns: + Number of trace files uploaded + """ + if mlflow_writer is None: + return 0 + + log_rank_0(f"[MLflow] Searching for trace files in: {tensorboard_dir}") + trace_files = _get_all_trace_files(tensorboard_dir) + if len(trace_files) > 5: + log_rank_0(f"[MLflow] Found {len(trace_files)} trace files: {trace_files[:5]}...") + else: + log_rank_0(f"[MLflow] Found {len(trace_files)} trace files: {trace_files}") + + if not trace_files: + log_rank_0("[MLflow] No trace files found to upload") + return 0 + + uploaded_count = 0 + for trace_file in trace_files: + try: + # Get relative path from tensorboard_dir for artifact organization + rel_path = os.path.relpath(trace_file, tensorboard_dir) + # Determine artifact subdirectory based on file location + artifact_subpath = ( + os.path.join(artifact_path, os.path.dirname(rel_path)) + if os.path.dirname(rel_path) + else artifact_path + ) + + mlflow_writer.log_artifact(trace_file, artifact_path=artifact_subpath) + uploaded_count += 1 + log_rank_0(f"[MLflow] Uploaded trace file: {os.path.basename(trace_file)}") + except Exception as e: + warning_rank_0(f"[MLflow] Failed to upload trace file {trace_file}: {e}") + + log_rank_0(f"[MLflow] Uploaded {uploaded_count} trace files to '{artifact_path}'") + return uploaded_count + + +def upload_log_files_to_mlflow( + mlflow_writer, + exp_root_path: str, + artifact_path: str = "logs", +) -> int: + """ + Upload all log files to MLflow as artifacts. + + This function collects log files from all ranks and all log levels + and uploads them to MLflow. The directory structure is preserved + in the artifact path. + + Args: + mlflow_writer: The MLflow module instance (from get_mlflow_writer()) + exp_root_path: Root path of the experiment + artifact_path: MLflow artifact subdirectory for log files + + Returns: + Number of log files uploaded + """ + if mlflow_writer is None: + return 0 + + log_files = _get_all_log_files(exp_root_path) + + if not log_files: + log_rank_0("[MLflow] No log files found to upload") + return 0 + + logs_base_dir = os.path.join(exp_root_path, "logs") + uploaded_count = 0 + + for log_file in log_files: + try: + # Preserve directory structure relative to logs base directory + rel_path = os.path.relpath(log_file, logs_base_dir) + artifact_subpath = ( + os.path.join(artifact_path, os.path.dirname(rel_path)) + if os.path.dirname(rel_path) + else artifact_path + ) + + mlflow_writer.log_artifact(log_file, artifact_path=artifact_subpath) + uploaded_count += 1 + except Exception as e: + warning_rank_0(f"[MLflow] Failed to upload log file {log_file}: {e}") + + log_rank_0(f"[MLflow] Uploaded {uploaded_count} log files to '{artifact_path}'") + return uploaded_count + + +def upload_artifacts_to_mlflow( + mlflow_writer, + tensorboard_dir: Optional[str] = None, + exp_root_path: Optional[str] = None, + upload_traces: bool = True, + upload_logs: bool = True, +) -> dict: + """ + Upload all artifacts (trace files and log files) to MLflow. + + This is the main entry point for uploading artifacts to MLflow. + It handles both trace files from profiling and log files from training. + + Args: + mlflow_writer: The MLflow module instance (from get_mlflow_writer()) + tensorboard_dir: Path to the tensorboard directory containing trace files + exp_root_path: Root path of the experiment for log files + upload_traces: Whether to upload trace files + upload_logs: Whether to upload log files + + Returns: + Dictionary with counts of uploaded files: + { + "traces": , + "logs": + } + """ + if mlflow_writer is None: + log_rank_0("[MLflow] MLflow writer not available, skipping artifact upload") + return {"traces": 0, "logs": 0} + + log_rank_0("[MLflow] Starting artifact upload to MLflow...") + log_rank_0(f"[MLflow] tensorboard_dir: {tensorboard_dir}") + log_rank_0(f"[MLflow] exp_root_path: {exp_root_path}") + log_rank_0(f"[MLflow] upload_traces: {upload_traces}, upload_logs: {upload_logs}") + + result = {"traces": 0, "logs": 0} + + if upload_traces and tensorboard_dir: + result["traces"] = upload_trace_files_to_mlflow( + mlflow_writer, tensorboard_dir, artifact_path="traces" + ) + + if upload_logs and exp_root_path: + result["logs"] = upload_log_files_to_mlflow(mlflow_writer, exp_root_path, artifact_path="logs") + + log_rank_0( + f"[MLflow] Artifact upload complete: " f"{result['traces']} trace files, {result['logs']} log files" + ) + + return result diff --git a/primus/configs/modules/megatron/primus_megatron_module.yaml b/primus/configs/modules/megatron/primus_megatron_module.yaml index 0ec3a22b0..6d8e4a6bf 100644 --- a/primus/configs/modules/megatron/primus_megatron_module.yaml +++ b/primus/configs/modules/megatron/primus_megatron_module.yaml @@ -5,6 +5,8 @@ disable_wandb: true disable_mlflow: true mlflow_run_name: null mlflow_experiment_name: null +mlflow_upload_traces: true # Upload profiler trace files to MLflow +mlflow_upload_logs: true # Upload training log files to MLflow disable_compile_dependencies: true # NOTE: # - If `use_rocm_mem_info = True`, ROCm memory information will be collected diff --git a/primus/modules/trainer/megatron/trainer.py b/primus/modules/trainer/megatron/trainer.py index 9758929da..a56a3bb29 100644 --- a/primus/modules/trainer/megatron/trainer.py +++ b/primus/modules/trainer/megatron/trainer.py @@ -144,7 +144,9 @@ from primus.backends.megatron.model_provider import primus_model_provider from primus.backends.megatron.training.global_vars import ( get_mlflow_writer, + set_exp_root_path, set_primus_global_variables, + upload_mlflow_artifacts, ) from primus.backends.megatron.training.tokenizer.tokenizer import build_tokenizer from primus.core.utils import checker, file_utils @@ -1243,6 +1245,8 @@ def initialize_megatron( set_global_variables(args, build_tokenizer=False) log_rank_0(f"-set_primus_global_variables...") set_primus_global_variables(args) + # Set exp_root_path for MLflow artifact logging + set_exp_root_path(self.exp_root_path) args = get_args() # set tokenizer @@ -1611,6 +1615,11 @@ def run(self, *args, **kwargs): mlflow_writer = get_mlflow_writer() if mlflow_writer: + # Upload artifacts before ending the run + upload_mlflow_artifacts( + upload_traces=getattr(args, "mlflow_upload_traces", True), + upload_logs=getattr(args, "mlflow_upload_logs", True), + ) mlflow_writer.end_run() one_logger and one_logger.log_metrics({"app_finish_time": one_logger_utils.get_timestamp_in_ms()}) @@ -2055,6 +2064,11 @@ def get_e2e_base_metrics(): wandb_writer.finish() mlflow_writer = get_mlflow_writer() if mlflow_writer: + # Upload artifacts before ending the run + upload_mlflow_artifacts( + upload_traces=getattr(args, "mlflow_upload_traces", True), + upload_logs=getattr(args, "mlflow_upload_logs", True), + ) mlflow_writer.end_run() ft_integration.shutdown() sys.exit(exit_code) From 13dfa81a9291d5034e1baf9aa74f1749165adf89 Mon Sep 17 00:00:00 2001 From: guangphu Date: Thu, 18 Dec 2025 10:28:08 +0000 Subject: [PATCH 2/6] docs: Clarify MLflow upload defaults are opt-out when MLflow enabled --- primus/configs/modules/megatron/primus_megatron_module.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/primus/configs/modules/megatron/primus_megatron_module.yaml b/primus/configs/modules/megatron/primus_megatron_module.yaml index 6d8e4a6bf..74f46f257 100644 --- a/primus/configs/modules/megatron/primus_megatron_module.yaml +++ b/primus/configs/modules/megatron/primus_megatron_module.yaml @@ -5,6 +5,8 @@ disable_wandb: true disable_mlflow: true mlflow_run_name: null mlflow_experiment_name: null +# NOTE: When disable_mlflow=false, traces and logs are uploaded by default. +# Set these to false if you only want metrics/params logged to MLflow. mlflow_upload_traces: true # Upload profiler trace files to MLflow mlflow_upload_logs: true # Upload training log files to MLflow disable_compile_dependencies: true From 1f2e136ecc8da96303d172cae9ce8c1153b328cd Mon Sep 17 00:00:00 2001 From: GP Huang Date: Thu, 18 Dec 2025 12:36:54 +0200 Subject: [PATCH 3/6] Update primus/modules/trainer/megatron/trainer.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- primus/modules/trainer/megatron/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/primus/modules/trainer/megatron/trainer.py b/primus/modules/trainer/megatron/trainer.py index a56a3bb29..b2c212cc5 100644 --- a/primus/modules/trainer/megatron/trainer.py +++ b/primus/modules/trainer/megatron/trainer.py @@ -1245,7 +1245,7 @@ def initialize_megatron( set_global_variables(args, build_tokenizer=False) log_rank_0(f"-set_primus_global_variables...") set_primus_global_variables(args) - # Set exp_root_path for MLflow artifact logging + # Set exp_root_path for MLflow artifact upload (needed before training starts) set_exp_root_path(self.exp_root_path) args = get_args() From d30b9202bf97dc3b1d693748d128089adc93b066 Mon Sep 17 00:00:00 2001 From: GP Huang Date: Thu, 18 Dec 2025 12:37:23 +0200 Subject: [PATCH 4/6] Update examples/run_pretrain.sh Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- examples/run_pretrain.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/run_pretrain.sh b/examples/run_pretrain.sh index a936df288..ebde568d7 100755 --- a/examples/run_pretrain.sh +++ b/examples/run_pretrain.sh @@ -126,8 +126,8 @@ fi # Extract model name from EXP config file path (e.g., deepseek_v2_lite-pretrain.yaml -> deepseek_v2_lite-pretrain) MODEL_NAME=$(basename "${EXP}" .yaml) -# Only generate new timestamp/paths if not already set by run_slurm_pretrain.sh -# This ensures: 1) single-node gets fresh timestamp, 2) multi-node shares same directory +# Only generate new timestamp/paths if not already set by run_slurm_pretrain.sh. +# This ensures single-node runs get a fresh timestamp, while multi-node runs share the same directory. if [ -z "${PRIMUS_EXP_NAME}" ]; then TIMESTAMP=$(date +%Y%m%d_%H%M%S) export PRIMUS_WORKSPACE=${PRIMUS_WORKSPACE:-"./output"} From b2da61b84356ad03ed3adfcc23685483eae8c7d2 Mon Sep 17 00:00:00 2001 From: GP Huang Date: Thu, 18 Dec 2025 12:44:58 +0200 Subject: [PATCH 5/6] Update primus/backends/megatron/training/mlflow_artifacts.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- primus/backends/megatron/training/mlflow_artifacts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/primus/backends/megatron/training/mlflow_artifacts.py b/primus/backends/megatron/training/mlflow_artifacts.py index 67caa0e62..c844036db 100644 --- a/primus/backends/megatron/training/mlflow_artifacts.py +++ b/primus/backends/megatron/training/mlflow_artifacts.py @@ -240,7 +240,7 @@ def upload_artifacts_to_mlflow( result["logs"] = upload_log_files_to_mlflow(mlflow_writer, exp_root_path, artifact_path="logs") log_rank_0( - f"[MLflow] Artifact upload complete: " f"{result['traces']} trace files, {result['logs']} log files" + f"[MLflow] Artifact upload complete: {result['traces']} trace files, {result['logs']} log files" ) return result From 283a1f4740aef8d5ecc2c627118734bfb10c098e Mon Sep 17 00:00:00 2001 From: guangphu Date: Thu, 18 Dec 2025 15:14:41 +0000 Subject: [PATCH 6/6] fix: Escape glob paths to handle [] characters in experiment names The experiment name contains square brackets like [deepseek_v2_lite-pretrain_...]-rank[0] which are interpreted as glob pattern character classes, causing glob.glob to return empty results even though files exist. Fixed by using glob.escape() on directory paths before using them with glob.glob(). --- primus/backends/megatron/training/mlflow_artifacts.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/primus/backends/megatron/training/mlflow_artifacts.py b/primus/backends/megatron/training/mlflow_artifacts.py index c844036db..f271dc639 100644 --- a/primus/backends/megatron/training/mlflow_artifacts.py +++ b/primus/backends/megatron/training/mlflow_artifacts.py @@ -43,9 +43,11 @@ def _get_all_trace_files(tensorboard_dir: str) -> list: trace_files = [] # Look for PyTorch profiler trace files (both compressed and uncompressed) patterns = ["*.pt.trace.json", "*.pt.trace.json.gz"] + # Escape directory path to handle special characters like [] in experiment names + escaped_dir = glob.escape(tensorboard_dir) for pattern in patterns: - trace_files.extend(glob.glob(os.path.join(tensorboard_dir, pattern))) - trace_files.extend(glob.glob(os.path.join(tensorboard_dir, "**", pattern), recursive=True)) + trace_files.extend(glob.glob(os.path.join(escaped_dir, pattern))) + trace_files.extend(glob.glob(os.path.join(escaped_dir, "**", pattern), recursive=True)) # Remove duplicates while preserving order seen = set() @@ -80,8 +82,8 @@ def _get_all_log_files(exp_root_path: str) -> list: return [] log_files = [] - # Find all .log files recursively - log_files.extend(glob.glob(os.path.join(logs_dir, "**", "*.log"), recursive=True)) + # Find all .log files recursively (escape path to handle special characters) + log_files.extend(glob.glob(os.path.join(glob.escape(logs_dir), "**", "*.log"), recursive=True)) return log_files