diff --git a/src/slurm_plugin/clustermgtd.py b/src/slurm_plugin/clustermgtd.py index c612dea1..d4a698f3 100644 --- a/src/slurm_plugin/clustermgtd.py +++ b/src/slurm_plugin/clustermgtd.py @@ -151,6 +151,8 @@ class ClustermgtdConfig: "terminate_down_nodes": True, "orphaned_instance_timeout": 300, "ec2_instance_missing_max_count": 0, + "hold_drain_nodes_timeout": 30, + "hold_drain_nodes_reasons": ["Prolog error"], # Health check configs "disable_ec2_health_check": False, "disable_scheduled_event_health_check": False, @@ -293,6 +295,17 @@ def _get_terminate_config(self, config): self.terminate_drain_nodes = config.getboolean( "clustermgtd", "terminate_drain_nodes", fallback=self.DEFAULTS.get("terminate_drain_nodes") ) + self.hold_drain_nodes_timeout = config.getint( + "clustermgtd", "hold_drain_nodes_timeout", fallback=self.DEFAULTS.get("hold_drain_nodes_timeout") + ) + # Parse comma-separated list of reasons + hold_drain_nodes_reasons_str = config.get( + "clustermgtd", "hold_drain_nodes_reasons", fallback=None + ) + if hold_drain_nodes_reasons_str: + self.hold_drain_nodes_reasons = [r.strip() for r in hold_drain_nodes_reasons_str.split(",")] + else: + self.hold_drain_nodes_reasons = self.DEFAULTS.get("hold_drain_nodes_reasons") self.terminate_down_nodes = config.getboolean( "clustermgtd", "terminate_down_nodes", fallback=self.DEFAULTS.get("terminate_down_nodes") ) @@ -388,6 +401,7 @@ def __init__(self, config): This state is required because we need to ignore static nodes that might have long bootstrap time """ self._insufficient_capacity_compute_resources = {} + self._held_compute_resources = {} self._static_nodes_in_replacement = set() self._partitions_protected_failure_count_map = {} self._nodes_without_backing_instance_count_map = {} @@ -783,6 +797,11 @@ def _find_unhealthy_slurm_nodes(self, slurm_nodes): # do not consider as unhealthy the nodes reserved for capacity blocks continue + # Track when the node was first found unhealthy, only if drain reason matches configured reasons + if node.name not in self._held_compute_resources: + if node.reason and any(reason in node.reason for reason in self._config.hold_drain_nodes_reasons): + self._held_compute_resources[node.name] = self._current_time + all_unhealthy_nodes.append(node) if isinstance(node, StaticNode): @@ -798,6 +817,14 @@ def _find_unhealthy_slurm_nodes(self, slurm_nodes): self._config.ec2_instance_missing_max_count, self._nodes_without_backing_instance_count_map, ) + + # Clean up nodes that are no longer unhealthy from _held_compute_resources + unhealthy_node_names = {node.name for node in all_unhealthy_nodes} + self._held_compute_resources = { + name: timestamp for name, timestamp in self._held_compute_resources.items() + if name in unhealthy_node_names + } + return ( unhealthy_dynamic_nodes, unhealthy_static_nodes, @@ -822,18 +849,40 @@ def _handle_unhealthy_dynamic_nodes(self, unhealthy_dynamic_nodes): """ Maintain any unhealthy dynamic node. - Terminate instances backing dynamic nodes. + Terminate instances backing dynamic nodes only after hold_drain_nodes_timeout. Setting node to down will let slurm requeue jobs allocated to node. Setting node to power_down will terminate backing instance and reset dynamic node for future use. """ - instances_to_terminate = [node.instance.id for node in unhealthy_dynamic_nodes if node.instance] + # Filter to only nodes that have exceeded the hold timeout (config is in minutes) + timeout_seconds = self._config.hold_drain_nodes_timeout * 60 + nodes_to_terminate = [] + nodes_being_held = [] + for node in unhealthy_dynamic_nodes: + if node.name not in self._held_compute_resources: + nodes_to_terminate.append(node) + elif time_is_up(self._held_compute_resources[node.name], self._current_time, timeout_seconds): + nodes_to_terminate.append(node) + self._held_compute_resources.pop(node.name, None) + else: + elapsed = (self._current_time - self._held_compute_resources[node.name]).total_seconds() + remaining = int(timeout_seconds - elapsed) + nodes_being_held.append(f"{node.name}({remaining}s left)") + + if nodes_being_held: + log.info( + "Holding termination for unhealthy dynamic nodes (timeout: %sm): %s", + self._config.hold_drain_nodes_timeout, + nodes_being_held, + ) + + instances_to_terminate = [node.instance.id for node in nodes_to_terminate if node.instance] if instances_to_terminate: log.info("Terminating instances that are backing unhealthy dynamic nodes") self._instance_manager.delete_instances( instances_to_terminate, terminate_batch_size=self._config.terminate_max_batch_size ) log.info("Setting unhealthy dynamic nodes to down and power_down.") - set_nodes_power_down([node.name for node in unhealthy_dynamic_nodes], reason="Scheduler health check failed") + set_nodes_power_down([node.name for node in nodes_to_terminate], reason="Scheduler health check failed") @log_exception(log, "maintaining powering down nodes", raise_on_error=False) def _handle_powering_down_nodes(self, slurm_nodes): @@ -880,7 +929,32 @@ def _handle_unhealthy_static_nodes(self, unhealthy_static_nodes): except Exception as e: log.error("Encountered exception when retrieving console output from unhealthy static nodes: %s", e) - node_list = [node.name for node in unhealthy_static_nodes] + # Config is in minutes, convert to seconds + timeout_seconds = self._config.hold_drain_nodes_timeout * 60 + nodes_to_terminate = [] + nodes_being_held = [] + for node in unhealthy_static_nodes: + if node.name not in self._held_compute_resources: + nodes_to_terminate.append(node) + elif time_is_up(self._held_compute_resources[node.name], self._current_time, timeout_seconds): + nodes_to_terminate.append(node) + self._held_compute_resources.pop(node.name, None) + else: + elapsed = (self._current_time - self._held_compute_resources[node.name]).total_seconds() + remaining = int(timeout_seconds - elapsed) + nodes_being_held.append(f"{node.name}({remaining}s left)") + + if nodes_being_held: + log.info( + "Holding termination for unhealthy static nodes (timeout: %sm): %s", + self._config.hold_drain_nodes_timeout, + nodes_being_held, + ) + + if not nodes_to_terminate: + return + + node_list = [node.name for node in nodes_to_terminate] # Set nodes into down state so jobs can be requeued immediately try: log.info("Setting unhealthy static nodes to DOWN") @@ -888,7 +962,7 @@ def _handle_unhealthy_static_nodes(self, unhealthy_static_nodes): except Exception as e: log.error("Encountered exception when setting unhealthy static nodes into down state: %s", e) - instances_to_terminate = [node.instance.id for node in unhealthy_static_nodes if node.instance] + instances_to_terminate = [node.instance.id for node in nodes_to_terminate if node.instance] if instances_to_terminate: log.info("Terminating instances backing unhealthy static nodes")