Skip to content

Commit 58002f0

Browse files
committed
fix: hard-sync kernel_registry to real containers
1 parent 11f42db commit 58002f0

2 files changed

Lines changed: 50 additions & 21 deletions

File tree

changes/2179.fix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Sync agent's kernel registry with the actual container through periodic loop.

src/ai/backend/agent/agent.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import zlib
1717
from abc import ABCMeta, abstractmethod
1818
from collections import defaultdict
19+
from collections.abc import Container as ContainerT
1920
from decimal import Decimal
2021
from io import SEEK_END, BytesIO
2122
from pathlib import Path
@@ -1076,27 +1077,10 @@ async def _handle_clean_event(self, ev: ContainerLifecycleEvent) -> None:
10761077
ev.done_future.set_exception(e)
10771078
await self.produce_error_event()
10781079
finally:
1079-
if ev.kernel_id in self.restarting_kernels:
1080-
# Don't forget as we are restarting it.
1081-
kernel_obj = self.kernel_registry.get(ev.kernel_id, None)
1082-
else:
1083-
# Forget as we are done with this kernel.
1084-
kernel_obj = self.kernel_registry.pop(ev.kernel_id, None)
1080+
kernel_obj = self.kernel_registry.get(ev.kernel_id, None)
10851081
try:
10861082
if kernel_obj is not None:
1087-
# Restore used ports to the port pool.
1088-
port_range = self.local_config["container"]["port-range"]
1089-
# Exclude out-of-range ports, because when the agent restarts
1090-
# with a different port range, existing containers' host ports
1091-
# may not belong to the new port range.
1092-
if host_ports := kernel_obj.get("host_ports"):
1093-
restored_ports = [
1094-
*filter(
1095-
lambda p: port_range[0] <= p <= port_range[1],
1096-
host_ports,
1097-
)
1098-
]
1099-
self.port_pool.update(restored_ports)
1083+
await self._restore_port_pool(kernel_obj)
11001084
await kernel_obj.close()
11011085
finally:
11021086
self.terminating_kernels.discard(ev.kernel_id)
@@ -1116,6 +1100,20 @@ async def _handle_clean_event(self, ev: ContainerLifecycleEvent) -> None:
11161100
if ev.done_future is not None and not ev.done_future.done():
11171101
ev.done_future.set_result(None)
11181102

1103+
async def _restore_port_pool(self, kernel_obj: AbstractKernel) -> None:
1104+
port_range = self.local_config["container"]["port-range"]
1105+
# Exclude out-of-range ports, because when the agent restarts
1106+
# with a different port range, existing containers' host ports
1107+
# may not belong to the new port range.
1108+
if host_ports := kernel_obj.get("host_ports"):
1109+
restored_ports = [
1110+
*filter(
1111+
lambda p: port_range[0] <= p <= port_range[1],
1112+
host_ports,
1113+
)
1114+
]
1115+
self.port_pool.update(restored_ports)
1116+
11191117
async def process_lifecycle_events(self) -> None:
11201118
async def lifecycle_task_exception_handler(
11211119
exc_type: Type[Exception],
@@ -1260,6 +1258,8 @@ async def sync_container_lifecycles(self, interval: float) -> None:
12601258
for cases when we miss the container lifecycle events from the underlying implementation APIs
12611259
due to the agent restarts or crashes.
12621260
"""
1261+
all_detected_kernels: set[KernelId] = set()
1262+
12631263
known_kernels: Dict[KernelId, ContainerId] = {}
12641264
alive_kernels: Dict[KernelId, ContainerId] = {}
12651265
kernel_session_map: Dict[KernelId, SessionId] = {}
@@ -1270,6 +1270,7 @@ async def sync_container_lifecycles(self, interval: float) -> None:
12701270
try:
12711271
# Check if: there are dead containers
12721272
for kernel_id, container in await self.enumerate_containers(DEAD_STATUS_SET):
1273+
all_detected_kernels.add(kernel_id)
12731274
if (
12741275
kernel_id in self.restarting_kernels
12751276
or kernel_id in self.terminating_kernels
@@ -1289,6 +1290,7 @@ async def sync_container_lifecycles(self, interval: float) -> None:
12891290
KernelLifecycleEventReason.SELF_TERMINATED,
12901291
)
12911292
for kernel_id, container in await self.enumerate_containers(ACTIVE_STATUS_SET):
1293+
all_detected_kernels.add(kernel_id)
12921294
alive_kernels[kernel_id] = container.id
12931295
session_id = SessionId(UUID(container.labels["ai.backend.session-id"]))
12941296
kernel_session_map[kernel_id] = session_id
@@ -1323,13 +1325,41 @@ async def sync_container_lifecycles(self, interval: float) -> None:
13231325
KernelLifecycleEventReason.TERMINATED_UNKNOWN_CONTAINER,
13241326
)
13251327
finally:
1328+
await self.prune_kernel_registry(all_detected_kernels)
13261329
# Enqueue the events.
13271330
for kernel_id, ev in terminated_kernels.items():
13281331
await self.container_lifecycle_queue.put(ev)
13291332

13301333
# Set container count
13311334
await self.set_container_count(len(own_kernels.keys()))
13321335

1336+
async def prune_kernel_registry(
1337+
self, detected_kernels: ContainerT[KernelId], *, ensure_cleaned: bool = True
1338+
) -> None:
1339+
"""
1340+
Deregister containerless kernels from `kernel_registry`
1341+
since `_handle_clean_event()` does not deregister them.
1342+
"""
1343+
any_container_pruned = False
1344+
for kernel_id in [*self.kernel_registry.keys()]:
1345+
if kernel_id not in detected_kernels:
1346+
if ensure_cleaned:
1347+
# Don't need to process this through event task
1348+
# since there is no communication with any container here.
1349+
kernel_obj = self.kernel_registry[kernel_id]
1350+
kernel_obj.stats_enabled = False
1351+
if kernel_obj.runner is not None:
1352+
await kernel_obj.runner.close()
1353+
if kernel_obj.clean_event is not None and not kernel_obj.clean_event.done():
1354+
kernel_obj.clean_event.set_result(None)
1355+
await self._restore_port_pool(kernel_obj)
1356+
await kernel_obj.close()
1357+
del self.kernel_registry[kernel_id]
1358+
self.terminating_kernels.discard(kernel_id)
1359+
any_container_pruned = True
1360+
if any_container_pruned:
1361+
await self.reconstruct_resource_usage()
1362+
13331363
async def set_container_count(self, container_count: int) -> None:
13341364
await redis_helper.execute(
13351365
self.redis_stat_pool, lambda r: r.set(f"container_count.{self.id}", container_count)
@@ -2025,8 +2055,6 @@ async def create_kernel(
20252055
" unregistered.",
20262056
kernel_id,
20272057
)
2028-
async with self.registry_lock:
2029-
del self.kernel_registry[kernel_id]
20302058
raise
20312059
async with self.registry_lock:
20322060
self.kernel_registry[ctx.kernel_id].data.update(container_data)

0 commit comments

Comments
 (0)