From 51499e56bf153f35e108cdf23b4983be2abee1db Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Fri, 9 Jan 2026 08:27:36 -0800 Subject: [PATCH 1/3] feat: add het-job support for ray slurm Signed-off-by: Hemil Desai --- nemo_run/run/ray/slurm.py | 90 +++- nemo_run/run/ray/templates/ray.sub.j2 | 21 +- .../run/torchx_backend/schedulers/slurm.py | 11 + .../artifacts/expected_ray_het_cluster.sub | 483 ++++++++++++++++++ test/run/ray/test_slurm_ray_request.py | 234 ++++++++- .../torchx_backend/schedulers/test_slurm.py | 197 +++++++ 6 files changed, 1014 insertions(+), 22 deletions(-) create mode 100644 test/core/execution/artifacts/expected_ray_het_cluster.sub diff --git a/nemo_run/run/ray/slurm.py b/nemo_run/run/ray/slurm.py index 5fb60019..72d79fe6 100644 --- a/nemo_run/run/ray/slurm.py +++ b/nemo_run/run/ray/slurm.py @@ -169,9 +169,67 @@ def materialize(self) -> str: parameters.update(self.executor.additional_parameters) sbatch_flags = [] - assert not self.executor.heterogeneous, "heterogeneous is not supported for ray clusters" - for k in sorted(parameters): - sbatch_flags.append(_as_sbatch_flag(k, parameters[k])) + if self.executor.heterogeneous: + # Validate resource_group exists + assert self.executor.resource_group, "heterogeneous requires resource_group to be set" + assert len(self.executor.resource_group) > 0, "resource_group must not be empty" + + # Validate het-group-0 has at least 1 node for Ray head + head_group = self.executor.resource_group[0] + assert head_group.nodes >= 1, "het-group-0 must have at least 1 node for Ray head" + + # Determine the final het group index (for hetjob separator placement) + final_group_index = len(self.executor.resource_group) - 1 + if self.executor.het_group_indices: + final_group_index = self.executor.het_group_indices.index( + max(self.executor.het_group_indices) + ) + + # Generate SBATCH blocks for each het group + for i, resource_req in enumerate(self.executor.resource_group): + # Skip duplicate het groups (when het_group_index is shared) + if resource_req.het_group_index is not None: + if ( + i > 0 + and resource_req.het_group_index + == self.executor.resource_group[i - 1].het_group_index + ): + continue + + # Build het-specific parameters + het_parameters = parameters.copy() + het_parameters.update( + { + "nodes": resource_req.nodes, + "ntasks_per_node": resource_req.ntasks_per_node, + } + ) + + # Update job name to include het group index + het_parameters["job_name"] = f"{job_details.job_name}-{i}" + + # Only update GPU parameters if they're explicitly set in resource_req + if resource_req.gpus_per_node is not None: + het_parameters["gpus_per_node"] = resource_req.gpus_per_node + if resource_req.gpus_per_task is not None: + het_parameters["gpus_per_task"] = resource_req.gpus_per_task + + # Update output/error paths to include het group index + het_parameters["output"] = parameters["output"].replace("%t", str(i)) + if "error" in het_parameters: + het_parameters["error"] = parameters["error"].replace("%t", str(i)) + + # Generate SBATCH flags for this het group + for k in sorted(het_parameters): + sbatch_flags.append(_as_sbatch_flag(k, het_parameters[k])) + + # Add hetjob separator (except after last group) + if i != final_group_index: + sbatch_flags.append("#SBATCH hetjob") + else: + # Non-heterogeneous: use existing logic + for k in sorted(parameters): + sbatch_flags.append(_as_sbatch_flag(k, parameters[k])) if self.executor.dependencies: slurm_deps = self.executor.parse_deps() @@ -238,6 +296,8 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str: "command_workdir": self.workdir, "gres_specification": get_gres_specification(), "ray_log_prefix": ray_log_prefix, + "heterogeneous": self.executor.heterogeneous, + "resource_group": self.executor.resource_group if self.executor.heterogeneous else [], } if self.command_groups: @@ -273,12 +333,24 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str: os.path.join(logs_dir, f"{ray_log_prefix}overlap-{idx}.err"), ] + # Determine het group for this command (if heterogeneous) + het_group_flag = [] + if self.executor.heterogeneous and self.executor.run_as_group: + if len(self.executor.resource_group) == len(self.command_groups): + # Use resource_group mapping + req = self.executor.resource_group[idx] + het_group_num = ( + req.het_group_index if req.het_group_index is not None else idx + ) + het_group_flag = [f"--het-group={het_group_num}"] + srun_cmd = " ".join( list( map( lambda arg: arg if isinstance(arg, noquote) else shlex.quote(arg), [ "srun", + *het_group_flag, "--output", noquote(stdout_path), *stderr_flags, @@ -1187,9 +1259,9 @@ def start( if isinstance(self.executor.tunnel, SSHTunnel): # Rsync workdir honouring .gitignore self.executor.tunnel.connect() - assert self.executor.tunnel.session is not None, ( - "Tunnel session is not connected" - ) + assert ( + self.executor.tunnel.session is not None + ), "Tunnel session is not connected" rsync( self.executor.tunnel.session, workdir, @@ -1244,9 +1316,9 @@ def start( if isinstance(self.executor.tunnel, SSHTunnel): self.executor.tunnel.connect() - assert self.executor.tunnel.session is not None, ( - "Tunnel session is not connected" - ) + assert ( + self.executor.tunnel.session is not None + ), "Tunnel session is not connected" rsync( self.executor.tunnel.session, os.path.join(local_code_extraction_path, ""), diff --git a/nemo_run/run/ray/templates/ray.sub.j2 b/nemo_run/run/ray/templates/ray.sub.j2 index 0d7e3510..8d510cb2 100644 --- a/nemo_run/run/ray/templates/ray.sub.j2 +++ b/nemo_run/run/ray/templates/ray.sub.j2 @@ -170,6 +170,13 @@ head_node=${nodes_array[0]} head_node_ip=${ip_addresses_array[0]} ip_head=$head_node_ip:$PORT +{%- if heterogeneous %} + +# Extract het group hostnames for heterogeneous jobs +{% for i in range(resource_group|length) %} +het_group_host_{{i}}=$(scontrol show hostnames $SLURM_JOB_NODELIST_HET_GROUP_{{i}} | head -n1) +{%- endfor %} +{%- endif %} {%- if setup_lines %} {{setup_lines}} @@ -279,12 +286,12 @@ touch $LOG_DIR/ENDED exit 1 EOF ) -srun {{ common_srun_args }} --container-name=ray-head --nodes=1 --ntasks=1 --cpus-per-task=$CPUS_PER_WORKER -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}head.log bash -x -c "$head_cmd" & +srun {% if heterogeneous %}--het-group=0 {% endif %}{{ common_srun_args }} --container-name=ray-head --nodes=1 --ntasks=1 --cpus-per-task=$CPUS_PER_WORKER -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}head.log bash -x -c "$head_cmd" & SRUN_PIDS["ray-head"]=$! # Wait for the head node container to start and for Ray to be ready elapsed_time=0 -while ! (srun --overlap --nodes=1 --ntasks=1 -w $head_node test -f $LOG_DIR/STARTED_RAY_HEAD && srun --overlap --container-name=ray-head --nodes=1 --ntasks=1 -w $head_node ray status --address $ip_head 2>/dev/null); do +while ! (srun {% if heterogeneous %}--het-group=0 {% endif %}--overlap --nodes=1 --ntasks=1 -w $head_node test -f $LOG_DIR/STARTED_RAY_HEAD && srun {% if heterogeneous %}--het-group=0 {% endif %}--overlap --container-name=ray-head --nodes=1 --ntasks=1 -w $head_node ray status --address $ip_head 2>/dev/null); do if [[ $elapsed_time -ge $RAY_HEAD_START_TIMEOUT ]]; then echo "[ERROR][$(date)] Ray head node failed to start within $RAY_HEAD_START_TIMEOUT seconds. Exiting..." touch $LOG_DIR/ENDED @@ -368,7 +375,7 @@ EOF if [[ $i -eq 0 ]]; then OVERLAP_HEAD_AND_WORKER_ARG="--overlap" fi - srun {{ common_srun_args }} ${OVERLAP_HEAD_AND_WORKER_ARG:-} --container-name=ray-worker-$i --exact --nodes=1 --ntasks=1 --cpus-per-task=$CPUS_PER_WORKER -w "$node_i" -o $LOG_DIR/{{ ray_log_prefix }}worker-$i.log bash -x -c "$worker_cmd" & + srun {% if heterogeneous %}--het-group=0 {% endif %}{{ common_srun_args }} ${OVERLAP_HEAD_AND_WORKER_ARG:-} --container-name=ray-worker-$i --exact --nodes=1 --ntasks=1 --cpus-per-task=$CPUS_PER_WORKER -w "$node_i" -o $LOG_DIR/{{ ray_log_prefix }}worker-$i.log bash -x -c "$worker_cmd" & SRUN_PIDS["ray-worker-$i"]=$! sleep 3 done @@ -377,7 +384,7 @@ done # Before we launch a job on this cluster we need to make sure that the bringup is complete # We do so by querying the number of worker_units in the ray cluster and asserting = NUM_ACTORS extract_worker_units() { - status_output=$(srun --overlap --container-name=ray-head --nodes=1 --ntasks=1 -w "$head_node" ray status --address $ip_head) + status_output=$(srun {% if heterogeneous %}--het-group=0 {% endif %}--overlap --container-name=ray-head --nodes=1 --ntasks=1 -w "$head_node" ray status --address $ip_head) if echo "$status_output" | grep -q "worker_units"; then worker_units=$(echo "$status_output" | grep "worker_units" | awk -F'[/. ]' '{print $4}') echo $worker_units @@ -447,7 +454,7 @@ COMMAND="${COMMAND:-{{ command | default('', true) }}}" COMMAND_WORKDIR={{ command_workdir | default('$CONTAINER_CWD') }} if [[ -n "$COMMAND" ]]; then - srun --no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}job.log bash -c "$COMMAND" + srun {% if heterogeneous %}--het-group=0 {% endif %}--no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}job.log bash -c "$COMMAND" else echo "[INFO]: Ray Cluster is idled, run this on the slurm head node to get a shell to the head node:" cat <$CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh @@ -455,10 +462,10 @@ else WORKER_NUM=\${1:-} if [[ -z "\$WORKER_NUM" ]]; then # Empty means we are on the head node - srun --no-container-mount-home --gpus=0 -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" --jobid $SLURM_JOB_ID --pty bash + srun {% if heterogeneous %}--het-group=0 {% endif %}--no-container-mount-home --gpus=0 -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" --jobid $SLURM_JOB_ID --pty bash else nodes_array=($nodes) - srun --no-container-mount-home {%- if gres_specification %}{{gres_specification}}{% endif %} -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-worker-\$WORKER_NUM --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "\${nodes_array[\$WORKER_NUM]}" --jobid $SLURM_JOB_ID --pty bash + srun {% if heterogeneous %}--het-group=0 {% endif %}--no-container-mount-home {%- if gres_specification %}{{gres_specification}}{% endif %} -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-worker-\$WORKER_NUM --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "\${nodes_array[\$WORKER_NUM]}" --jobid $SLURM_JOB_ID --pty bash fi EOF chmod +x $CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh diff --git a/nemo_run/run/torchx_backend/schedulers/slurm.py b/nemo_run/run/torchx_backend/schedulers/slurm.py index 358419c0..999e505a 100644 --- a/nemo_run/run/torchx_backend/schedulers/slurm.py +++ b/nemo_run/run/torchx_backend/schedulers/slurm.py @@ -113,6 +113,17 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t srun_cmd = [role.entrypoint] + role.args srun_cmds.append([" ".join(srun_cmd)]) + # For heterogeneous jobs, ensure run_as_group is set for command group mapping + if executor.heterogeneous and executor.resource_group: + executor.run_as_group = True + # Validate that command groups align with resource groups + if len(srun_cmds) != len(executor.resource_group): + log.warning( + f"Heterogeneous job has {len(executor.resource_group)} resource groups " + f"but {len(srun_cmds)} roles. Command groups should match resource groups " + f"for proper het-group mapping." + ) + command = [app.roles[0].entrypoint] + app.roles[0].args # Allow selecting Ray template via environment variable ray_template_name = os.environ.get("NEMO_RUN_SLURM_RAY_TEMPLATE", "ray.sub.j2") diff --git a/test/core/execution/artifacts/expected_ray_het_cluster.sub b/test/core/execution/artifacts/expected_ray_het_cluster.sub new file mode 100644 index 00000000..9a1a0cd0 --- /dev/null +++ b/test/core/execution/artifacts/expected_ray_het_cluster.sub @@ -0,0 +1,483 @@ +#!/bin/bash +# +# Generated by NeMo Run +# + +# Parameters +#SBATCH --account=test_account +#SBATCH --gpus-per-node=8 +#SBATCH --gres=gpu:8 +#SBATCH --job-name=test_account-account.test-ray-het-cluster-0 +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=8 +#SBATCH --open-mode=append +#SBATCH --output=/tmp/test_jobs/test-ray-het-cluster/logs/sbatch_test_account-account.test-ray-het-cluster_%j.out +#SBATCH --partition=gpu +#SBATCH --time=01:00:00 +#SBATCH hetjob +#SBATCH --account=test_account +#SBATCH --gpus-per-node=0 +#SBATCH --gres=gpu:8 +#SBATCH --job-name=test_account-account.test-ray-het-cluster-1 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --open-mode=append +#SBATCH --output=/tmp/test_jobs/test-ray-het-cluster/logs/sbatch_test_account-account.test-ray-het-cluster_%j.out +#SBATCH --partition=gpu +#SBATCH --time=01:00:00 +#SBATCH hetjob +#SBATCH --account=test_account +#SBATCH --gpus-per-node=0 +#SBATCH --gres=gpu:8 +#SBATCH --job-name=test_account-account.test-ray-het-cluster-2 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --open-mode=append +#SBATCH --output=/tmp/test_jobs/test-ray-het-cluster/logs/sbatch_test_account-account.test-ray-het-cluster_%j.out +#SBATCH --partition=gpu +#SBATCH --time=01:00:00 + +set -eoux pipefail + +######################################################## +# User defined variables +######################################################## +export PYTHONUNBUFFERED=1 +export SLURM_UNBUFFEREDIO=1 + +# Ports for all nodes (should be odd numbers since we place head/worker[0] on the same node) so all workers get the odd ports, but the head will get +1 the ports +NODE_MANAGER_PORT=${NODE_MANAGER_PORT:-53001} +OBJECT_MANAGER_PORT=${OBJECT_MANAGER_PORT:-53003} +RUNTIME_ENV_AGENT_PORT=${RUNTIME_ENV_AGENT_PORT:-53005} +DASHBOARD_AGENT_GRPC_PORT=${DASHBOARD_AGENT_GRPC_PORT:-53007} +METRICS_EXPORT_PORT=${METRICS_EXPORT_PORT:-53009} + +# Ports for the head node +PORT=${PORT:-6379} +RAY_CLIENT_SERVER_PORT=${RAY_CLIENT_SERVER_PORT:-10001} +#REDIT_SHARD_PORTS=${REDIT_SHARD_PORTS:-"random"} ?? +DASHBOARD_PORT=${DASHBOARD_PORT:-8265} # Also used by debugger +DASHBOARD_AGENT_LISTEN_PORT=${DASHBOARD_AGENT_LISTEN_PORT:-52365} +RAY_DEBUGGER_ARGS= +if [ "${RAY_DEBUG:-}" = "legacy" ]; then + RAY_DEBUGGER_ARGS="--ray-debugger-external" +fi + +# After ray>=2.47, this feature is enabled by default which creates uv venvs for any py_executable starting with `uv run`. +# There is severe contention and performance issues with this enabled considering our dependencies are so large and occasionally +# need to be compiled, so NeMo RL has an implementation in nemo_rl/utils/venv.py that does it once per node as opposed to once per task. +export RAY_ENABLE_UV_RUN_RUNTIME_ENV=0 + +# Setting ulimit is recommended by ray best practices page +# @ https://docs.ray.io/en/latest/cluster/vms/user-guides/large-cluster-best-practices.html +# It's session based and won't affect the system outside the script +# Ensure that the soft limit isn't above the hard limit +if [[ $(ulimit -Hn) == "unlimited" ]] || [[ 65535 -lt $(ulimit -Hn) ]]; then + ulimit -Sn 65535 +elif [[ $(ulimit -Hn) != "unlimited" ]] && [[ $(ulimit -Hn) -lt 65535 ]]; then + echo "[WARNING]: Cannot increase ulimit on file descriptors to 65535 according ray recommendation: https://docs.ray.io/en/latest/cluster/vms/user-guides/large-cluster-best-practices.html. Speak to cluster admins to increase, otherwise ray may crash unexpectedly." +fi + +# On our clusters, the largest port range on an idle worker appeared between 52369-64607 +# (not including the other ports set by this script). So this range is chosen to be +# somewhere in the middle +MIN_WORKER_PORT=${MIN_WORKER_PORT:-54001} +MAX_WORKER_PORT=${MAX_WORKER_PORT:-54257} + +# Ray temp directory (inside container). Used by --temp-dir and log sync sidecar +RAY_TEMP_DIR=${RAY_TEMP_DIR:-/ray-cluster} + +# Number seconds to sync logs from /tmp/ray/session_*/logs to $LOG_DIR/ray/ +RAY_LOG_SYNC_FREQUENCY=${RAY_LOG_SYNC_FREQUENCY:-} + +# Timeout in seconds for Ray head node to start (default 10 minutes) +RAY_HEAD_START_TIMEOUT=${RAY_HEAD_START_TIMEOUT:-600} + +# Directory setup +export CLUSTER_DIR=/tmp/test_jobs/test-ray-het-cluster +mkdir -p $CLUSTER_DIR + +JOB_IDS_FILE="$CLUSTER_DIR/job_ids.json" +if [[ -f "$JOB_IDS_FILE" ]]; then + tmp="$(mktemp)" + jq --arg id "$SLURM_JOB_ID" '. + [$id]' "$JOB_IDS_FILE" > "$tmp" && mv "$tmp" "$JOB_IDS_FILE" +else + touch "$JOB_IDS_FILE" + echo "[\"$SLURM_JOB_ID\"]" > "$JOB_IDS_FILE" +fi + +mkdir -p $CLUSTER_DIR/scripts + +export LOG_DIR=/tmp/test_jobs/test-ray-het-cluster/logs +mkdir -p $LOG_DIR + +# Clean up any previous run files +rm -f $LOG_DIR/STARTED_RAY_HEAD +rm -f $LOG_DIR/ENDED + +# Defaults to placing uv cache inside the CLUSTER_DIR +# This directory is mounted into the container at /home/ray/.cache/uv so it is shared between the head and worker nodes +# UV_CACHE_DIR=/tmp/test_jobs/test-ray-het-cluster/uv_cache +# mkdir -p $UV_CACHE_DIR +######################################################## + +# Number of GPUs per node +gpus_per_node=8 +CPUS_PER_WORKER=${CPUS_PER_WORKER:-$((gpus_per_node * 16))} + +num_retries=1 + +# Track backgrounded srun client PIDs for head and workers +declare -A SRUN_PIDS + +# Verify all backgrounded srun client processes are still alive; exit fast if any died +check_srun_processes() { + for name in "${!SRUN_PIDS[@]}"; do + pid="${SRUN_PIDS[$name]}" + # Check if the process is still running + if ! kill -0 "$pid" 2>/dev/null; then + echo "[ERROR] Background srun '$name' died (pid=$pid). Could be a failure in startup or an issue with the node preventing the srun to start. Attempting to exit." >&2 + # Signal sidecars inside containers to terminate ASAP + touch "$LOG_DIR/ENDED" + exit 1 + fi + done +} + +# Getting the node names and IP addresses in the SLURM allocation +nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") +nodes_array=($nodes) +ip_addresses_array=() + +for node in $nodes; do + # Try multiple methods to get IP address - ENHANCED VERSION v2.0 + echo "[DEBUG] Resolving hostname: $node using enhanced resolution methods" + ip_address="" + + # Method 1: Try host command + echo "[DEBUG] Method 1: host command" + ip_address=$(host $node 2>/dev/null | awk '/has address/ { print $4 }' | head -1 || true) + echo "[DEBUG] host result: '$ip_address'" + + # Method 2: If host fails, try getent + if [[ -z "$ip_address" ]]; then + echo "[DEBUG] Method 2: getent hosts" + ip_address=$(getent hosts $node 2>/dev/null | awk '{ print $1 }' | head -1 || true) + echo "[DEBUG] getent result: '$ip_address'" + fi + + # Method 3: If getent fails, try nslookup + if [[ -z "$ip_address" ]]; then + echo "[DEBUG] Method 3: nslookup" + ip_address=$(nslookup $node 2>/dev/null | awk '/^Address: / { print $2 }' | head -1 || true) + echo "[DEBUG] nslookup result: '$ip_address'" + fi + + # Method 4: If all DNS methods fail, try ping to extract IP + if [[ -z "$ip_address" ]]; then + echo "[DEBUG] Method 4: ping" + ip_address=$(ping -c 1 $node 2>/dev/null | grep "PING" | sed 's/.*(\([^)]*\)).*/\1/' || true) + echo "[DEBUG] ping result: '$ip_address'" + fi + + # If still no IP, use the hostname itself (might work if it's already an IP or resolvable) + if [[ -z "$ip_address" ]]; then + echo "[WARNING] Could not resolve IP for $node, using hostname as fallback" + ip_address=$node + fi + + echo "[INFO] Node: $node -> IP: $ip_address" + # Add the IP address to the array + ip_addresses_array+=("$ip_address") +done + +head_node=${nodes_array[0]} +head_node_ip=${ip_addresses_array[0]} + +ip_head=$head_node_ip:$PORT + +# Extract het group hostnames for heterogeneous jobs + +het_group_host_0=$(scontrol show hostnames $SLURM_JOB_NODELIST_HET_GROUP_0 | head -n1) +het_group_host_1=$(scontrol show hostnames $SLURM_JOB_NODELIST_HET_GROUP_1 | head -n1) +het_group_host_2=$(scontrol show hostnames $SLURM_JOB_NODELIST_HET_GROUP_2 | head -n1) + +######################################################## +# Ray cluster setup +######################################################## +# First we start the head of the ray cluster on one of the physical nodes +# Set GPU/CPU resources to 0 to avoid scheduling on the head node + +head_cmd=$(cat < /dev/null 2>&1; then + for session_dir in ${RAY_TEMP_DIR}/session_[0-9]*/; do + if [[ -d "\$session_dir/logs" ]]; then + session_name=\$(basename "\$session_dir") + mkdir -p "$LOG_DIR/ray/\$session_name" + if command -v rsync > /dev/null 2>&1; then + rsync -ahP "\$session_dir/logs/" "$LOG_DIR/ray/\$session_name/logs/" 2>/dev/null || true + else + cp -r "\$session_dir/logs" "$LOG_DIR/ray/\$session_name/" + fi + fi + done + fi + if [[ -f "$LOG_DIR/ENDED" ]]; then + echo "Log sync sidecar terminating..." + break + fi + done +} +log-sync-sidecar & + +# Patch nsight.py before starting Ray head +sed -i 's/context\.py_executable = " "\.join(self\.nsight_cmd) + " python"/context.py_executable = " ".join(self.nsight_cmd) + f" {context.py_executable}"/g' /opt/nemo_rl_venv/lib64/python*/site-packages/ray/_private/runtime_env/nsight.py + +cat </dev/null); do + if [[ $elapsed_time -ge $RAY_HEAD_START_TIMEOUT ]]; then + echo "[ERROR][$(date)] Ray head node failed to start within $RAY_HEAD_START_TIMEOUT seconds. Exiting..." + touch $LOG_DIR/ENDED + exit 1 + fi + echo "[INFO][$(date)] Waiting for Ray head node container to start and be ready... ($elapsed_time/$RAY_HEAD_START_TIMEOUT seconds)" + check_srun_processes + sleep 2 + elapsed_time=$((elapsed_time + 2)) +done + +NUM_ACTORS=$((gpus_per_node * SLURM_JOB_NUM_NODES)) + +# Start Ray worker nodes +# We want 1 Ray worker node per physical node +# Worker nodes are started with ray start but without the --head flag +for ((i = 1; i < SLURM_JOB_NUM_NODES; i++)); do + node_i=${nodes_array[$i]} + + worker_cmd=$(cat <$CLUSTER_DIR/ray_cluster_info.json +{ + "head_ip": "$head_node_ip", + "dashboard_port": "$DASHBOARD_PORT", + "port": "$PORT" +} +EOF +# Set up trap to clean up cluster info on job termination +cleanup_cluster_info() { + echo "[INFO] Cleaning up Ray cluster information" + rm -f $CLUSTER_DIR/ray_cluster_info.json +} + +# Register the cleanup function to run on script exit +trap cleanup_cluster_info EXIT + + +echo "[INFO] Ray cluster information saved to $CLUSTER_DIR/ray_cluster_info.json" + +######################################################## + + +# Run extra commands + + +srun --het-group=1 --output /tmp/test_jobs/test-ray-het-cluster/logs/ray-overlap-1.out --container-image=nvcr.io/nvidia/pytorch:24.01-py3 --no-container-mount-home --mpi=pmix -A=test_account -p=gpu --gres=gpu:8 --container-mounts /tmp/test_jobs/test-ray-het-cluster:/tmp/test_jobs/test-ray-het-cluster,/tmp/test_jobs/test-ray-het-cluster:/tmp/test_jobs/test-ray-het-cluster,/tmp/test_jobs/test-ray-het-cluster/logs:/tmp/test_jobs/test-ray-het-cluster/logs --container-workdir=/tmp/test_jobs/test-ray-het-cluster --wait=60 --kill-on-bad-exit=1 --overlap python /scripts/auxiliary_task.py & + +export TASK_TYPE=monitoring + +srun --het-group=2 --output /tmp/test_jobs/test-ray-het-cluster/logs/ray-overlap-2.out --container-image=nvcr.io/nvidia/pytorch:24.01-py3 --no-container-mount-home --mpi=pmix -A=test_account -p=gpu --gres=gpu:8 --container-mounts /tmp/test_jobs/test-ray-het-cluster:/tmp/test_jobs/test-ray-het-cluster,/tmp/test_jobs/test-ray-het-cluster:/tmp/test_jobs/test-ray-het-cluster,/tmp/test_jobs/test-ray-het-cluster/logs:/tmp/test_jobs/test-ray-het-cluster/logs --container-workdir=/tmp/test_jobs/test-ray-het-cluster --wait=60 --kill-on-bad-exit=1 --overlap python /scripts/monitoring.py & + +######################################################## +# We can now launch a job on this cluster +# We do so by launching a driver process on the physical node that the head node is on +# This driver process is responsible for launching a job on the Ray cluster +CONTAINER_CWD=$(scontrol show job $SLURM_JOB_ID --json | jq -r '.jobs[].current_working_directory') +# Define command to be empty by default +COMMAND="${COMMAND:-}" +COMMAND_WORKDIR=None + +if [[ -n "$COMMAND" ]]; then + srun --het-group=0 --no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/ray-job.log bash -c "$COMMAND" +else + echo "[INFO]: Ray Cluster is idled, run this on the slurm head node to get a shell to the head node:" + cat <$CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh +# No args launches on the head node +WORKER_NUM=\${1:-} +if [[ -z "\$WORKER_NUM" ]]; then + # Empty means we are on the head node + srun --het-group=0 --no-container-mount-home --gpus=0 -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" --jobid $SLURM_JOB_ID --pty bash +else + nodes_array=($nodes) + srun --het-group=0 --no-container-mount-home--gres=gpu:8 -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-worker-\$WORKER_NUM --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "\${nodes_array[\$WORKER_NUM]}" --jobid $SLURM_JOB_ID --pty bash +fi +EOF + chmod +x $CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh + echo " bash $CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh" + sleep infinity +fi \ No newline at end of file diff --git a/test/run/ray/test_slurm_ray_request.py b/test/run/ray/test_slurm_ray_request.py index 4129587b..2ed6acba 100644 --- a/test/run/ray/test_slurm_ray_request.py +++ b/test/run/ray/test_slurm_ray_request.py @@ -372,23 +372,163 @@ def test_cpus_per_gpu_warning(self): with pytest.warns(UserWarning, match="cpus_per_gpu.*requires.*gpus_per_task"): request.materialize() - def test_heterogeneous_assertion(self): - """Test materialize raises assertion for heterogeneous jobs.""" - executor = SlurmExecutor(account="test_account", heterogeneous=True) + def test_heterogeneous_basic(self): + """Test materialize generates correct SBATCH blocks for heterogeneous jobs.""" + from unittest.mock import Mock + + executor = SlurmExecutor( + account="test_account", + partition="gpu", + heterogeneous=True, + ) + executor.run_as_group = True + executor.resource_group = [ + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=2, + ntasks_per_node=8, + gpus_per_node=8, + container_image="gpu_image", + container_mounts=["/data:/data"], + ), + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=1, + gpus_per_node=0, + container_image="cpu_image", + container_mounts=["/data:/data"], + ), + ] executor.tunnel = Mock(spec=SSHTunnel) executor.tunnel.job_dir = "/tmp/test_jobs" request = SlurmRayRequest( - name="test-ray-cluster", - cluster_dir="/tmp/test_jobs/test-ray-cluster", + name="test-ray-het-cluster", + cluster_dir="/tmp/test_jobs/test-ray-het-cluster", + template_name="ray.sub.j2", + executor=executor, + launch_cmd=["sbatch", "--parsable"], + ) + + script = request.materialize() + + # Assert het job structure + assert "#SBATCH hetjob" in script + assert "het_group_host_0" in script + assert "het_group_host_1" in script + + # Assert different GPU specs per group + lines = script.split("\n") + het_job_idx = None + for i, line in enumerate(lines): + if "#SBATCH hetjob" in line: + het_job_idx = i + break + + assert het_job_idx is not None + + # Before hetjob separator should have gpus-per-node=8 + before_hetjob = "\n".join(lines[:het_job_idx]) + assert "#SBATCH --gpus-per-node=8" in before_hetjob + assert "#SBATCH --nodes=2" in before_hetjob + + # After hetjob separator should have gpus-per-node=0 + after_hetjob = "\n".join(lines[het_job_idx:]) + assert "#SBATCH --gpus-per-node=0" in after_hetjob + assert "#SBATCH --nodes=1" in after_hetjob + + def test_heterogeneous_with_command_groups(self): + """Test command groups with het jobs use correct het-group flags.""" + from unittest.mock import Mock + + executor = SlurmExecutor( + account="test_account", + heterogeneous=True, + ) + executor.run_as_group = True + executor.resource_group = [ + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=8, + gpus_per_node=8, + container_image="image1", + container_mounts=["/data:/data"], + het_group_index=0, + ), + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=1, + gpus_per_node=0, + container_image="image2", + container_mounts=["/data:/data"], + het_group_index=1, + ), + ] + executor.tunnel = Mock(spec=SSHTunnel) + executor.tunnel.job_dir = "/tmp/test_jobs" + + request = SlurmRayRequest( + name="test-ray-het-cluster", + cluster_dir="/tmp/test_jobs/test-ray-het-cluster", template_name="ray.sub.j2", executor=executor, + command_groups=[["cmd0"], ["cmd1"]], launch_cmd=["sbatch", "--parsable"], ) - with pytest.raises(AssertionError, match="heterogeneous is not supported"): + script = request.materialize() + + # Should have het-group flags in srun commands + assert "--het-group=1" in script # command_groups[1] uses het-group=1 + + def test_heterogeneous_validation_errors(self): + """Test validation errors for invalid het job configs.""" + from unittest.mock import Mock + + # Test: missing resource_group + executor = SlurmExecutor(account="test_account", heterogeneous=True) + executor.tunnel = Mock(spec=SSHTunnel) + executor.tunnel.job_dir = "/tmp/test_jobs" + + request = SlurmRayRequest( + name="test-cluster", + cluster_dir="/tmp/test_jobs/test-cluster", + template_name="ray.sub.j2", + executor=executor, + launch_cmd=["sbatch"], + ) + + with pytest.raises(AssertionError, match="resource_group"): request.materialize() + # Test: het-group-0 with 0 nodes + executor2 = SlurmExecutor(account="test_account", heterogeneous=True) + executor2.resource_group = [ + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=0, # Invalid! + ntasks_per_node=1, + container_image="image", + container_mounts=[], + ) + ] + executor2.tunnel = Mock(spec=SSHTunnel) + executor2.tunnel.job_dir = "/tmp/test_jobs" + + request2 = SlurmRayRequest( + name="test-cluster", + cluster_dir="/tmp/test_jobs/test-cluster", + template_name="ray.sub.j2", + executor=executor2, + launch_cmd=["sbatch"], + ) + + with pytest.raises(AssertionError, match="het-group-0 must have at least 1 node"): + request2.materialize() + def test_array_assertion(self): """Test materialize raises assertion for array jobs.""" executor = SlurmExecutor(account="test_account", array="1-10") @@ -742,3 +882,85 @@ def test_ray_enroot_template( expected_script = f.read() assert generated_script.strip() == expected_script.strip() + + @pytest.fixture + def het_ray_request_with_artifact(self) -> tuple[SlurmRayRequest, str]: + """Create a het Ray cluster request matching expected artifact.""" + executor = SlurmExecutor( + account="test_account", + partition="gpu", + time="01:00:00", + heterogeneous=True, + container_image="nvcr.io/nvidia/pytorch:24.01-py3", + container_mounts=[ + "/tmp/test_jobs/test-ray-het-cluster:/tmp/test_jobs/test-ray-het-cluster" + ], + gres="gpu:8", + ) + executor.run_as_group = True + executor.resource_group = [ + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=2, + ntasks_per_node=8, + gpus_per_node=8, + container_image="nvcr.io/nvidia/pytorch:24.01-py3", + container_mounts=[ + "/tmp/test_jobs/test-ray-het-cluster:/tmp/test_jobs/test-ray-het-cluster" + ], + het_group_index=0, + ), + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=1, + gpus_per_node=0, + container_image="nvcr.io/nvidia/pytorch:24.01-py3", + container_mounts=[ + "/tmp/test_jobs/test-ray-het-cluster:/tmp/test_jobs/test-ray-het-cluster" + ], + het_group_index=1, + ), + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=2, + gpus_per_node=0, + container_image="nvcr.io/nvidia/pytorch:24.01-py3", + container_mounts=[ + "/tmp/test_jobs/test-ray-het-cluster:/tmp/test_jobs/test-ray-het-cluster" + ], + het_group_index=2, + env_vars={"TASK_TYPE": "monitoring"}, + ), + ] + executor.tunnel = Mock(spec=SSHTunnel) + executor.tunnel.job_dir = "/tmp/test_jobs" + executor.tunnel.key = "test-cluster" + + request = SlurmRayRequest( + name="test-ray-het-cluster", + cluster_dir="/tmp/test_jobs/test-ray-het-cluster", + template_name="ray.sub.j2", + executor=executor, + command_groups=[ + ["echo 'Ray cluster on het-group-0'"], # Skipped (index 0 = Ray cluster) + ["python /scripts/auxiliary_task.py"], # Runs on het-group-1 + ["python /scripts/monitoring.py"], # Runs on het-group-2 + ], + launch_cmd=["sbatch", "--parsable"], + ) + + return request, os.path.join(ARTIFACTS_DIR, "expected_ray_het_cluster.sub") + + def test_heterogeneous_artifact( + self, het_ray_request_with_artifact: tuple[SlurmRayRequest, str] + ): + """Test that het Ray cluster script matches artifact.""" + ray_request, artifact_path = het_ray_request_with_artifact + generated_script = ray_request.materialize() + + with open(artifact_path, "r") as f: + expected_script = f.read() + + assert generated_script.strip() == expected_script.strip() diff --git a/test/run/torchx_backend/schedulers/test_slurm.py b/test/run/torchx_backend/schedulers/test_slurm.py index 96b0e239..7da8a3c0 100644 --- a/test/run/torchx_backend/schedulers/test_slurm.py +++ b/test/run/torchx_backend/schedulers/test_slurm.py @@ -403,3 +403,200 @@ def test_ray_template_env_var(slurm_scheduler, slurm_executor): dryrun_info = slurm_scheduler._submit_dryrun(app_def, slurm_executor) assert isinstance(dryrun_info.request, SlurmRayRequest) assert dryrun_info.request.template_name == "ray_enroot.sub.j2" + + +def test_heterogeneous_ray_cluster_run_as_group(slurm_scheduler, temp_dir): + """Test that run_as_group is automatically set for heterogeneous Ray clusters.""" + from nemo_run.config import USE_WITH_RAY_CLUSTER_KEY + from nemo_run.run.ray.slurm import SlurmRayRequest + + # Create executor with heterogeneous job configuration + executor = SlurmExecutor( + account="test_account", + job_dir=temp_dir, + heterogeneous=True, + tunnel=LocalTunnel(job_dir=temp_dir), + ) + executor.resource_group = [ + SlurmExecutor.ResourceRequest( + packager=mock.MagicMock(), + nodes=2, + ntasks_per_node=8, + gpus_per_node=8, + container_image="nvcr.io/nvidia/pytorch:24.01-py3", + container_mounts=[], + het_group_index=0, + ), + SlurmExecutor.ResourceRequest( + packager=mock.MagicMock(), + nodes=1, + ntasks_per_node=1, + gpus_per_node=0, + container_image="nvcr.io/nvidia/pytorch:24.01-py3", + container_mounts=[], + het_group_index=1, + ), + ] + + # Create a Ray-enabled app with 2 roles (matching resource groups) + app_def = AppDef( + name="test_ray_het_app", + roles=[ + Role(name="ray_cluster", image="", entrypoint="python", args=["train.py"]), + Role(name="auxiliary", image="", entrypoint="python", args=["monitor.py"]), + ], + metadata={USE_WITH_RAY_CLUSTER_KEY: True}, + ) + + with ( + mock.patch.object(SlurmTunnelScheduler, "_initialize_tunnel"), + mock.patch.object(SlurmExecutor, "package"), + mock.patch("builtins.open", mock.mock_open()), + mock.patch("nemo_run.core.execution.utils.fill_template") as mock_fill, + ): + slurm_scheduler.tunnel = mock.MagicMock() + mock_fill.return_value = "#!/bin/bash\n# Mock script" + + # Initially run_as_group should not be set + assert not hasattr(executor, "run_as_group") or not executor.run_as_group + + dryrun_info = slurm_scheduler._submit_dryrun(app_def, executor) + + # Verify run_as_group was automatically set + assert executor.run_as_group is True + assert isinstance(dryrun_info.request, SlurmRayRequest) + assert dryrun_info.request.executor.heterogeneous is True + assert len(dryrun_info.request.command_groups) == 2 + + +def test_heterogeneous_ray_cluster_mismatched_groups_warning(slurm_scheduler, temp_dir, caplog): + """Test that a warning is logged when roles don't match resource groups.""" + from nemo_run.config import USE_WITH_RAY_CLUSTER_KEY + from nemo_run.run.ray.slurm import SlurmRayRequest + + # Create executor with 2 resource groups + executor = SlurmExecutor( + account="test_account", + job_dir=temp_dir, + heterogeneous=True, + tunnel=LocalTunnel(job_dir=temp_dir), + ) + executor.resource_group = [ + SlurmExecutor.ResourceRequest( + packager=mock.MagicMock(), + nodes=2, + ntasks_per_node=8, + gpus_per_node=8, + container_image="nvcr.io/nvidia/pytorch:24.01-py3", + container_mounts=[], + het_group_index=0, + ), + SlurmExecutor.ResourceRequest( + packager=mock.MagicMock(), + nodes=1, + ntasks_per_node=1, + gpus_per_node=0, + container_image="nvcr.io/nvidia/pytorch:24.01-py3", + container_mounts=[], + het_group_index=1, + ), + ] + + # Create a Ray-enabled app with 3 roles (mismatched with 2 resource groups) + app_def = AppDef( + name="test_ray_het_app", + roles=[ + Role(name="ray_cluster", image="", entrypoint="python", args=["train.py"]), + Role(name="auxiliary", image="", entrypoint="python", args=["monitor.py"]), + Role(name="extra", image="", entrypoint="python", args=["extra.py"]), + ], + metadata={USE_WITH_RAY_CLUSTER_KEY: True}, + ) + + with ( + mock.patch.object(SlurmTunnelScheduler, "_initialize_tunnel"), + mock.patch.object(SlurmExecutor, "package"), + mock.patch("builtins.open", mock.mock_open()), + mock.patch("nemo_run.core.execution.utils.fill_template") as mock_fill, + ): + slurm_scheduler.tunnel = mock.MagicMock() + mock_fill.return_value = "#!/bin/bash\n# Mock script" + + with caplog.at_level(logging.WARNING): + dryrun_info = slurm_scheduler._submit_dryrun(app_def, executor) + + # Verify warning was logged + assert any("resource groups" in record.message for record in caplog.records) + assert any("3 roles" in record.message for record in caplog.records) + assert any("2 resource groups" in record.message for record in caplog.records) + + # Verify request was still created + assert isinstance(dryrun_info.request, SlurmRayRequest) + assert executor.run_as_group is True + + +def test_heterogeneous_ray_cluster_no_resource_group(slurm_scheduler, temp_dir): + """Test that heterogeneous jobs without resource_group raise an AssertionError.""" + from nemo_run.config import USE_WITH_RAY_CLUSTER_KEY + + # Create executor with heterogeneous=True but no resource_group + executor = SlurmExecutor( + account="test_account", + job_dir=temp_dir, + heterogeneous=True, + tunnel=LocalTunnel(job_dir=temp_dir), + ) + # Don't set resource_group + + # Create a Ray-enabled app + app_def = AppDef( + name="test_ray_het_app", + roles=[Role(name="ray_cluster", image="", entrypoint="python", args=["train.py"])], + metadata={USE_WITH_RAY_CLUSTER_KEY: True}, + ) + + with ( + mock.patch.object(SlurmTunnelScheduler, "_initialize_tunnel"), + mock.patch.object(SlurmExecutor, "package"), + mock.patch("builtins.open", mock.mock_open()), + ): + slurm_scheduler.tunnel = mock.MagicMock() + + # Should raise AssertionError because resource_group is required for het jobs + with pytest.raises(AssertionError, match="heterogeneous requires resource_group to be set"): + slurm_scheduler._submit_dryrun(app_def, executor) + + +def test_non_heterogeneous_ray_cluster(slurm_scheduler, temp_dir): + """Test that run_as_group is NOT set for non-heterogeneous clusters.""" + from nemo_run.config import USE_WITH_RAY_CLUSTER_KEY + from nemo_run.run.ray.slurm import SlurmRayRequest + + # Create executor without heterogeneous + executor = SlurmExecutor( + account="test_account", + job_dir=temp_dir, + tunnel=LocalTunnel(job_dir=temp_dir), + ) + + # Create a Ray-enabled app + app_def = AppDef( + name="test_ray_app", + roles=[Role(name="ray_cluster", image="", entrypoint="python", args=["train.py"])], + metadata={USE_WITH_RAY_CLUSTER_KEY: True}, + ) + + with ( + mock.patch.object(SlurmTunnelScheduler, "_initialize_tunnel"), + mock.patch.object(SlurmExecutor, "package"), + mock.patch("builtins.open", mock.mock_open()), + mock.patch("nemo_run.core.execution.utils.fill_template") as mock_fill, + ): + slurm_scheduler.tunnel = mock.MagicMock() + mock_fill.return_value = "#!/bin/bash\n# Mock script" + + dryrun_info = slurm_scheduler._submit_dryrun(app_def, executor) + + # Verify run_as_group was NOT set + assert not hasattr(executor, "run_as_group") or not executor.run_as_group + assert isinstance(dryrun_info.request, SlurmRayRequest) From 41d58e0e44beae7e248cf239fedae95d337652df Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Fri, 9 Jan 2026 09:00:37 -0800 Subject: [PATCH 2/3] fix Signed-off-by: Hemil Desai --- nemo_run/run/ray/slurm.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/nemo_run/run/ray/slurm.py b/nemo_run/run/ray/slurm.py index 72d79fe6..d229261b 100644 --- a/nemo_run/run/ray/slurm.py +++ b/nemo_run/run/ray/slurm.py @@ -1259,9 +1259,9 @@ def start( if isinstance(self.executor.tunnel, SSHTunnel): # Rsync workdir honouring .gitignore self.executor.tunnel.connect() - assert ( - self.executor.tunnel.session is not None - ), "Tunnel session is not connected" + assert self.executor.tunnel.session is not None, ( + "Tunnel session is not connected" + ) rsync( self.executor.tunnel.session, workdir, @@ -1316,9 +1316,9 @@ def start( if isinstance(self.executor.tunnel, SSHTunnel): self.executor.tunnel.connect() - assert ( - self.executor.tunnel.session is not None - ), "Tunnel session is not connected" + assert self.executor.tunnel.session is not None, ( + "Tunnel session is not connected" + ) rsync( self.executor.tunnel.session, os.path.join(local_code_extraction_path, ""), From 6e9f6a629329302eb480bbaad67ebff5e4dbc862 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Mon, 12 Jan 2026 11:21:10 -0800 Subject: [PATCH 3/3] fix Signed-off-by: Hemil Desai --- test/run/ray/test_slurm_ray_request.py | 364 +++++++++++++++++++++++++ 1 file changed, 364 insertions(+) diff --git a/test/run/ray/test_slurm_ray_request.py b/test/run/ray/test_slurm_ray_request.py index 2ed6acba..d9a41ae7 100644 --- a/test/run/ray/test_slurm_ray_request.py +++ b/test/run/ray/test_slurm_ray_request.py @@ -964,3 +964,367 @@ def test_heterogeneous_artifact( expected_script = f.read() assert generated_script.strip() == expected_script.strip() + + def test_heterogeneous_with_het_group_indices(self): + """Test het job with explicit het_group_indices for final_group_index calculation.""" + from unittest.mock import Mock + + executor = SlurmExecutor( + account="test_account", + partition="gpu", + heterogeneous=True, + ) + executor.run_as_group = True + # Create resource groups with explicit het_group_indices + executor.resource_group = [ + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=2, + ntasks_per_node=8, + gpus_per_node=8, + container_image="gpu_image", + container_mounts=["/data:/data"], + het_group_index=0, + ), + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=1, + gpus_per_node=0, + container_image="cpu_image", + container_mounts=["/data:/data"], + het_group_index=2, # Non-sequential index + ), + ] + executor.tunnel = Mock(spec=SSHTunnel) + executor.tunnel.job_dir = "/tmp/test_jobs" + + request = SlurmRayRequest( + name="test-ray-het-cluster", + cluster_dir="/tmp/test_jobs/test-ray-het-cluster", + template_name="ray.sub.j2", + executor=executor, + launch_cmd=["sbatch", "--parsable"], + ) + + script = request.materialize() + + # Should have het job structure + assert "#SBATCH hetjob" in script + # Should have both het group hostnames + assert "het_group_host_0" in script + assert "het_group_host_1" in script + + def test_heterogeneous_duplicate_het_group_index_skipped(self): + """Test that duplicate het_group_index ResourceRequests are skipped.""" + from unittest.mock import Mock + + executor = SlurmExecutor( + account="test_account", + partition="gpu", + heterogeneous=True, + ) + executor.run_as_group = True + # Create resource groups where two share the same het_group_index + executor.resource_group = [ + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=2, + ntasks_per_node=8, + gpus_per_node=8, + container_image="gpu_image", + container_mounts=["/data:/data"], + het_group_index=0, + ), + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=2, + ntasks_per_node=8, + gpus_per_node=8, + container_image="gpu_image2", + container_mounts=["/data:/data"], + het_group_index=0, # Same as previous - should be skipped + ), + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=1, + gpus_per_node=0, + container_image="cpu_image", + container_mounts=["/data:/data"], + het_group_index=1, + ), + ] + executor.tunnel = Mock(spec=SSHTunnel) + executor.tunnel.job_dir = "/tmp/test_jobs" + + request = SlurmRayRequest( + name="test-ray-het-cluster", + cluster_dir="/tmp/test_jobs/test-ray-het-cluster", + template_name="ray.sub.j2", + executor=executor, + launch_cmd=["sbatch", "--parsable"], + ) + + script = request.materialize() + + # Should only have one #SBATCH hetjob separator (2 het groups, not 3) + assert script.count("#SBATCH hetjob") == 1 + # Should have het group hostnames for each unique het group + assert "het_group_host_0" in script + assert "het_group_host_1" in script + + def test_heterogeneous_with_gpus_per_task(self): + """Test het job with gpus_per_task set in ResourceRequest.""" + from unittest.mock import Mock + + executor = SlurmExecutor( + account="test_account", + partition="gpu", + heterogeneous=True, + ) + executor.run_as_group = True + executor.resource_group = [ + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=8, + gpus_per_node=8, + gpus_per_task=1, # Explicit gpus_per_task + container_image="gpu_image", + container_mounts=["/data:/data"], + het_group_index=0, + ), + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=1, + gpus_per_node=0, + container_image="cpu_image", + container_mounts=["/data:/data"], + het_group_index=1, + ), + ] + executor.tunnel = Mock(spec=SSHTunnel) + executor.tunnel.job_dir = "/tmp/test_jobs" + + request = SlurmRayRequest( + name="test-ray-het-cluster", + cluster_dir="/tmp/test_jobs/test-ray-het-cluster", + template_name="ray.sub.j2", + executor=executor, + launch_cmd=["sbatch", "--parsable"], + ) + + script = request.materialize() + + # First het group should have gpus-per-task + lines = script.split("\n") + het_job_idx = None + for i, line in enumerate(lines): + if "#SBATCH hetjob" in line: + het_job_idx = i + break + + before_hetjob = "\n".join(lines[:het_job_idx]) + assert "#SBATCH --gpus-per-task=1" in before_hetjob + + def test_heterogeneous_with_separate_stderr(self): + """Test het job with stderr_to_stdout=False generates error paths.""" + from unittest.mock import Mock + + executor = SlurmExecutor( + account="test_account", + partition="gpu", + heterogeneous=True, + ) + executor.stderr_to_stdout = False # Separate stderr + executor.run_as_group = True + executor.resource_group = [ + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=2, + ntasks_per_node=8, + gpus_per_node=8, + container_image="gpu_image", + container_mounts=["/data:/data"], + het_group_index=0, + ), + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=1, + gpus_per_node=0, + container_image="cpu_image", + container_mounts=["/data:/data"], + het_group_index=1, + ), + ] + executor.tunnel = Mock(spec=SSHTunnel) + executor.tunnel.job_dir = "/tmp/test_jobs" + + request = SlurmRayRequest( + name="test-ray-het-cluster", + cluster_dir="/tmp/test_jobs/test-ray-het-cluster", + template_name="ray.sub.j2", + executor=executor, + launch_cmd=["sbatch", "--parsable"], + ) + + script = request.materialize() + + # Should have separate error output paths for each het group + assert "#SBATCH --error=" in script + + def test_heterogeneous_command_groups_without_het_group_index(self): + """Test het command groups fallback to idx when het_group_index is None.""" + from unittest.mock import Mock + + executor = SlurmExecutor( + account="test_account", + heterogeneous=True, + ) + executor.run_as_group = True + # Resource groups WITHOUT het_group_index set + executor.resource_group = [ + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=8, + gpus_per_node=8, + container_image="image1", + container_mounts=["/data:/data"], + # het_group_index not set - should fall back to idx + ), + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=1, + gpus_per_node=0, + container_image="image2", + container_mounts=["/data:/data"], + # het_group_index not set - should fall back to idx + ), + ] + executor.tunnel = Mock(spec=SSHTunnel) + executor.tunnel.job_dir = "/tmp/test_jobs" + + request = SlurmRayRequest( + name="test-ray-het-cluster", + cluster_dir="/tmp/test_jobs/test-ray-het-cluster", + template_name="ray.sub.j2", + executor=executor, + command_groups=[["cmd0"], ["cmd1"]], + launch_cmd=["sbatch", "--parsable"], + ) + + script = request.materialize() + + # Should have het-group flags using idx fallback + assert "--het-group=1" in script # command_groups[1] uses het-group=1 (idx fallback) + + def test_heterogeneous_without_run_as_group(self): + """Test het job without run_as_group does not add het-group flags to commands.""" + from unittest.mock import Mock + + executor = SlurmExecutor( + account="test_account", + heterogeneous=True, + ) + # run_as_group NOT set + executor.resource_group = [ + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=8, + gpus_per_node=8, + container_image="image1", + container_mounts=["/data:/data"], + het_group_index=0, + ), + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=1, + gpus_per_node=0, + container_image="image2", + container_mounts=["/data:/data"], + het_group_index=1, + ), + ] + executor.tunnel = Mock(spec=SSHTunnel) + executor.tunnel.job_dir = "/tmp/test_jobs" + + request = SlurmRayRequest( + name="test-ray-het-cluster", + cluster_dir="/tmp/test_jobs/test-ray-het-cluster", + template_name="ray.sub.j2", + executor=executor, + command_groups=[["cmd0"], ["cmd1"]], + launch_cmd=["sbatch", "--parsable"], + ) + + script = request.materialize() + + # SBATCH het job structure should still exist + assert "#SBATCH hetjob" in script + # But command groups should NOT have --het-group flags (run_as_group not set) + # The overlap srun commands should not have --het-group=1 + # Find the overlap srun command + lines = script.split("\n") + overlap_srun_lines = [line for line in lines if "overlap" in line and "srun" in line] + for line in overlap_srun_lines: + # These should NOT have --het-group since run_as_group is not set + if "cmd1" in line: + assert "--het-group=1" not in line + + def test_heterogeneous_mismatched_command_groups_length(self): + """Test het job when command_groups length doesn't match resource_group length.""" + from unittest.mock import Mock + + executor = SlurmExecutor( + account="test_account", + heterogeneous=True, + ) + executor.run_as_group = True + executor.resource_group = [ + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=8, + gpus_per_node=8, + container_image="image1", + container_mounts=["/data:/data"], + het_group_index=0, + ), + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=1, + gpus_per_node=0, + container_image="image2", + container_mounts=["/data:/data"], + het_group_index=1, + ), + ] + executor.tunnel = Mock(spec=SSHTunnel) + executor.tunnel.job_dir = "/tmp/test_jobs" + + # 3 command groups but only 2 resource groups - mismatched + request = SlurmRayRequest( + name="test-ray-het-cluster", + cluster_dir="/tmp/test_jobs/test-ray-het-cluster", + template_name="ray.sub.j2", + executor=executor, + command_groups=[["cmd0"], ["cmd1"], ["cmd2"]], + launch_cmd=["sbatch", "--parsable"], + ) + + script = request.materialize() + + # Should still generate script but WITHOUT het-group flags + # (because lengths don't match) + assert "#SBATCH hetjob" in script + # Overlap commands should NOT have --het-group flags + assert "--het-group=1" not in script + assert "--het-group=2" not in script