Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 79 additions & 5 deletions src/slurm_plugin/clustermgtd.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ class ClustermgtdConfig:
"terminate_down_nodes": True,
"orphaned_instance_timeout": 300,
"ec2_instance_missing_max_count": 0,
"hold_drain_nodes_timeout": 30,
"hold_drain_nodes_reasons": ["Prolog error"],
# Health check configs
"disable_ec2_health_check": False,
"disable_scheduled_event_health_check": False,
Expand Down Expand Up @@ -293,6 +295,17 @@ def _get_terminate_config(self, config):
self.terminate_drain_nodes = config.getboolean(
"clustermgtd", "terminate_drain_nodes", fallback=self.DEFAULTS.get("terminate_drain_nodes")
)
self.hold_drain_nodes_timeout = config.getint(
"clustermgtd", "hold_drain_nodes_timeout", fallback=self.DEFAULTS.get("hold_drain_nodes_timeout")
)
# Parse comma-separated list of reasons
hold_drain_nodes_reasons_str = config.get(
"clustermgtd", "hold_drain_nodes_reasons", fallback=None
)
if hold_drain_nodes_reasons_str:
self.hold_drain_nodes_reasons = [r.strip() for r in hold_drain_nodes_reasons_str.split(",")]
else:
self.hold_drain_nodes_reasons = self.DEFAULTS.get("hold_drain_nodes_reasons")
self.terminate_down_nodes = config.getboolean(
"clustermgtd", "terminate_down_nodes", fallback=self.DEFAULTS.get("terminate_down_nodes")
)
Expand Down Expand Up @@ -388,6 +401,7 @@ def __init__(self, config):
This state is required because we need to ignore static nodes that might have long bootstrap time
"""
self._insufficient_capacity_compute_resources = {}
self._held_compute_resources = {}
self._static_nodes_in_replacement = set()
self._partitions_protected_failure_count_map = {}
self._nodes_without_backing_instance_count_map = {}
Expand Down Expand Up @@ -783,6 +797,11 @@ def _find_unhealthy_slurm_nodes(self, slurm_nodes):
# do not consider as unhealthy the nodes reserved for capacity blocks
continue

# Track when the node was first found unhealthy, only if drain reason matches configured reasons
if node.name not in self._held_compute_resources:
if node.reason and any(reason in node.reason for reason in self._config.hold_drain_nodes_reasons):
self._held_compute_resources[node.name] = self._current_time

all_unhealthy_nodes.append(node)

if isinstance(node, StaticNode):
Expand All @@ -798,6 +817,14 @@ def _find_unhealthy_slurm_nodes(self, slurm_nodes):
self._config.ec2_instance_missing_max_count,
self._nodes_without_backing_instance_count_map,
)

# Clean up nodes that are no longer unhealthy from _held_compute_resources
unhealthy_node_names = {node.name for node in all_unhealthy_nodes}
self._held_compute_resources = {
name: timestamp for name, timestamp in self._held_compute_resources.items()
if name in unhealthy_node_names
}

return (
unhealthy_dynamic_nodes,
unhealthy_static_nodes,
Expand All @@ -822,18 +849,40 @@ def _handle_unhealthy_dynamic_nodes(self, unhealthy_dynamic_nodes):
"""
Maintain any unhealthy dynamic node.

Terminate instances backing dynamic nodes.
Terminate instances backing dynamic nodes only after hold_drain_nodes_timeout.
Setting node to down will let slurm requeue jobs allocated to node.
Setting node to power_down will terminate backing instance and reset dynamic node for future use.
"""
instances_to_terminate = [node.instance.id for node in unhealthy_dynamic_nodes if node.instance]
# Filter to only nodes that have exceeded the hold timeout (config is in minutes)
timeout_seconds = self._config.hold_drain_nodes_timeout * 60
nodes_to_terminate = []
nodes_being_held = []
for node in unhealthy_dynamic_nodes:
if node.name not in self._held_compute_resources:
nodes_to_terminate.append(node)
elif time_is_up(self._held_compute_resources[node.name], self._current_time, timeout_seconds):
nodes_to_terminate.append(node)
self._held_compute_resources.pop(node.name, None)
else:
elapsed = (self._current_time - self._held_compute_resources[node.name]).total_seconds()
remaining = int(timeout_seconds - elapsed)
nodes_being_held.append(f"{node.name}({remaining}s left)")

if nodes_being_held:
log.info(
"Holding termination for unhealthy dynamic nodes (timeout: %sm): %s",
self._config.hold_drain_nodes_timeout,
nodes_being_held,
)

instances_to_terminate = [node.instance.id for node in nodes_to_terminate if node.instance]
if instances_to_terminate:
log.info("Terminating instances that are backing unhealthy dynamic nodes")
self._instance_manager.delete_instances(
instances_to_terminate, terminate_batch_size=self._config.terminate_max_batch_size
)
log.info("Setting unhealthy dynamic nodes to down and power_down.")
set_nodes_power_down([node.name for node in unhealthy_dynamic_nodes], reason="Scheduler health check failed")
set_nodes_power_down([node.name for node in nodes_to_terminate], reason="Scheduler health check failed")

@log_exception(log, "maintaining powering down nodes", raise_on_error=False)
def _handle_powering_down_nodes(self, slurm_nodes):
Expand Down Expand Up @@ -880,15 +929,40 @@ def _handle_unhealthy_static_nodes(self, unhealthy_static_nodes):
except Exception as e:
log.error("Encountered exception when retrieving console output from unhealthy static nodes: %s", e)

node_list = [node.name for node in unhealthy_static_nodes]
# Config is in minutes, convert to seconds
timeout_seconds = self._config.hold_drain_nodes_timeout * 60
nodes_to_terminate = []
nodes_being_held = []
for node in unhealthy_static_nodes:
if node.name not in self._held_compute_resources:
nodes_to_terminate.append(node)
elif time_is_up(self._held_compute_resources[node.name], self._current_time, timeout_seconds):
nodes_to_terminate.append(node)
self._held_compute_resources.pop(node.name, None)
else:
elapsed = (self._current_time - self._held_compute_resources[node.name]).total_seconds()
remaining = int(timeout_seconds - elapsed)
nodes_being_held.append(f"{node.name}({remaining}s left)")

if nodes_being_held:
log.info(
"Holding termination for unhealthy static nodes (timeout: %sm): %s",
self._config.hold_drain_nodes_timeout,
nodes_being_held,
)

if not nodes_to_terminate:
return

node_list = [node.name for node in nodes_to_terminate]
# Set nodes into down state so jobs can be requeued immediately
try:
log.info("Setting unhealthy static nodes to DOWN")
reset_nodes(node_list, state="down", reason="Static node maintenance: unhealthy node is being replaced")
except Exception as e:
log.error("Encountered exception when setting unhealthy static nodes into down state: %s", e)

instances_to_terminate = [node.instance.id for node in unhealthy_static_nodes if node.instance]
instances_to_terminate = [node.instance.id for node in nodes_to_terminate if node.instance]

if instances_to_terminate:
log.info("Terminating instances backing unhealthy static nodes")
Expand Down
Loading