Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 98 additions & 25 deletions code-interpreter/app/services/executor_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
import uuid
from collections.abc import Generator, Sequence
from contextlib import contextmanager, suppress
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -41,6 +41,10 @@

logger = logging.getLogger(__name__)

POD_DELETE_RETRIES = 3
POD_DELETE_RETRY_DELAY_SECONDS = 0.2
POD_DELETE_CONFIRM_TIMEOUT_SECONDS = 2.0


def _parse_exit_code(error: str) -> int | None:
"""Parse the exit code from a Kubernetes exec error channel message."""
Expand Down Expand Up @@ -75,7 +79,11 @@ def __init__(self) -> None:
except config.ConfigException:
config.load_kube_config()

self.v1 = client.CoreV1Api()
# Keep REST calls on a dedicated ApiClient. kubernetes.stream.stream mutates
# the ApiClient request path for websocket exec calls, so mixing CRUD and
# exec traffic on one client can leave later REST calls in a broken state.
self._rest_api_client = client.ApiClient()
self.v1 = client.CoreV1Api(api_client=self._rest_api_client)
self.namespace = KUBERNETES_EXECUTOR_NAMESPACE
self.image = KUBERNETES_EXECUTOR_IMAGE
self.service_account = KUBERNETES_EXECUTOR_SERVICE_ACCOUNT
Expand Down Expand Up @@ -263,20 +271,42 @@ def _wait_for_pod_ready(self, pod_name: str, timeout_sec: int = 30) -> None:
time.sleep(0.1)
raise RuntimeError(f"Pod {pod_name} did not become ready in {timeout_sec} seconds")

def _stream_pod_exec(
self,
pod_name: str,
command: list[str],
*,
stderr: bool,
stdin: bool,
stdout: bool,
tty: bool,
preload_content: bool = False,
) -> ws_client.WSClient:
"""Run a websocket exec call using an isolated ApiClient instance."""
stream_api = client.CoreV1Api(api_client=client.ApiClient())
return stream.stream(
stream_api.connect_get_namespaced_pod_exec,
pod_name,
self.namespace,
command=command,
stderr=stderr,
stdin=stdin,
stdout=stdout,
tty=tty,
_preload_content=preload_content,
)

def _upload_tar_to_pod(self, pod_name: str, tar_archive: bytes) -> None:
"""Upload and extract a tar archive into the pod's workspace."""
logger.info(f"Uploading tar archive ({len(tar_archive)} bytes) to pod {pod_name}")
exec_command = ["tar", "-x", "-C", "/workspace"]
resp = stream.stream(
self.v1.connect_get_namespaced_pod_exec,
resp = self._stream_pod_exec(
pod_name,
self.namespace,
command=exec_command,
stderr=True,
stdin=True,
stdout=True,
tty=False,
_preload_content=False,
)

resp.write_stdin(tar_archive)
Expand Down Expand Up @@ -314,17 +344,17 @@ def _upload_tar_to_pod(self, pod_name: str, tar_archive: bytes) -> None:

def _kill_python_process(self, pod_name: str) -> None:
"""Kill the Python process running in the pod."""
with suppress(Exception):
stream.stream(
self.v1.connect_get_namespaced_pod_exec,
try:
self._stream_pod_exec(
pod_name,
self.namespace,
command=["pkill", "-9", "python"],
stderr=False,
stdin=False,
stdout=False,
tty=False,
)
except Exception:
logger.warning("Failed to kill Python process in pod %s", pod_name, exc_info=True)

@contextmanager
def _run_in_pod(
Expand Down Expand Up @@ -369,16 +399,13 @@ def _run_in_pod(
start = time.perf_counter()
exec_command = ["python", "/workspace/__main__.py"]

exec_resp = stream.stream(
self.v1.connect_get_namespaced_pod_exec,
exec_resp = self._stream_pod_exec(
pod_name,
self.namespace,
command=exec_command,
stderr=True,
stdin=True,
stdout=True,
tty=False,
_preload_content=False,
)

yield _KubeExecContext(
Expand Down Expand Up @@ -409,16 +436,13 @@ def _extract_workspace_snapshot(self, pod_name: str) -> tuple[WorkspaceEntry, ..
]

logger.info(f"Starting tar extraction from pod {pod_name}")
resp = stream.stream(
self.v1.connect_get_namespaced_pod_exec,
resp = self._stream_pod_exec(
pod_name,
self.namespace,
command=exec_command,
stderr=True,
stdin=False,
stdout=True,
tty=False,
_preload_content=False,
)

base64_data = ""
Expand Down Expand Up @@ -485,14 +509,63 @@ def _extract_workspace_snapshot(self, pod_name: str) -> tuple[WorkspaceEntry, ..
logger.error(f"Failed to extract workspace snapshot: {e}", exc_info=True)
return tuple()

def _wait_for_pod_deleted(self, pod_name: str, timeout_sec: float) -> bool:
deadline = time.time() + timeout_sec
while time.time() < deadline:
try:
self.v1.read_namespaced_pod(pod_name, self.namespace)
except ApiException as e:
if e.status == 404:
return True
logger.warning(
"Error while checking pod deletion for %s in namespace %s: %s",
pod_name,
self.namespace,
e,
)
return False
time.sleep(0.1)
return False

def _cleanup_pod(self, pod_name: str) -> None:
"""Delete a pod and wait for cleanup."""
with suppress(ApiException):
self.v1.delete_namespaced_pod(
name=pod_name,
namespace=self.namespace,
body=client.V1DeleteOptions(grace_period_seconds=0),
)
"""Delete a pod and log any cleanup failures."""
for attempt in range(1, POD_DELETE_RETRIES + 1):
try:
self.v1.delete_namespaced_pod(
name=pod_name,
namespace=self.namespace,
body=client.V1DeleteOptions(grace_period_seconds=0),
)
except ApiException as e:
if e.status == 404:
return
logger.warning(
"Failed to delete pod %s in namespace %s on attempt %s/%s: %s",
pod_name,
self.namespace,
attempt,
POD_DELETE_RETRIES,
e,
)
else:
if self._wait_for_pod_deleted(pod_name, POD_DELETE_CONFIRM_TIMEOUT_SECONDS):
return
logger.warning(
"Pod %s still exists after delete request on attempt %s/%s",
pod_name,
attempt,
POD_DELETE_RETRIES,
)

if attempt < POD_DELETE_RETRIES:
time.sleep(POD_DELETE_RETRY_DELAY_SECONDS * attempt)

logger.error(
"Failed to confirm deletion of pod %s in namespace %s after %s attempts",
pod_name,
self.namespace,
POD_DELETE_RETRIES,
)

def execute_python(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

import base64
import io
import logging
import tarfile
from typing import Any
from unittest.mock import MagicMock, patch

import pytest
from kubernetes.client.exceptions import ApiException # type: ignore[import-untyped]

from app.services.executor_base import StreamChunk, StreamEvent, StreamResult
from app.services.executor_kubernetes import KubernetesExecutor
Expand All @@ -32,7 +34,13 @@ def executor() -> KubernetesExecutor:
inst.service_account = ""
pod_mock = MagicMock()
pod_mock.status.phase = "Running"
inst.v1.read_namespaced_pod.return_value = pod_mock

def _read_namespaced_pod(*args: object, **kwargs: object) -> MagicMock:
if inst.v1.delete_namespaced_pod.called:
raise ApiException(status=404)
return pod_mock

inst.v1.read_namespaced_pod.side_effect = _read_namespaced_pod
return inst


Expand Down Expand Up @@ -327,6 +335,62 @@ def test_streaming_cleans_up_pod(executor: KubernetesExecutor) -> None:
executor.v1.delete_namespaced_pod.assert_called_once()


def test_cleanup_retries_delete_failures(
executor: KubernetesExecutor, caplog: pytest.LogCaptureFixture
) -> None:
executor.v1.delete_namespaced_pod.side_effect = [
ApiException(status=500, reason="boom"),
None,
]
executor.v1.read_namespaced_pod.side_effect = [
MagicMock(),
ApiException(status=404),
]

with caplog.at_level(logging.WARNING):
executor._cleanup_pod("code-exec-test")

assert executor.v1.delete_namespaced_pod.call_count == 2
assert "Failed to delete pod code-exec-test" in caplog.text


def test_cleanup_logs_when_delete_never_succeeds(
executor: KubernetesExecutor, caplog: pytest.LogCaptureFixture
) -> None:
executor.v1.delete_namespaced_pod.side_effect = ApiException(status=500, reason="boom")

with caplog.at_level(logging.WARNING):
executor._cleanup_pod("code-exec-test")

assert executor.v1.delete_namespaced_pod.call_count == 3
assert "Failed to delete pod code-exec-test" in caplog.text
assert "Failed to confirm deletion of pod code-exec-test" in caplog.text


def test_stream_exec_uses_fresh_api_client(executor: KubernetesExecutor) -> None:
with (
patch("app.services.executor_kubernetes.client.ApiClient") as api_client_cls,
patch("app.services.executor_kubernetes.client.CoreV1Api") as core_v1_cls,
patch("app.services.executor_kubernetes.stream.stream") as mock_stream,
):
stream_api = MagicMock()
core_v1_cls.return_value = stream_api

executor._stream_pod_exec(
"code-exec-test",
["python", "/workspace/__main__.py"],
stderr=True,
stdin=True,
stdout=True,
tty=False,
)

api_client_cls.assert_called_once()
core_v1_cls.assert_called_once_with(api_client=api_client_cls.return_value)
assert mock_stream.call_args.args[0] is stream_api.connect_get_namespaced_pod_exec
assert mock_stream.call_args.kwargs["_preload_content"] is False


def test_streaming_empty_output(executor: KubernetesExecutor) -> None:
events = _run_streaming(executor, FakeExecResp())

Expand Down
Loading