diff --git a/docs/ref/extensions/sandbox/tensorlake/sandbox.md b/docs/ref/extensions/sandbox/tensorlake/sandbox.md new file mode 100644 index 0000000000..1be9fceca6 --- /dev/null +++ b/docs/ref/extensions/sandbox/tensorlake/sandbox.md @@ -0,0 +1,3 @@ +# `Sandbox` + +::: agents.extensions.sandbox.tensorlake.sandbox diff --git a/docs/sandbox/clients.md b/docs/sandbox/clients.md index bd21da63d3..be4794454b 100644 --- a/docs/sandbox/clients.md +++ b/docs/sandbox/clients.md @@ -96,6 +96,7 @@ For provider-specific setup notes and links for the checked-in extension example | `E2BSandboxClient` | `openai-agents[e2b]` | [E2B runner](https://github.com/openai/openai-agents-python/blob/main/examples/sandbox/extensions/e2b_runner.py) | | `ModalSandboxClient` | `openai-agents[modal]` | [Modal runner](https://github.com/openai/openai-agents-python/blob/main/examples/sandbox/extensions/modal_runner.py) | | `RunloopSandboxClient` | `openai-agents[runloop]` | [Runloop runner](https://github.com/openai/openai-agents-python/blob/main/examples/sandbox/extensions/runloop/runner.py) | +| `TensorlakeSandboxClient` | `openai-agents[tensorlake]` | [Tensorlake runner](https://github.com/openai/openai-agents-python/blob/main/examples/sandbox/extensions/tensorlake_runner.py) | | `VercelSandboxClient` | `openai-agents[vercel]` | [Vercel runner](https://github.com/openai/openai-agents-python/blob/main/examples/sandbox/extensions/vercel_runner.py) | @@ -113,6 +114,7 @@ Hosted sandbox clients expose provider-specific mount strategies. Choose the bac | `DaytonaSandboxClient` | Supports rclone-backed cloud storage mounts with `DaytonaCloudBucketMountStrategy`; use it with `S3Mount`, `GCSMount`, `R2Mount`, `AzureBlobMount`, and `BoxMount`. | | `E2BSandboxClient` | Supports rclone-backed cloud storage mounts with `E2BCloudBucketMountStrategy`; use it with `S3Mount`, `GCSMount`, `R2Mount`, `AzureBlobMount`, and `BoxMount`. | | `RunloopSandboxClient` | Supports rclone-backed cloud storage mounts with `RunloopCloudBucketMountStrategy`; use it with `S3Mount`, `GCSMount`, `R2Mount`, `AzureBlobMount`, and `BoxMount`. | +| `TensorlakeSandboxClient` | No hosted-specific mount strategy is currently exposed. Use manifest files, repos, or other workspace inputs instead. The default manifest root is `DEFAULT_TENSORLAKE_WORKSPACE_ROOT` (`/home/tl-user/workspace`), which is writable by the default image's non-root user and persisted across FILESYSTEM checkpoints; override it only when targeting a custom image. Tensorlake's native sandbox checkpoint API is available via `workspace_persistence="snapshot"`; prefer this over external bucket mounts for between-run persistence. | | `VercelSandboxClient` | No hosted-specific mount strategy is currently exposed. Use manifest files, repos, or other workspace inputs instead. | @@ -130,6 +132,7 @@ The table below summarizes which remote storage entries each backend can mount d | `DaytonaSandboxClient` | ✓ | ✓ | ✓ | ✓ | ✓ | - | | `E2BSandboxClient` | ✓ | ✓ | ✓ | ✓ | ✓ | - | | `RunloopSandboxClient` | ✓ | ✓ | ✓ | ✓ | ✓ | - | +| `TensorlakeSandboxClient` | - | - | - | - | - | - | | `VercelSandboxClient` | - | - | - | - | - | - | diff --git a/examples/sandbox/extensions/README.md b/examples/sandbox/extensions/README.md index 837d9dfa28..61e5817901 100644 --- a/examples/sandbox/extensions/README.md +++ b/examples/sandbox/extensions/README.md @@ -243,6 +243,45 @@ export DAYTONA_API_KEY=... uv run python examples/sandbox/extensions/daytona/daytona_runner.py --stream ``` +## Tensorlake + +### Setup + +Install the repo extra: + +```bash +uv sync --extra tensorlake +``` + +Sign up at [cloud.tensorlake.ai](https://cloud.tensorlake.ai/) (or run `tl login`) +and export the required environment variables: + +```bash +export OPENAI_API_KEY=... +export TENSORLAKE_API_KEY=... +``` + +### Run + +```bash +uv run python examples/sandbox/extensions/tensorlake_runner.py --stream +``` + +Useful flags: + +- `--image ` to pin a specific Tensorlake registered image. +- `--timeout-secs 600` +- `--workspace-persistence snapshot` specify snapshot as the mechanism for persistence. +- `--env KEY=VAL` to inject an environment variable into the sandbox. Repeatable. +- `--secret NAME` to inject a Tensorlake-managed secret into the sandbox by name. Repeatable. +- `--cpus ` to set the sandbox CPU allocation. +- `--memory-mb ` to set the sandbox memory allocation, in megabytes. + Must be between 1024 and 8192 MB **per CPU core**, so scale this up + alongside `--cpus` (for example, `--cpus 2` requires at least + `--memory-mb 2048`). +- `--disk-mb ` to set the sandbox disk allocation, in megabytes. + Must be between 10240 and 102400 MiB (10–100 GiB). + ## Runloop ### Setup diff --git a/examples/sandbox/extensions/tensorlake_runner.py b/examples/sandbox/extensions/tensorlake_runner.py new file mode 100644 index 0000000000..e1909bfd33 --- /dev/null +++ b/examples/sandbox/extensions/tensorlake_runner.py @@ -0,0 +1,315 @@ +""" +Minimal Tensorlake-backed sandbox example for manual validation. + +This mirrors the other cloud extension examples: it creates a tiny workspace, +verifies stop/resume persistence, then asks a sandboxed agent to inspect the +workspace through one shell tool. +""" + +from __future__ import annotations + +import argparse +import asyncio +import io +import os +import sys +import tempfile +from pathlib import Path +from typing import Literal, cast + +from openai.types.responses import ResponseTextDeltaEvent + +from agents import ModelSettings, Runner +from agents.run import RunConfig +from agents.sandbox import LocalSnapshotSpec, Manifest, SandboxAgent, SandboxRunConfig +from agents.sandbox.entries import File +from agents.sandbox.session import BaseSandboxSession + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +from examples.sandbox.misc.workspace_shell import WorkspaceShellCapability + +try: + from agents.extensions.sandbox import ( + DEFAULT_TENSORLAKE_WORKSPACE_ROOT, + TensorlakeSandboxClient, + TensorlakeSandboxClientOptions, + ) +except Exception as exc: # pragma: no cover - import path depends on optional extras + raise SystemExit( + "Tensorlake sandbox examples require the optional repo extra.\n" + "Install it with: uv sync --extra tensorlake" + ) from exc + + +DEFAULT_QUESTION = "Summarize this cloud sandbox workspace in 2 sentences." +SNAPSHOT_CHECK_PATH = Path("snapshot-check.txt") +SNAPSHOT_CHECK_CONTENT = "tensorlake snapshot round-trip ok\n" +LIVE_RESUME_CHECK_PATH = Path("live-resume-check.txt") +LIVE_RESUME_CHECK_CONTENT = "tensorlake live resume ok\n" + + +def _build_manifest() -> Manifest: + files = { + "README.md": ( + "# Tensorlake Demo Workspace\n\n" + "This workspace exists to validate the Tensorlake sandbox backend manually.\n" + ), + "handoff.md": ( + "# Handoff\n\n" + "- Customer: Northwind Traders.\n" + "- Goal: validate Tensorlake sandbox exec and persistence flows.\n" + "- Current status: non-PTY backend slice is wired and under test.\n" + ), + "todo.md": ( + "# Todo\n\n" + "1. Inspect the workspace files.\n" + "2. Summarize the current status in two sentences.\n" + ), + } + return Manifest( + root=DEFAULT_TENSORLAKE_WORKSPACE_ROOT, + entries={path: File(content=contents.encode("utf-8")) for path, contents in files.items()}, + ) + + +async def _read_text(session: BaseSandboxSession, path: Path) -> str: + data = await session.read(path) + text = cast(str | bytes, data.read()) + if isinstance(text, bytes): + return text.decode("utf-8") + return text + + +def _require_env(name: str) -> None: + if os.environ.get(name): + return + raise SystemExit(f"{name} must be set before running this example.") + + +def _parse_env_pair(raw: str) -> tuple[str, str]: + if "=" not in raw: + raise argparse.ArgumentTypeError(f"--env value must be KEY=VAL (got {raw!r}).") + key, value = raw.split("=", 1) + if not key: + raise argparse.ArgumentTypeError(f"--env key must be non-empty (got {raw!r}).") + return key, value + + +async def _verify_stop_resume( + *, + manifest: Manifest, + options: TensorlakeSandboxClientOptions, +) -> None: + # Verification sandboxes should always terminate on shutdown so the example does not + # leak suspended sandboxes; pause-on-exit is exercised by the main agent run instead. + options = options.model_copy(update={"pause_on_exit": False}) + client = TensorlakeSandboxClient() + with tempfile.TemporaryDirectory(prefix="tensorlake-snapshot-example-") as snapshot_dir: + sandbox = await client.create( + manifest=manifest, + snapshot=LocalSnapshotSpec(base_path=Path(snapshot_dir)), + options=options, + ) + + try: + await sandbox.start() + await sandbox.write( + SNAPSHOT_CHECK_PATH, + io.BytesIO(SNAPSHOT_CHECK_CONTENT.encode("utf-8")), + ) + await sandbox.stop() + finally: + await sandbox.shutdown() + + resumed_sandbox = await client.resume(sandbox.state) + try: + await resumed_sandbox.start() + restored_text = await _read_text(resumed_sandbox, SNAPSHOT_CHECK_PATH) + if restored_text != SNAPSHOT_CHECK_CONTENT: + raise RuntimeError( + f"Snapshot resume verification failed for {options.workspace_persistence!r}: " + f"expected {SNAPSHOT_CHECK_CONTENT!r}, got {restored_text!r}" + ) + finally: + await resumed_sandbox.aclose() + + print(f"snapshot round-trip ok ({options.workspace_persistence})") + + +async def _verify_resume_running_sandbox( + *, + manifest: Manifest, + options: TensorlakeSandboxClientOptions, +) -> None: + # Force terminate-on-shutdown for verification so we don't leave suspended sandboxes behind. + options = options.model_copy(update={"pause_on_exit": False}) + client = TensorlakeSandboxClient() + sandbox = await client.create(manifest=manifest, options=options) + + try: + await sandbox.start() + await sandbox.write( + LIVE_RESUME_CHECK_PATH, + io.BytesIO(LIVE_RESUME_CHECK_CONTENT.encode("utf-8")), + ) + serialized = client.serialize_session_state(sandbox.state) + resumed_sandbox = await client.resume(client.deserialize_session_state(serialized)) + try: + restored_text = await _read_text(resumed_sandbox, LIVE_RESUME_CHECK_PATH) + if restored_text != LIVE_RESUME_CHECK_CONTENT: + raise RuntimeError( + "Running sandbox resume verification failed: " + f"expected {LIVE_RESUME_CHECK_CONTENT!r}, got {restored_text!r}" + ) + finally: + await resumed_sandbox.aclose() + finally: + await sandbox.shutdown() + + print(f"running sandbox resume ok ({options.workspace_persistence})") + + +async def main( + *, + model: str, + question: str, + options: TensorlakeSandboxClientOptions, + stream: bool, +) -> None: + _require_env("OPENAI_API_KEY") + _require_env("TENSORLAKE_API_KEY") + + manifest = _build_manifest() + + await _verify_stop_resume(manifest=manifest, options=options) + await _verify_resume_running_sandbox(manifest=manifest, options=options) + + agent = SandboxAgent( + name="Tensorlake Sandbox Assistant", + model=model, + instructions=( + "Answer questions about the sandbox workspace. Inspect the files before answering " + "and keep the response concise. " + "Do not invent files or statuses that are not present in the workspace. Cite the " + "file names you inspected." + ), + default_manifest=manifest, + capabilities=[WorkspaceShellCapability()], + model_settings=ModelSettings(tool_choice="required"), + ) + + run_config = RunConfig( + sandbox=SandboxRunConfig( + client=TensorlakeSandboxClient(), + options=options, + ), + workflow_name="Tensorlake sandbox example", + ) + + if not stream: + result = await Runner.run(agent, question, run_config=run_config) + print(result.final_output) + return + + stream_result = Runner.run_streamed(agent, question, run_config=run_config) + saw_text_delta = False + async for event in stream_result.stream_events(): + if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent): + if not saw_text_delta: + print("assistant> ", end="", flush=True) + saw_text_delta = True + print(event.data.delta, end="", flush=True) + + if saw_text_delta: + print() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="gpt-5.5", help="Model name to use.") + parser.add_argument("--question", default=DEFAULT_QUESTION, help="Prompt to send to the agent.") + parser.add_argument( + "--image", + default=None, + help="Optional Tensorlake registered image name. Falls back to the SDK default.", + ) + parser.add_argument( + "--timeout-secs", + type=int, + default=1800, + help=( + "Optional Tensorlake sandbox lifetime in seconds. Must be strictly greater " + "than `checkpoint_timeout_s` (default 300) when " + "--workspace-persistence=snapshot." + ), + ) + parser.add_argument( + "--workspace-persistence", + choices=("tar", "snapshot"), + default="tar", + help="Workspace persistence mode to verify before the agent run.", + ) + parser.add_argument( + "--env", + action="append", + default=None, + type=_parse_env_pair, + metavar="KEY=VAL", + help="Environment variable to inject into the sandbox. Repeatable.", + ) + parser.add_argument( + "--secret", + action="append", + default=None, + metavar="NAME", + help="Tensorlake-managed secret name to inject into the sandbox. Repeatable.", + ) + parser.add_argument( + "--pause-on-exit", + action="store_true", + default=False, + help="Pause the sandbox on shutdown instead of terminating it.", + ) + parser.add_argument( + "--cpus", + type=float, + default=None, + help="Optional CPU allocation for the sandbox.", + ) + parser.add_argument( + "--memory-mb", + type=int, + default=None, + help="Optional memory allocation for the sandbox, in megabytes.", + ) + parser.add_argument( + "--disk-mb", + type=int, + default=None, + help="Optional disk allocation for the sandbox, in megabytes.", + ) + parser.add_argument("--stream", action="store_true", default=False, help="Stream the response.") + args = parser.parse_args() + + options = TensorlakeSandboxClientOptions( + image=args.image, + timeout_secs=args.timeout_secs, + workspace_persistence=cast(Literal["tar", "snapshot"], args.workspace_persistence), + envs=dict(args.env) if args.env else None, + secret_names=tuple(args.secret or ()), + pause_on_exit=args.pause_on_exit, + cpus=args.cpus, + memory_mb=args.memory_mb, + disk_mb=args.disk_mb, + ) + + asyncio.run( + main( + model=args.model, + question=args.question, + options=options, + stream=args.stream, + ) + ) diff --git a/pyproject.toml b/pyproject.toml index 4d0122049f..8a61456401 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ cloudflare = ["aiohttp>=3.12,<4"] e2b = ["e2b==2.20.0", "e2b-code-interpreter==2.4.1"] modal = ["modal==1.3.5"] runloop = ["runloop_api_client>=1.16.0,<2.0.0"] +tensorlake = ["tensorlake>=0.5.11"] vercel = ["vercel>=0.5.6,<0.6"] s3 = ["boto3>=1.34"] temporal = [ @@ -156,6 +157,10 @@ ignore_missing_imports = true module = ["runloop_api_client", "runloop_api_client.*"] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = ["tensorlake", "tensorlake.*"] +ignore_missing_imports = true + [[tool.mypy.overrides]] module = ["blaxel", "blaxel.*"] ignore_missing_imports = true diff --git a/src/agents/extensions/sandbox/__init__.py b/src/agents/extensions/sandbox/__init__.py index d7b082ba1f..fa2a66dab3 100644 --- a/src/agents/extensions/sandbox/__init__.py +++ b/src/agents/extensions/sandbox/__init__.py @@ -97,6 +97,20 @@ except Exception: # pragma: no cover _HAS_RUNLOOP = False +try: + from .tensorlake import ( + DEFAULT_TENSORLAKE_WORKSPACE_ROOT as DEFAULT_TENSORLAKE_WORKSPACE_ROOT, + TensorlakeSandboxClient as TensorlakeSandboxClient, + TensorlakeSandboxClientOptions as TensorlakeSandboxClientOptions, + TensorlakeSandboxSession as TensorlakeSandboxSession, + TensorlakeSandboxSessionState as TensorlakeSandboxSessionState, + TensorlakeSandboxTimeouts as TensorlakeSandboxTimeouts, + ) + + _HAS_TENSORLAKE = True +except Exception: # pragma: no cover + _HAS_TENSORLAKE = False + try: from .vercel import ( VercelSandboxClient as VercelSandboxClient, @@ -177,6 +191,18 @@ ] ) +if _HAS_TENSORLAKE: + __all__.extend( + [ + "DEFAULT_TENSORLAKE_WORKSPACE_ROOT", + "TensorlakeSandboxClient", + "TensorlakeSandboxClientOptions", + "TensorlakeSandboxSession", + "TensorlakeSandboxSessionState", + "TensorlakeSandboxTimeouts", + ] + ) + if _HAS_VERCEL: __all__.extend( [ diff --git a/src/agents/extensions/sandbox/tensorlake/__init__.py b/src/agents/extensions/sandbox/tensorlake/__init__.py new file mode 100644 index 0000000000..dea284e1bb --- /dev/null +++ b/src/agents/extensions/sandbox/tensorlake/__init__.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from .sandbox import ( + DEFAULT_TENSORLAKE_WORKSPACE_ROOT, + TensorlakeSandboxClient, + TensorlakeSandboxClientOptions, + TensorlakeSandboxSession, + TensorlakeSandboxSessionState, + TensorlakeSandboxTimeouts, +) + +__all__ = [ + "DEFAULT_TENSORLAKE_WORKSPACE_ROOT", + "TensorlakeSandboxClient", + "TensorlakeSandboxClientOptions", + "TensorlakeSandboxSession", + "TensorlakeSandboxSessionState", + "TensorlakeSandboxTimeouts", +] diff --git a/src/agents/extensions/sandbox/tensorlake/sandbox.py b/src/agents/extensions/sandbox/tensorlake/sandbox.py new file mode 100644 index 0000000000..b9c694cb76 --- /dev/null +++ b/src/agents/extensions/sandbox/tensorlake/sandbox.py @@ -0,0 +1,1508 @@ +""" +Tensorlake sandbox (https://tensorlake.ai) implementation. + +Set `TENSORLAKE_API_KEY` (or run `tl login`) to authenticate. + +This module provides a Tensorlake-backed sandbox client/session implementation backed by +`tensorlake.sandbox.AsyncSandbox`. + +Note: The `tensorlake` dependency is optional (installed via the `tensorlake` extra). The +SDK is imported at module load and importing this module without the extra installed +raises `ImportError` with installation guidance. +""" + +from __future__ import annotations + +import asyncio +import io +import json +import logging +import math +import uuid +from contextlib import suppress +from dataclasses import dataclass, fields, replace +from pathlib import Path +from typing import Any, Literal, cast +from urllib.parse import urlsplit + +from pydantic import BaseModel, Field + +try: + from tensorlake.sandbox import ( + AsyncSandbox, + CheckpointType, + RemoteAPIError, + SandboxError, + SandboxStatus, + ) +except ImportError as exc: # pragma: no cover - exercised via unit tests with fakes + raise ImportError( + "TensorlakeSandboxClient requires the optional `tensorlake` dependency.\n" + 'Install it with `pip install "openai-agents[tensorlake]"`.' + ) from exc + +from ....sandbox.errors import ( + ExecNonZeroError, + ExecTimeoutError, + ExecTransportError, + WorkspaceArchiveReadError, + WorkspaceArchiveWriteError, + WorkspaceReadNotFoundError, + WorkspaceStartError, + WorkspaceWriteTypeError, +) +from ....sandbox.manifest import Manifest +from ....sandbox.session import SandboxSession, SandboxSessionState +from ....sandbox.session.base_sandbox_session import BaseSandboxSession +from ....sandbox.session.dependencies import Dependencies +from ....sandbox.session.manager import Instrumentation +from ....sandbox.session.mount_lifecycle import with_ephemeral_mounts_removed +from ....sandbox.session.runtime_helpers import RESOLVE_WORKSPACE_PATH_HELPER, RuntimeHelperScript +from ....sandbox.session.sandbox_client import BaseSandboxClient, BaseSandboxClientOptions +from ....sandbox.session.tar_workspace import shell_tar_exclude_args +from ....sandbox.snapshot import SnapshotBase, SnapshotSpec, resolve_snapshot +from ....sandbox.types import ExecResult, ExposedPortEndpoint, User +from ....sandbox.util.retry import ( + TRANSIENT_HTTP_STATUS_CODES, + exception_chain_has_status_code, + retry_async, +) +from ....sandbox.util.tar_utils import UnsafeTarMemberError, validate_tar_bytes +from ....sandbox.workspace_paths import posix_path_for_error, sandbox_path_str + +logger = logging.getLogger(__name__) + +WorkspacePersistenceMode = Literal["tar", "snapshot"] +CheckpointMode = Literal["filesystem", "memory"] +CheckpointWaitUntil = Literal["local_ready", "completed"] + +_WORKSPACE_PERSISTENCE_TAR: WorkspacePersistenceMode = "tar" +_WORKSPACE_PERSISTENCE_SNAPSHOT: WorkspacePersistenceMode = "snapshot" + +# Default manifest root for the Tensorlake provider. The default image runs as the +# non-root `tl-user`, so `/workspace` (the cross-provider default) is not writable; +# tmpfs paths like `/tmp/*` are writable but excluded from FILESYSTEM checkpoints. +# `/home/tl-user/workspace` is both `tl-user`-writable and persisted across snapshots. +DEFAULT_TENSORLAKE_WORKSPACE_ROOT = "/home/tl-user/workspace" +_DEFAULT_MANIFEST_ROOT = cast(str, Manifest.model_fields["root"].default) +_GENERATED_TENSORLAKE_NAME_PREFIX = "openai-agents-" + +# Magic prefix for Tensorlake checkpoint references that are not tar bytes. +_TENSORLAKE_SNAPSHOT_MAGIC = b"TENSORLAKE_SANDBOX_SNAPSHOT_V1\n" + +_DEFAULT_EXPOSED_PORT_HOST_TEMPLATE = "{port}-{sandbox}.sandbox.tensorlake.ai" + +# Hostnames that indicate a local proxy where port-prefixed subdomain routing does not apply +# (the SDK uses a `Host` header instead). +_LOOPBACK_HOSTS = frozenset({"localhost", "127.0.0.1", "::1"}) + + +def _unwrap_traced_bytes(payload: Any) -> bytes: + """Unwrap a `Traced[bytes]` returned by `AsyncSandbox.read_file` into raw bytes. + + The SDK wraps the value on `.value` and exposes a W3C trace id on `.trace_id`; detect + via `trace_id` so a `Traced[None]` still unwraps correctly. + """ + if hasattr(payload, "trace_id") and not isinstance(payload, bytes | bytearray): + payload = payload.value + if isinstance(payload, bytes | bytearray): + return bytes(payload) + return str(payload).encode("utf-8", errors="replace") + + +def _encode_tensorlake_snapshot_ref(*, snapshot_id: str) -> bytes: + body = json.dumps({"snapshot_id": snapshot_id}, separators=(",", ":"), sort_keys=True).encode( + "utf-8" + ) + return _TENSORLAKE_SNAPSHOT_MAGIC + body + + +def _decode_tensorlake_snapshot_ref(raw: bytes) -> str | None: + if not raw.startswith(_TENSORLAKE_SNAPSHOT_MAGIC): + return None + body = raw[len(_TENSORLAKE_SNAPSHOT_MAGIC) :] + try: + payload = json.loads(body.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError): + return None + snapshot_id = payload.get("snapshot_id") if isinstance(payload, dict) else None + return snapshot_id if isinstance(snapshot_id, str) and snapshot_id else None + + +async def _restore_tensorlake_snapshot_reference_id(snapshot: SnapshotBase) -> str | None: + """Best-effort extraction of the Tensorlake snapshot id from a persisted snapshot. + + Returns ``None`` when the persisted payload is not a Tensorlake checkpoint reference + or the snapshot store cannot be reached. `client.resume()` runs before session + dependencies are wired, so e.g. `RemoteSnapshot` would raise; callers fall back to + the slower `hydrate_workspace` path in those cases. + """ + + try: + if not await snapshot.restorable(): + return None + restored = await snapshot.restore() + try: + raw = restored.read() + finally: + with suppress(Exception): + restored.close() + except Exception: + return None + + if isinstance(raw, str): + raw = raw.encode("utf-8") + if not isinstance(raw, bytes | bytearray): + return None + return _decode_tensorlake_snapshot_ref(bytes(raw)) + + +class TensorlakeSandboxTimeouts(BaseModel): + """Timeout configuration for Tensorlake operations. + + Attributes: + exec_timeout_unbounded_s: Safety cap, in seconds, applied to `exec(...)` calls that + were invoked with `timeout=None`. Prevents an "unbounded" command from holding + the sandbox indefinitely. Defaults to 24 hours. + fast_op_s: Per-operation timeout, in seconds, for short backend operations such as + `mkdir`, `delete_file`, and exposed-port updates. Defaults to 30 seconds. + snapshot_tar_s: Per-operation timeout, in seconds, for tar-based workspace + persist and hydrate operations. Defaults to 300 seconds. + """ + + # Caller-supplied timeout=None should mean "no timeout" without bypassing the safety net. + exec_timeout_unbounded_s: float = Field(default=24 * 60 * 60, ge=1) # 24 hours + fast_op_s: float = Field(default=30, ge=1) + snapshot_tar_s: float = Field(default=300, ge=1) + + +class TensorlakeSandboxClientOptions(BaseSandboxClientOptions): + """Client options for the Tensorlake sandbox backend. + + Attributes: + image: Optional Tensorlake registered image name. Falls back to the SDK default + image when None. The default image runs as the non-root `tl-user` account, so + workspace paths must be writable by that user; see + [`DEFAULT_TENSORLAKE_WORKSPACE_ROOT`][agents.extensions.sandbox.tensorlake.sandbox.DEFAULT_TENSORLAKE_WORKSPACE_ROOT]. + cpus: Optional CPU allocation for the sandbox. Uses the Tensorlake SDK default + when None. + memory_mb: Optional memory allocation for the sandbox, in megabytes. + disk_mb: Optional disk allocation for the sandbox, in megabytes. + timeout_secs: Optional sandbox lifetime, in seconds. The Tensorlake backend + automatically terminates the sandbox after this period. + name: Optional friendly name for the sandbox; surfaces in Tensorlake dashboards + and is used as part of the fallback exposed-port hostname. If omitted while + `pause_on_exit=True`, a stable name is generated because Tensorlake only + supports suspend/resume for named sandboxes. + secret_names: Names of Tensorlake-managed secrets to inject into the sandbox + environment. + envs: Additional environment variables to inject on every command. Tensorlake + does not accept envs at sandbox-create time; they are passed on every + `AsyncSandbox.run(...)` call instead, merged with manifest-supplied envs. + allow_internet_access: Whether the sandbox is allowed to make outbound internet + connections. Defaults to True. + allow_out: Hostname allow-list for outbound traffic. When set, only the listed + hostnames are reachable from the sandbox. + deny_out: Hostname deny-list for outbound traffic. + exposed_ports: TCP ports inside the sandbox to expose through Tensorlake's + per-sandbox URL. + allow_unauthenticated_port_access: When True, exposed ports skip Tensorlake's + built-in auth check. Defaults to False. + pause_on_exit: When True, run the sandbox as a named Tensorlake sandbox and + suspend it on shutdown so it can later be resumed via `client.resume(state)`. + When False (default), the sandbox is terminated on shutdown. + workspace_persistence: How to persist the workspace between runs. `"tar"` + (default) captures the manifest root as a tar archive; `"snapshot"` uses + Tensorlake's native sandbox checkpoint API and stores only a snapshot id. + Snapshot mode falls back to tar when path-level skips are required. + checkpoint_mode: For `workspace_persistence="snapshot"`, either `"filesystem"` + (default; persists across hosts) or `"memory"` (faster, host-local). + checkpoint_timeout_s: Timeout, in seconds, for a single native checkpoint + operation. Defaults to 300. Must be strictly less than `timeout_secs` + when `workspace_persistence="snapshot"`, so the sandbox lives long enough + for the snapshot to settle before the Tensorlake backend auto-terminates + it; otherwise the snapshot can be orphaned mid-poll. + checkpoint_wait_until: How long to wait for the native checkpoint before + returning a snapshot id. `"local_ready"` (default, matches Tensorlake's SDK + default) returns as soon as the snapshot is locally resumable — fast and + sufficient for `AsyncSandbox.create(snapshot_id=...)` restore on the same + backend. `"completed"` additionally blocks until the snapshot is uploaded + to durable remote storage; use this only when you need a durable + `snapshot_uri` (e.g. for cross-host restore after the source host is gone). + timeouts: Optional `TensorlakeSandboxTimeouts` override (or dict of the same + shape) controlling fine-grained per-operation timeouts. + entrypoint: Optional command override for the sandbox image entrypoint. + startup_timeout: Optional seconds to wait for the sandbox to become ready after + create. + proxy_url: Optional override for the Tensorlake sandbox proxy URL (e.g., for + self-hosted or dev deployments). When set, the exposed-port host is resolved + from `AsyncSandbox.info().sandbox_url` instead of the public template. + api_url: Optional override for the Tensorlake control-plane API URL. + namespace: Optional Tensorlake namespace selector. + organization_id: Optional Tensorlake organization id. + project_id: Optional Tensorlake project id. + routing_hint: Optional routing hint passed to `AsyncSandbox.connect(...)` when + resuming an existing sandbox. Not used at create time. + """ + + type: Literal["tensorlake"] = "tensorlake" + image: str | None = None + cpus: float | None = None + memory_mb: int | None = None + timeout_secs: int | None = None + name: str | None = None + secret_names: tuple[str, ...] = () + envs: dict[str, str] | None = None + allow_internet_access: bool = True + allow_out: tuple[str, ...] = () + deny_out: tuple[str, ...] = () + exposed_ports: tuple[int, ...] = () + allow_unauthenticated_port_access: bool = False + pause_on_exit: bool = False + workspace_persistence: WorkspacePersistenceMode = _WORKSPACE_PERSISTENCE_TAR + checkpoint_mode: CheckpointMode = "filesystem" + checkpoint_timeout_s: float = 300.0 + timeouts: TensorlakeSandboxTimeouts | dict[str, object] | None = None + disk_mb: int | None = None + entrypoint: tuple[str, ...] = () + startup_timeout: float | None = None + proxy_url: str | None = None + api_url: str | None = None + namespace: str | None = None + organization_id: str | None = None + project_id: str | None = None + routing_hint: str | None = None + checkpoint_wait_until: CheckpointWaitUntil = "local_ready" + + def __init__( + self, + image: str | None = None, + cpus: float | None = None, + memory_mb: int | None = None, + timeout_secs: int | None = None, + name: str | None = None, + secret_names: tuple[str, ...] = (), + envs: dict[str, str] | None = None, + allow_internet_access: bool = True, + allow_out: tuple[str, ...] = (), + deny_out: tuple[str, ...] = (), + exposed_ports: tuple[int, ...] = (), + allow_unauthenticated_port_access: bool = False, + pause_on_exit: bool = False, + workspace_persistence: WorkspacePersistenceMode = _WORKSPACE_PERSISTENCE_TAR, + checkpoint_mode: CheckpointMode = "filesystem", + checkpoint_timeout_s: float = 300.0, + timeouts: TensorlakeSandboxTimeouts | dict[str, object] | None = None, + disk_mb: int | None = None, + entrypoint: tuple[str, ...] = (), + startup_timeout: float | None = None, + proxy_url: str | None = None, + api_url: str | None = None, + namespace: str | None = None, + organization_id: str | None = None, + project_id: str | None = None, + routing_hint: str | None = None, + checkpoint_wait_until: CheckpointWaitUntil = "local_ready", + *, + type: Literal["tensorlake"] = "tensorlake", + ) -> None: + super().__init__( + type=type, + image=image, + cpus=cpus, + memory_mb=memory_mb, + timeout_secs=timeout_secs, + name=name, + secret_names=secret_names, + envs=envs, + allow_internet_access=allow_internet_access, + allow_out=allow_out, + deny_out=deny_out, + exposed_ports=exposed_ports, + allow_unauthenticated_port_access=allow_unauthenticated_port_access, + pause_on_exit=pause_on_exit, + workspace_persistence=workspace_persistence, + checkpoint_mode=checkpoint_mode, + checkpoint_timeout_s=checkpoint_timeout_s, + timeouts=timeouts, + disk_mb=disk_mb, + entrypoint=entrypoint, + startup_timeout=startup_timeout, + proxy_url=proxy_url, + api_url=api_url, + namespace=namespace, + organization_id=organization_id, + project_id=project_id, + routing_hint=routing_hint, + checkpoint_wait_until=checkpoint_wait_until, + ) + + +class TensorlakeSandboxSessionState(SandboxSessionState): + """Serializable state for a Tensorlake-backed session. + + Captured at `create(...)` time and round-trippable via + `client.serialize_session_state` / `client.deserialize_session_state`. Mirrors the + knobs from + [`TensorlakeSandboxClientOptions`][agents.extensions.sandbox.tensorlake.sandbox.TensorlakeSandboxClientOptions] + needed to reconnect to the same sandbox (via `AsyncSandbox.connect`) or, if it has + expired, to recreate it from a stored snapshot id or hydrate it from a tar archive. + + Attributes: + sandbox_id: The Tensorlake-assigned identifier of the underlying sandbox. + base_envs: Caller-supplied environment variables to merge with manifest envs on + every command. + """ + + type: Literal["tensorlake"] = "tensorlake" + sandbox_id: str + name: str | None = None + image: str | None = None + cpus: float | None = None + memory_mb: int | None = None + timeout_secs: int | None = None + secret_names: tuple[str, ...] = () + base_envs: dict[str, str] = Field(default_factory=dict) + allow_internet_access: bool = True + allow_out: tuple[str, ...] = () + deny_out: tuple[str, ...] = () + allow_unauthenticated_port_access: bool = False + pause_on_exit: bool = False + workspace_persistence: WorkspacePersistenceMode = _WORKSPACE_PERSISTENCE_TAR + checkpoint_mode: CheckpointMode = "filesystem" + checkpoint_timeout_s: float = 300.0 + timeouts: TensorlakeSandboxTimeouts = Field(default_factory=TensorlakeSandboxTimeouts) + disk_mb: int | None = None + entrypoint: tuple[str, ...] = () + startup_timeout: float | None = None + proxy_url: str | None = None + api_url: str | None = None + namespace: str | None = None + organization_id: str | None = None + project_id: str | None = None + routing_hint: str | None = None + checkpoint_wait_until: CheckpointWaitUntil = "local_ready" + + @classmethod + def from_options( + cls, + options: TensorlakeSandboxClientOptions, + *, + session_id: uuid.UUID, + manifest: Manifest, + snapshot: SnapshotBase, + sandbox_id: str, + name: str | None, + timeouts: TensorlakeSandboxTimeouts, + ) -> TensorlakeSandboxSessionState: + """Build a session state from the create-time options plus derived values. + + Carry-over fields are pulled by name from `options`; the explicit keyword + arguments override fields that have no `options` counterpart (`session_id`, + `manifest`, `snapshot`, `sandbox_id`, `base_envs`) or that the caller derives + separately (`name`, `timeouts`). + """ + # `name` is resolved by `_resolve_lifecycle_sandbox_name`, and `timeouts` is + # the validated `TensorlakeSandboxTimeouts` (options.timeouts may be a dict). + # `envs` is renamed to `base_envs` and copied so the state owns the dict. + overrides = {"type", "name", "timeouts"} + carry = { + f: getattr(options, f) + for f in cls.model_fields + if f in type(options).model_fields and f not in overrides + } + return cls( + **carry, + session_id=session_id, + manifest=manifest, + snapshot=snapshot, + sandbox_id=sandbox_id, + name=name, + base_envs=dict(options.envs or {}), + timeouts=timeouts, + ) + + +def _resolve_lifecycle_sandbox_name( + *, + name: str | None, + pause_on_exit: bool, + session_id: uuid.UUID, +) -> str | None: + if name is not None and name.strip(): + return name + if pause_on_exit: + return f"{_GENERATED_TENSORLAKE_NAME_PREFIX}{session_id.hex}" + return name + + +# Scalar create-kwargs are emitted when their value is not None; tuple create-kwargs are +# emitted as a list when non-empty. `routing_hint` is accepted by `AsyncSandbox.connect` +# but not `AsyncSandbox.create`. +_CREATE_SCALAR_FIELDS: tuple[str, ...] = ( + "image", + "cpus", + "memory_mb", + "disk_mb", + "timeout_secs", + "name", + "startup_timeout", + "proxy_url", + "api_url", + "namespace", + "organization_id", + "project_id", +) +_CREATE_LIST_FIELDS: tuple[str, ...] = ( + "secret_names", + "allow_out", + "deny_out", + "entrypoint", +) +_CONNECT_FIELDS: tuple[str, ...] = ( + "proxy_url", + "api_url", + "namespace", + "organization_id", + "project_id", + "routing_hint", +) + + +@dataclass(frozen=True, kw_only=True, slots=True) +class _TensorlakeLifecycleConfig: + """Normalized lifecycle config shared by `AsyncSandbox.create` and `connect`. + + Private to this module. Built once from either `TensorlakeSandboxClientOptions` + (at create time) or `TensorlakeSandboxSessionState` (on resume/restore) so the + kwargs derivation lives in one place — adding a new Tensorlake option then + becomes a single change here plus the public option/state classes. + """ + + image: str | None = None + cpus: float | None = None + memory_mb: int | None = None + disk_mb: int | None = None + timeout_secs: int | None = None + name: str | None = None + secret_names: tuple[str, ...] = () + allow_internet_access: bool = True + allow_out: tuple[str, ...] = () + deny_out: tuple[str, ...] = () + entrypoint: tuple[str, ...] = () + startup_timeout: float | None = None + proxy_url: str | None = None + api_url: str | None = None + namespace: str | None = None + organization_id: str | None = None + project_id: str | None = None + routing_hint: str | None = None + + @classmethod + def from_options( + cls, + options: TensorlakeSandboxClientOptions, + *, + name: str | None, + ) -> _TensorlakeLifecycleConfig: + # `name` is the *resolved* sandbox name from `_resolve_lifecycle_sandbox_name`, + # which differs from the raw `options.name` (e.g. when `pause_on_exit` forces a + # generated name), so override that single field after pulling the rest. + attrs = {f.name: getattr(options, f.name) for f in fields(cls)} + attrs["name"] = name + return cls(**attrs) + + @classmethod + def from_state( + cls, + state: TensorlakeSandboxSessionState, + ) -> _TensorlakeLifecycleConfig: + return cls(**{f.name: getattr(state, f.name) for f in fields(cls)}) + + +# Tensorlake memory snapshots restore image, resources, entrypoint, and secrets from +# the snapshot itself; passing them at restore time is rejected by the backend. The +# docs say: "Image, resources (CPUs, memory), entrypoint, and secrets come from the +# snapshot and cannot be changed at restore time." +# See https://docs.tensorlake.ai/sandboxes/snapshots. +_MEMORY_SNAPSHOT_RESTORE_EXCLUDED_SCALARS: frozenset[str] = frozenset( + {"image", "cpus", "memory_mb", "disk_mb"} +) +_MEMORY_SNAPSHOT_RESTORE_EXCLUDED_LISTS: frozenset[str] = frozenset({"entrypoint", "secret_names"}) + + +def _create_kwargs( + cfg: _TensorlakeLifecycleConfig, + *, + snapshot_id: str | None = None, + memory_snapshot: bool = False, +) -> dict[str, object]: + """Derive the kwargs accepted by `AsyncSandbox.create(...)` from a lifecycle config. + + Only includes optional fields when they are set so the SDK can apply its own defaults. + Tensorlake does not accept environment variables at sandbox-create time; envs are passed + on each `sandbox.run(...)` call instead. + + When restoring from a memory snapshot, image/resources/entrypoint/secrets must be + omitted because the backend always sources them from the snapshot. + """ + + excluded_scalars: frozenset[str] = frozenset() + excluded_lists: frozenset[str] = frozenset() + if snapshot_id is not None and memory_snapshot: + excluded_scalars = _MEMORY_SNAPSHOT_RESTORE_EXCLUDED_SCALARS + excluded_lists = _MEMORY_SNAPSHOT_RESTORE_EXCLUDED_LISTS + + kwargs: dict[str, object] = {"allow_internet_access": cfg.allow_internet_access} + for name in _CREATE_SCALAR_FIELDS: + if name in excluded_scalars: + continue + value = getattr(cfg, name) + if value is not None: + kwargs[name] = value + for name in _CREATE_LIST_FIELDS: + if name in excluded_lists: + continue + value = getattr(cfg, name) + if value: + kwargs[name] = list(value) + if snapshot_id is not None: + kwargs["snapshot_id"] = snapshot_id + return kwargs + + +def _connect_kwargs(cfg: _TensorlakeLifecycleConfig) -> dict[str, object]: + """Derive the kwargs accepted by `AsyncSandbox.connect(sandbox_id, ...)`.""" + + return {name: value for name in _CONNECT_FIELDS if (value := getattr(cfg, name)) is not None} + + +async def _resolve_sandbox_id(sandbox: Any) -> str | None: + """Return sandbox_id, seeding the SDK's `info()` cache when necessary. + + `AsyncSandbox.sandbox_id` is a `@property` that raises `SandboxError` until the id + has been populated (e.g. immediately after `AsyncSandbox.create(snapshot_id=...)`); + awaiting `info()` fills `_cached_info` and then `_sandbox_id` so the second read + succeeds. Seeding the cache here also lets subsequent `update()` calls route via + the stable identifier instead of the create-time bootstrap one. + """ + + for attempt in range(2): + try: + value = sandbox.sandbox_id + except SandboxError: + value = None + if isinstance(value, str) and value: + return value + if attempt == 0: + with suppress(SandboxError): + await sandbox.info() + return None + + +class TensorlakeSandboxSession(BaseSandboxSession): + """SandboxSession implementation backed by a Tensorlake sandbox.""" + + state: TensorlakeSandboxSessionState + _sandbox: Any + _envs_cache: dict[str, str] | None + _cached_proxy_hostname: str | None + _proxy_hostname_resolved: bool + _backend_lifecycle_finalized: bool + + def __init__( + self, + *, + state: TensorlakeSandboxSessionState, + sandbox: Any, + ) -> None: + self.state = state + self._sandbox = sandbox + self._envs_cache = None + self._cached_proxy_hostname = None + self._proxy_hostname_resolved = False + self._backend_lifecycle_finalized = False + + @classmethod + def from_state( + cls, + state: TensorlakeSandboxSessionState, + *, + sandbox: Any, + ) -> TensorlakeSandboxSession: + return cls(state=state, sandbox=sandbox) + + @property + def sandbox_id(self) -> str: + return self.state.sandbox_id + + def supports_pty(self) -> bool: + # WebSocket PTY API not yet wired through this integration. + return False + + async def _validate_path_access(self, path: Path | str, *, for_write: bool = False) -> Path: + return await self._validate_remote_path_access(path, for_write=for_write) + + def _runtime_helpers(self) -> tuple[RuntimeHelperScript, ...]: + return (RESOLVE_WORKSPACE_PATH_HELPER,) + + def _current_runtime_helper_cache_key(self) -> object | None: + return self.state.sandbox_id + + async def _resolved_envs(self) -> dict[str, str]: + # The manifest is treated as immutable for the lifetime of a session, so we resolve + # secret-store/env values once and reuse the merged dict across exec/file operations. + if self._envs_cache is None: + manifest_envs = await self.state.manifest.environment.resolve() + self._envs_cache = {**self.state.base_envs, **manifest_envs} + return self._envs_cache + + def _coerce_exec_timeout(self, timeout_s: float | None) -> float: + if timeout_s is None: + return float(self.state.timeouts.exec_timeout_unbounded_s) + if timeout_s <= 0: + # The SDK's `timeout` is an int seconds value; the call site clamps to a 1s + # floor via `max(1, math.ceil(...))`. Return 1.0 here (matching E2B) instead + # of a sub-second sentinel so the intent is obvious at the source. + return 1.0 + return float(timeout_s) + + async def _exec_checked_nonzero(self, *command: str | Path) -> ExecResult: + """Run a privileged metadata command as ``root`` via the ``sudo`` wrap. + + Manifest account provisioning (`groupadd`/`useradd`/`usermod`), entry + metadata (`chgrp`/`chmod`), and mount-pattern config writes all funnel + through this hook. The default Tensorlake image's process user is the + non-root `tl-user`, which lacks permission for these operations. The + Tensorlake SDK does not expose a process-user knob on `run()`, so route + these commands through ``exec(user="root")`` to pick up the base + session's `sudo -u root --` wrap (the established cross-user pattern + used by `mkdir(..., user=...)` and verified in the smoke tests). + """ + result = await self.exec(*command, shell=False, user="root") + if not result.ok(): + raise ExecNonZeroError(result, command=command) + return result + + async def _run_mkdir( + self, + argv: list[str], + *, + user: str | User | None = None, + ) -> Any: + """Run `mkdir argv` via the SDK with resolved envs and the fast-op timeout. + + When `user` is provided, wrap with `sudo -u --` so the created + directory is owned by that sandbox-local user — matching the `sudo`-based + user switching that `_prepare_exec_command` uses for the `exec` path. + Caller is responsible for wrapping SDK exceptions with its own error type + and inspecting `exit_code` on the returned result. + """ + envs = await self._resolved_envs() + if user is None: + command = "mkdir" + args = argv + else: + user_name = user.name if isinstance(user, User) else user + command = "sudo" + args = ["-u", user_name, "--", "mkdir", *argv] + return await self._sandbox.run( + command, + args, + env=envs or None, + timeout=int(self.state.timeouts.fast_op_s), + ) + + async def _prepare_backend_workspace(self) -> None: + # Skip the mkdir round-trip when the base start flow probed a reconnected + # sandbox and confirmed the workspace root already exists. + if self._workspace_state_preserved_on_start() and self._start_workspace_root_ready: + return + root = sandbox_path_str(self.state.manifest.root) + try: + result = await self._run_mkdir(["-p", "--", root]) + except Exception as exc: + raise WorkspaceStartError(path=Path(root), cause=exc) from exc + + exit_code = int(getattr(result, "exit_code", 0) or 0) + if exit_code != 0: + raise WorkspaceStartError( + path=Path(root), + context={ + "reason": "workspace_root_nonzero_exit", + "exit_code": exit_code, + "stderr": str(getattr(result, "stderr", "") or ""), + }, + ) + + async def _after_start(self) -> None: + # Checkpoint restore replaces the sandbox and sandbox_id; reinstall runtime helpers only + # when the cache now points at a different backend. + if self._runtime_helper_cache_key != self._current_runtime_helper_cache_key(): + await self._ensure_runtime_helpers() + + def _close_sandbox_handle(self) -> None: + """Close the SDK's local Rust client and drop the reference. + + Use after a path that did NOT already close the handle for us — primarily the + suspend path, since `AsyncSandbox.suspend()` does not close the local Rust client + (only `terminate()` does). After a successful `terminate()` the SDK has already + closed; just set `self._sandbox = None` directly instead of calling this helper, + to avoid a redundant close on the Rust binding. + """ + sandbox = self._sandbox + if sandbox is None: + return + try: + sandbox.close() + except Exception: + pass + self._sandbox = None + + async def _shutdown_backend(self) -> None: + sandbox = self._sandbox + if sandbox is None: + return + try: + if self.state.pause_on_exit: + await sandbox.suspend() + self._backend_lifecycle_finalized = True + # `suspend()` does not close the local Rust client; release it + # explicitly so the connection pool does not leak. + self._close_sandbox_handle() + else: + await sandbox.terminate() + self._backend_lifecycle_finalized = True + # `terminate()` already closed the local Rust client; just drop the + # reference to avoid a redundant close. + self._sandbox = None + except Exception as exc: + if self.state.pause_on_exit: + logger.warning( + "Failed to suspend Tensorlake sandbox on shutdown; falling back to terminate.", + extra={"sandbox_id": self.state.sandbox_id}, + exc_info=exc, + ) + try: + await sandbox.terminate() + self._backend_lifecycle_finalized = True + self._sandbox = None + except Exception as term_exc: + logger.warning( + "Failed to terminate Tensorlake sandbox after suspend fallback failure.", + extra={"sandbox_id": self.state.sandbox_id}, + exc_info=term_exc, + ) + # Leave `self._sandbox` attached so `client.delete()` can retry. + else: + logger.warning( + "Failed to terminate Tensorlake sandbox on shutdown.", + extra={"sandbox_id": self.state.sandbox_id}, + exc_info=exc, + ) + # Leave `self._sandbox` attached so `client.delete()` can retry. + + async def running(self) -> bool: + if not self.state.workspace_root_ready: + return False + sandbox = self._sandbox + if sandbox is None: + return False + try: + status = await sandbox.status() + except Exception: + return False + return bool(status == SandboxStatus.RUNNING) + + async def _exec_internal( + self, + *command: str | Path, + timeout: float | None = None, + ) -> ExecResult: + normalized = [str(part) for part in command] + if not normalized: + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + + envs = await self._resolved_envs() + cwd = sandbox_path_str(self.state.manifest.root) + exec_timeout = self._coerce_exec_timeout(timeout) + + try: + # Rely on the SDK's own `timeout` so the backend tears down the running + # process; an outer `asyncio.wait_for` only cancels the local awaiter and + # would leave the sandbox-side command running until the next tick. + result = await self._sandbox.run( + normalized[0], + normalized[1:], + env=envs or None, + working_dir=cwd, + timeout=max(1, math.ceil(exec_timeout)), + ) + except Exception as exc: + if "timeout" in type(exc).__name__.lower() or "timed out" in str(exc).lower(): + raise ExecTimeoutError(command=command, timeout_s=timeout, cause=exc) from exc + raise ExecTransportError( + command=command, + context={"backend": "tensorlake", "sandbox_id": self.state.sandbox_id}, + cause=exc, + ) from exc + + stdout_str = str(getattr(result, "stdout", "") or "") + stderr_str = str(getattr(result, "stderr", "") or "") + exit_code = int(getattr(result, "exit_code", 0) or 0) + return ExecResult( + stdout=stdout_str.encode("utf-8", errors="replace"), + stderr=stderr_str.encode("utf-8", errors="replace"), + exit_code=exit_code, + ) + + async def _resolve_exposed_port(self, port: int) -> ExposedPortEndpoint: + # Prefer the backend's per-sandbox URL so non-default `TENSORLAKE_SANDBOX_PROXY_URL` + # deployments (e.g. tensorlake.dev) resolve correctly; fall back to the public template. + proxy_hostname = await self._get_proxy_hostname() + if proxy_hostname: + host = f"{port}-{proxy_hostname}" + else: + host = _DEFAULT_EXPOSED_PORT_HOST_TEMPLATE.format( + port=port, sandbox=self.state.name or self.state.sandbox_id + ) + return ExposedPortEndpoint(host=host, port=443, tls=True) + + async def _get_proxy_hostname(self) -> str | None: + if self._proxy_hostname_resolved: + return self._cached_proxy_hostname + custom_control_plane = self.state.proxy_url is not None or self.state.api_url is not None + try: + info = await self._sandbox.info() + except Exception: + info = None + sandbox_url = getattr(info, "sandbox_url", None) if info is not None else None + if custom_control_plane and not sandbox_url: + # Some Tensorlake create paths cache minimal info without `sandbox_url`; + # `status()` performs a fresh lifecycle read and refreshes the SDK cache. + with suppress(Exception): + await self._sandbox.status() + info = await self._sandbox.info() + sandbox_url = getattr(info, "sandbox_url", None) if info is not None else None + hostname: str | None = None + if isinstance(sandbox_url, str) and sandbox_url: + parsed = urlsplit(sandbox_url).hostname + if parsed and parsed not in _LOOPBACK_HOSTS: + hostname = parsed + self._cached_proxy_hostname = hostname + # For custom control planes, an unresolved hostname is almost certainly a + # transient info()/status() failure rather than a steady-state answer (the + # public template fallback cannot route to a custom deployment). Leave the + # cache "unresolved" so a later call can retry instead of permanently + # returning the wrong fallback for the rest of the session. + self._proxy_hostname_resolved = hostname is not None or not custom_control_plane + if hostname is None and custom_control_plane: + logger.warning( + "Could not resolve Tensorlake sandbox URL from info(); falling back to the " + "public exposed-port template, which will not route correctly for this " + "custom proxy_url/api_url deployment. Will retry on the next lookup.", + extra={ + "sandbox_id": self.state.sandbox_id, + "proxy_url": self.state.proxy_url, + "api_url": self.state.api_url, + }, + ) + return hostname + + async def read(self, path: Path, *, user: str | User | None = None) -> io.IOBase: + if user is not None: + await self._check_read_with_exec(path, user=user) + + normalized_path = await self._validate_path_access(path) + + try: + payload = await self._sandbox.read_file(sandbox_path_str(normalized_path)) + except FileNotFoundError as exc: + raise WorkspaceReadNotFoundError(path=normalized_path, cause=exc) from exc + except Exception as exc: + if isinstance(exc, RemoteAPIError) and getattr(exc, "status_code", None) == 404: + raise WorkspaceReadNotFoundError(path=normalized_path, cause=exc) from exc + raise WorkspaceArchiveReadError(path=normalized_path, cause=exc) from exc + + return io.BytesIO(_unwrap_traced_bytes(payload)) + + async def write( + self, + path: Path, + data: io.IOBase, + *, + user: str | User | None = None, + ) -> None: + if user is not None: + await self._check_write_with_exec(path, user=user) + + normalized_path = await self._validate_path_access(path, for_write=True) + + payload = data.read() + if isinstance(payload, str): + payload = payload.encode("utf-8") + if not isinstance(payload, bytes | bytearray): + raise WorkspaceWriteTypeError(path=normalized_path, actual_type=type(payload).__name__) + + try: + await self._sandbox.write_file(sandbox_path_str(normalized_path), bytes(payload)) + except Exception as exc: + raise WorkspaceArchiveWriteError(path=normalized_path, cause=exc) from exc + + async def mkdir( + self, + path: Path | str, + *, + parents: bool = False, + user: str | User | None = None, + ) -> None: + if user is not None: + path = await self._check_mkdir_with_exec(path, parents=parents, user=user) + else: + path = await self._validate_path_access(path, for_write=True) + + if path == Path("/"): + return + + flag = "-p" if parents else "" + argv = [a for a in [flag, "--", sandbox_path_str(path)] if a] + try: + result = await self._run_mkdir(argv, user=user) + except Exception as exc: + raise WorkspaceArchiveWriteError( + path=path, context={"reason": "mkdir_failed"}, cause=exc + ) from exc + + exit_code = int(getattr(result, "exit_code", 0) or 0) + if exit_code != 0: + raise WorkspaceArchiveWriteError( + path=path, + context={ + "reason": "mkdir_nonzero_exit", + "exit_code": exit_code, + "stderr": str(getattr(result, "stderr", "") or ""), + }, + ) + + async def persist_workspace(self) -> io.IOBase: + return await with_ephemeral_mounts_removed( + self, + self._persist_workspace_internal, + error_path=self._workspace_root_path(), + error_cls=WorkspaceArchiveReadError, + operation_error_context_key="snapshot_error_before_remount_corruption", + ) + + async def _persist_workspace_internal(self) -> io.IOBase: + if self.state.workspace_persistence == _WORKSPACE_PERSISTENCE_SNAPSHOT: + return await self._persist_workspace_via_checkpoint() + return await self._persist_workspace_via_tar() + + async def _persist_workspace_via_checkpoint(self) -> io.IOBase: + """Persist using Tensorlake's native sandbox checkpoint API. + + Falls back to tar when the backend declines or when path-level skips are required — + Tensorlake checkpoints capture the whole sandbox and have no path-level excludes. + """ + + root = self._workspace_root_path() + error_root = posix_path_for_error(root) + + if self._native_snapshot_requires_tar_fallback(): + return await self._persist_workspace_via_tar() + + skip = self._persist_workspace_skip_relpaths() + mount_targets = self.state.manifest.ephemeral_mount_targets() + mount_skip_rel_paths: set[Path] = set() + for _, mount_path in mount_targets: + try: + mount_skip_rel_paths.add(mount_path.relative_to(root)) + except ValueError: + continue + if skip - mount_skip_rel_paths: + return await self._persist_workspace_via_tar() + + checkpoint_type = ( + CheckpointType.MEMORY + if self.state.checkpoint_mode == "memory" + else CheckpointType.FILESYSTEM + ) + + # Rely on the SDK's own `timeout` so the backend tears down the operation; + # an outer `asyncio.wait_for` would only cancel the local awaiter. The + # `wait_until` knob defaults to `"local_ready"` (Tensorlake SDK default) — that + # is sufficient for `AsyncSandbox.create(snapshot_id=...)` restore and avoids + # blocking on remote-storage upload. Set `checkpoint_wait_until="completed"` + # only when a durable `snapshot_uri` is required. + try: + snapshot = await self._sandbox.checkpoint( + checkpoint_type=checkpoint_type, + timeout=int(self.state.checkpoint_timeout_s), + wait_until=self.state.checkpoint_wait_until, + ) + except Exception as exc: + raise WorkspaceArchiveReadError( + path=error_root, + context={"reason": "tensorlake_checkpoint_failed"}, + cause=exc, + ) from exc + + snapshot_id = getattr(snapshot, "snapshot_id", None) + if not isinstance(snapshot_id, str) or not snapshot_id: + raise WorkspaceArchiveReadError( + path=error_root, + context={ + "reason": "tensorlake_checkpoint_unexpected_return", + "type": type(snapshot).__name__, + }, + ) + return io.BytesIO(_encode_tensorlake_snapshot_ref(snapshot_id=snapshot_id)) + + async def _persist_workspace_via_tar(self) -> io.IOBase: + root = self._workspace_root_path() + error_root = posix_path_for_error(root) + archive_path = f"/tmp/openai-agents-{self.state.session_id.hex}.tar" + skip = list(self._persist_workspace_skip_relpaths()) + # When the workspace root is /tmp (or /) the archive file falls inside the tree being + # archived; exclude it to prevent tar's "file is the archive" error. + try: + skip.append(Path(archive_path).relative_to(root)) + except ValueError: + pass # archive is outside the workspace root + excludes = shell_tar_exclude_args(skip) + tar_argv = ["cf", archive_path, *excludes, "-C", root.as_posix(), "."] + + try: + archive_bytes = await self._run_persist_workspace_command(tar_argv, archive_path) + except Exception as exc: + raise WorkspaceArchiveReadError(path=error_root, cause=exc) from exc + finally: + await self._remove_tmp_archive(archive_path) + + return io.BytesIO(archive_bytes) + + @retry_async( + retry_if=lambda exc, *_args, **_kwargs: exception_chain_has_status_code( + exc, TRANSIENT_HTTP_STATUS_CODES + ) + ) + async def _run_persist_workspace_command(self, tar_argv: list[str], archive_path: str) -> bytes: + envs = await self._resolved_envs() + result = await self._sandbox.run( + "tar", + tar_argv, + env=envs or None, + timeout=int(self.state.timeouts.snapshot_tar_s), + ) + exit_code = int(getattr(result, "exit_code", 0) or 0) + if exit_code != 0: + raise ExecNonZeroError( + ExecResult( + stdout=str(getattr(result, "stdout", "") or "").encode( + "utf-8", errors="replace" + ), + stderr=str(getattr(result, "stderr", "") or "").encode( + "utf-8", errors="replace" + ), + exit_code=exit_code, + ), + command=("tar", *tar_argv), + context={"backend": "tensorlake", "sandbox_id": self.state.sandbox_id}, + ) + payload = await self._sandbox.read_file(archive_path) + return _unwrap_traced_bytes(payload) + + async def hydrate_workspace(self, data: io.IOBase) -> None: + raw = data.read() + if isinstance(raw, str): + raw = raw.encode("utf-8") + if not isinstance(raw, bytes | bytearray): + raise WorkspaceWriteTypeError( + path=self._workspace_root_path(), actual_type=type(raw).__name__ + ) + + await with_ephemeral_mounts_removed( + self, + lambda: self._hydrate_workspace_internal(bytes(raw)), + error_path=self._workspace_root_path(), + error_cls=WorkspaceArchiveWriteError, + operation_error_context_key="hydrate_error_before_remount_corruption", + ) + + async def _hydrate_workspace_internal(self, raw: bytes) -> None: + snapshot_id = _decode_tensorlake_snapshot_ref(raw) + if snapshot_id is not None: + await self._restore_from_checkpoint(snapshot_id) + return + + await self._hydrate_workspace_via_tar(raw) + + async def _restore_from_checkpoint(self, snapshot_id: str) -> None: + root = self._workspace_root_path() + error_root = posix_path_for_error(root) + + try: + await self._sandbox.terminate() + except Exception: + pass + + kwargs = _create_kwargs( + _TensorlakeLifecycleConfig.from_state(self.state), + snapshot_id=snapshot_id, + memory_snapshot=self.state.checkpoint_mode == "memory", + ) + + try: + sandbox = await AsyncSandbox.create(**kwargs) + except Exception as exc: + raise WorkspaceArchiveWriteError( + path=error_root, + context={ + "reason": "tensorlake_checkpoint_restore_failed", + "snapshot_id": snapshot_id, + }, + cause=exc, + ) from exc + + self._sandbox = sandbox + # `_backend_lifecycle_finalized` tracks the current `self._sandbox` handle. + # Rebinding must clear it so `delete()` does not short-circuit on a live sandbox. + self._backend_lifecycle_finalized = False + # The new sandbox has a different sandbox_url; clear the cache so the next + # _resolve_exposed_port() call fetches the updated hostname from the new backend. + self._proxy_hostname_resolved = False + self._cached_proxy_hostname = None + new_id = await _resolve_sandbox_id(sandbox) + if new_id is not None: + self.state.sandbox_id = new_id + try: + await self._apply_exposed_ports() + except Exception: + with suppress(Exception): + await sandbox.terminate() + self._backend_lifecycle_finalized = True + raise + self.state.workspace_root_ready = True + # The restored checkpoint carries full OS state (users, groups, system packages), so + # the base start flow must not re-run groupadd/useradd for accounts already present. + self._set_start_state_preserved(True, system=True) + + async def _hydrate_workspace_via_tar(self, raw: bytes) -> None: + root = self._workspace_root_path() + error_root = posix_path_for_error(root) + + try: + validate_tar_bytes(raw, allow_external_symlink_targets=False) + except UnsafeTarMemberError as exc: + raise WorkspaceArchiveWriteError( + path=error_root, + context={ + "reason": "unsafe_or_invalid_tar", + "member": exc.member, + "detail": str(exc), + }, + cause=exc, + ) from exc + + archive_path = f"/tmp/openai-agents-hydrate-{self.state.session_id.hex}.tar" + + try: + await self._prepare_backend_workspace() + await self._sandbox.write_file(archive_path, raw) + envs = await self._resolved_envs() + result = await self._sandbox.run( + "tar", + ["xf", archive_path, "-C", root.as_posix()], + env=envs or None, + timeout=int(self.state.timeouts.snapshot_tar_s), + ) + except WorkspaceStartError as exc: + raise WorkspaceArchiveWriteError(path=error_root, cause=exc) from exc + except Exception as exc: + raise WorkspaceArchiveWriteError(path=error_root, cause=exc) from exc + finally: + await self._remove_tmp_archive(archive_path) + + exit_code = int(getattr(result, "exit_code", 0) or 0) + if exit_code != 0: + raise WorkspaceArchiveWriteError( + path=error_root, + context={ + "reason": "hydrate_nonzero_exit", + "exit_code": exit_code, + "stderr": str(getattr(result, "stderr", "") or ""), + }, + ) + self.state.workspace_root_ready = True + + async def _remove_tmp_archive(self, archive_path: str) -> None: + """Best-effort cleanup of a `/tmp` tar archive used for workspace persistence.""" + try: + # `delete_file` has no timeout knob; bound it so a hung daemon doesn't + # block the outer persist/hydrate flow indefinitely on a best-effort op. + await asyncio.wait_for( + self._sandbox.delete_file(archive_path), + timeout=self.state.timeouts.fast_op_s, + ) + except Exception: + pass + + async def _apply_exposed_ports(self) -> None: + ports = list(self.state.exposed_ports) + if not ports: + return + try: + await asyncio.wait_for( + self._sandbox.update( + exposed_ports=ports, + allow_unauthenticated_access=self.state.allow_unauthenticated_port_access, + ), + timeout=self.state.timeouts.fast_op_s, + ) + except Exception as exc: + raise WorkspaceStartError( + path=self._workspace_root_path(), + message="failed to expose Tensorlake sandbox ports", + context={ + "reason": "tensorlake_exposed_ports_update_failed", + "sandbox_id": self.state.sandbox_id, + "ports": ports, + "allow_unauthenticated_access": self.state.allow_unauthenticated_port_access, + }, + cause=exc, + ) from exc + + +class TensorlakeSandboxClient(BaseSandboxClient[TensorlakeSandboxClientOptions]): + """Tensorlake-backed sandbox client.""" + + backend_id = "tensorlake" + _instrumentation: Instrumentation + + def __init__( + self, + *, + instrumentation: Instrumentation | None = None, + dependencies: Dependencies | None = None, + ) -> None: + super().__init__() + self._instrumentation = instrumentation or Instrumentation() + self._dependencies = dependencies + + async def create( + self, + *, + snapshot: SnapshotSpec | SnapshotBase | None = None, + manifest: Manifest | None = None, + options: TensorlakeSandboxClientOptions, + ) -> SandboxSession: + if manifest is None: + manifest = Manifest(root=DEFAULT_TENSORLAKE_WORKSPACE_ROOT) + elif manifest.root == _DEFAULT_MANIFEST_ROOT: + # The default Tensorlake image runs as `tl-user`, which cannot write to the + # cross-provider default `/workspace`. Rewrite manifests that still carry that + # default so common construction patterns like `Manifest(entries=...)` work + # against this backend without callers having to know the writable path. + manifest = manifest.model_copy( + update={"root": DEFAULT_TENSORLAKE_WORKSPACE_ROOT}, deep=True + ) + + timeouts_in = options.timeouts + if isinstance(timeouts_in, TensorlakeSandboxTimeouts): + timeouts = timeouts_in + elif timeouts_in is None: + timeouts = TensorlakeSandboxTimeouts() + else: + timeouts = TensorlakeSandboxTimeouts.model_validate(timeouts_in) + + if options.workspace_persistence not in ( + _WORKSPACE_PERSISTENCE_TAR, + _WORKSPACE_PERSISTENCE_SNAPSHOT, + ): + raise ValueError( + "TensorlakeSandboxClient.create requires workspace_persistence to be one of " + f"{_WORKSPACE_PERSISTENCE_TAR!r} or {_WORKSPACE_PERSISTENCE_SNAPSHOT!r}" + ) + + # `timeout_secs` is an *idle threshold* on sandbox-proxy traffic, not a + # wall-clock lifetime. `checkpoint()` polling goes through Tensorlake's + # lifecycle/control-plane client rather than the sandbox proxy, so no proxied + # traffic flows while a checkpoint is in flight — if the idle threshold is + # smaller than the checkpoint poll budget, the sandbox can idle-time out + # mid-poll and orphan the snapshot. Require the idle threshold to exceed the + # poll budget so the snapshot can settle. `timeout_secs=0` requests the plan + # maximum (≥1h on every Tensorlake plan, far larger than any reasonable + # `checkpoint_timeout_s`), so it is exempt alongside `None`. + if ( + options.workspace_persistence == _WORKSPACE_PERSISTENCE_SNAPSHOT + and options.timeout_secs is not None + and options.timeout_secs > 0 + and options.timeout_secs <= options.checkpoint_timeout_s + ): + raise ValueError( + "timeout_secs must be strictly greater than checkpoint_timeout_s when " + "workspace_persistence='snapshot'; otherwise the sandbox can be " + "auto-terminated during checkpoint polling, orphaning the snapshot. " + f"Got timeout_secs={options.timeout_secs}, " + f"checkpoint_timeout_s={options.checkpoint_timeout_s}." + ) + + session_id = uuid.uuid4() + sandbox_name = _resolve_lifecycle_sandbox_name( + name=options.name, + pause_on_exit=options.pause_on_exit, + session_id=session_id, + ) + + kwargs = _create_kwargs(_TensorlakeLifecycleConfig.from_options(options, name=sandbox_name)) + + sandbox = await AsyncSandbox.create(**kwargs) + sandbox_id = await _resolve_sandbox_id(sandbox) + if not sandbox_id: + with suppress(Exception): + await sandbox.terminate() + raise RuntimeError( + "Tensorlake `AsyncSandbox.create` did not return a sandbox with a `sandbox_id`." + ) + + snapshot_instance = resolve_snapshot(snapshot, str(session_id)) + state = TensorlakeSandboxSessionState.from_options( + options, + session_id=session_id, + manifest=manifest, + snapshot=snapshot_instance, + sandbox_id=sandbox_id, + name=sandbox_name, + timeouts=timeouts, + ) + inner = TensorlakeSandboxSession.from_state(state, sandbox=sandbox) + try: + await inner._apply_exposed_ports() + except Exception: + with suppress(Exception): + await sandbox.terminate() + raise + return self._wrap_session(inner, instrumentation=self._instrumentation) + + async def delete(self, session: SandboxSession) -> SandboxSession: + inner = session._inner + if not isinstance(inner, TensorlakeSandboxSession): + raise TypeError("TensorlakeSandboxClient.delete expects a TensorlakeSandboxSession") + # `delete` runs after `shutdown()` in the manager; only terminate when shutdown didn't + # already (e.g. `pause_on_exit=True` suspended instead) so we don't double-call the + # backend, while still freeing remote resources on direct `client.delete(...)` use. + if inner._sandbox is None: + return session + if inner._backend_lifecycle_finalized: + # Remote lifecycle already finalized (e.g. via `_restore_from_checkpoint`'s + # error path) but the local Rust client handle is still attached — release + # it so the connection pool is freed. + inner._close_sandbox_handle() + return session + try: + await inner._sandbox.terminate() + inner._backend_lifecycle_finalized = True + # `terminate()` already closed the local Rust client; drop the reference + # directly to avoid a redundant close. + inner._sandbox = None + except Exception: + # Terminate failed; this is the final cleanup hop, so free the local handle + # even though the remote may still be running. + inner._close_sandbox_handle() + return session + + async def resume( + self, + state: SandboxSessionState, + ) -> SandboxSession: + if not isinstance(state, TensorlakeSandboxSessionState): + raise TypeError( + "TensorlakeSandboxClient.resume expects a TensorlakeSandboxSessionState" + ) + + cfg = _TensorlakeLifecycleConfig.from_state(state) + connect_kwargs = _connect_kwargs(cfg) + + sandbox: Any = None + reconnected = False + try: + sandbox = await AsyncSandbox.connect(state.sandbox_id, **connect_kwargs) + if state.pause_on_exit: + # `connect` returns a handle even for a paused/expired sandbox; `resume` is + # what actually transitions it to running. Failures must fall through so the + # outer handler recreates rather than marking a dead backend as preserved. + await sandbox.resume() + status = await sandbox.status() + if status != SandboxStatus.RUNNING: + raise RuntimeError("tensorlake sandbox is not running") + reconnected = True + except Exception: + if sandbox is not None: + if state.pause_on_exit: + # The user opted into suspend lifecycle and expects the backend to + # preserve workspace state across resume. A probe failure after a + # successful `connect()` is ambiguous — it can be a transient blip + # against a still-suspended sandbox — so terminating here would + # destroy potentially recoverable state (especially with a Noop or + # stale snapshot). Drop the local handle only; the remote either + # reconnects on a later attempt or auto-expires via `timeout_secs`. + with suppress(Exception): + sandbox.close() + else: + # Without suspend lifecycle, no cross-resume state is expected. + # Terminate so the abandoned remote sandbox doesn't linger on the + # backend until its own timeout expires. + with suppress(Exception): + await sandbox.terminate() + sandbox = None + + recreate_snapshot_id: str | None = None + if sandbox is None: + if state.workspace_persistence == _WORKSPACE_PERSISTENCE_SNAPSHOT: + # Skip the throwaway empty sandbox that `hydrate_workspace` would otherwise + # terminate and replace from the same snapshot. + recreate_snapshot_id = await _restore_tensorlake_snapshot_reference_id( + state.snapshot + ) + sandbox_name = _resolve_lifecycle_sandbox_name( + name=state.name, + pause_on_exit=state.pause_on_exit, + session_id=state.session_id, + ) + recreate_cfg = replace(cfg, name=sandbox_name) + kwargs = _create_kwargs( + recreate_cfg, + snapshot_id=recreate_snapshot_id, + memory_snapshot=state.checkpoint_mode == "memory", + ) + sandbox = await AsyncSandbox.create(**kwargs) + new_id = await _resolve_sandbox_id(sandbox) + if new_id is not None: + state.sandbox_id = new_id + state.name = sandbox_name + state.workspace_root_ready = recreate_snapshot_id is not None + + inner = TensorlakeSandboxSession.from_state(state, sandbox=sandbox) + preserved = reconnected or recreate_snapshot_id is not None + inner._set_start_state_preserved(preserved, system=preserved) + try: + await inner._apply_exposed_ports() + except Exception: + if not reconnected: + with suppress(Exception): + await sandbox.terminate() + raise + return self._wrap_session(inner, instrumentation=self._instrumentation) + + def deserialize_session_state(self, payload: dict[str, object]) -> SandboxSessionState: + return cast(SandboxSessionState, TensorlakeSandboxSessionState.model_validate(payload)) + + +__all__ = [ + "DEFAULT_TENSORLAKE_WORKSPACE_ROOT", + "TensorlakeSandboxClient", + "TensorlakeSandboxClientOptions", + "TensorlakeSandboxSession", + "TensorlakeSandboxSessionState", + "TensorlakeSandboxTimeouts", +] diff --git a/tests/extensions/sandbox/test_tensorlake.py b/tests/extensions/sandbox/test_tensorlake.py new file mode 100644 index 0000000000..b9580d9b75 --- /dev/null +++ b/tests/extensions/sandbox/test_tensorlake.py @@ -0,0 +1,2039 @@ +from __future__ import annotations + +import importlib +import io +import sys +import tarfile +import types +import uuid +from enum import Enum +from pathlib import Path +from typing import Any, cast + +import pytest + +from agents.sandbox import Manifest +from agents.sandbox.entries import File +from agents.sandbox.snapshot import LocalSnapshot, NoopSnapshot +from tests._fake_workspace_paths import resolve_fake_workspace_path + + +class _FakeCommandResult: + def __init__(self, *, stdout: str = "", stderr: str = "", exit_code: int = 0) -> None: + self.stdout = stdout + self.stderr = stderr + self.exit_code = exit_code + + +class _FakeSnapshotInfo: + def __init__(self, snapshot_id: str) -> None: + self.snapshot_id = snapshot_id + + +class _FakeCheckpointType: + # Mirror the real `CheckpointType` str-Enum shape (members expose `.value`) so the + # integration's `_resolve_checkpoint_type(...).value` path is exercised by the fake. + class _Member: + def __init__(self, value: str) -> None: + self.value = value + + FILESYSTEM = _Member("filesystem") + MEMORY = _Member("memory") + + +class _FakeSandboxStatus(str, Enum): + # Real `SandboxStatus` is a str-Enum, so members compare equal both to each other and + # to their raw string value. Mirror that here so `status == SandboxStatus.RUNNING` + # works against the fake. + PENDING = "pending" + RUNNING = "running" + SUSPENDED = "suspended" + TERMINATED = "terminated" + + +class _FakeSandboxError(Exception): + """Mimics `tensorlake.sandbox.SandboxError`.""" + + +class _FakeRemoteAPIError(_FakeSandboxError): + def __init__(self, status_code: int, message: str = "") -> None: + super().__init__(f"API error (status {status_code}): {message}") + self.status_code = status_code + self.message = message + + +class _FakeSandboxInfo: + def __init__(self, *, sandbox_url: str | None = None) -> None: + self.sandbox_url = sandbox_url + + +class _FakeTraced: + """Mimics the Tensorlake SDK `Traced[T]` wrapper returned by `read_file`.""" + + def __init__(self, value: Any) -> None: + self.trace_id = "trace-fake" + self._value = value + + @property + def value(self) -> Any: + return self._value + + def __getattr__(self, name: str) -> Any: + return getattr(object.__getattribute__(self, "_value"), name) + + +class _FakeSandbox: + """Async fake mirroring the Tensorlake `AsyncSandbox` surface used by the integration.""" + + create_calls: list[dict[str, object]] = [] + connect_calls: list[dict[str, object]] = [] + sandboxes: dict[str, _FakeSandbox] = {} + snapshots: dict[str, dict[str, bytes]] = {} + next_sandbox_index: int = 0 + create_failures: list[BaseException] = [] + connect_failures: dict[str, BaseException] = {} + update_failures: list[BaseException] = [] + + def __init__( + self, + *, + sandbox_id: str, + name: str | None = None, + status: str = "running", + files: dict[str, bytes] | None = None, + sandbox_url: str | None = None, + status_refresh_sandbox_url: str | None = None, + ) -> None: + self.sandbox_id = sandbox_id + self.name = name + self._status = _FakeSandboxStatus(status) + self.files: dict[str, bytes] = dict(files or {}) + self.run_calls: list[dict[str, object]] = [] + self.update_calls: list[dict[str, object]] = [] + self.terminated = False + self.terminate_count = 0 + self.closed = False + self.close_count = 0 + self.suspended = False + self.resumed = False + self.resume_failure: BaseException | None = None + self.update_failure: BaseException | None = None + self.next_run_result: _FakeCommandResult | None = None + self.symlinks: dict[str, str] = {} + self.sandbox_url = sandbox_url + self.status_refresh_sandbox_url = status_refresh_sandbox_url + self.info_calls = 0 + self.status_calls = 0 + self.last_checkpoint_wait_until: str | None = None + + @classmethod + def reset(cls) -> None: + cls.create_calls = [] + cls.connect_calls = [] + cls.sandboxes = {} + cls.snapshots = {} + cls.next_sandbox_index = 0 + cls.create_failures = [] + cls.connect_failures = {} + cls.update_failures = [] + + @classmethod + async def create(cls, **kwargs: object) -> _FakeSandbox: + cls.create_calls.append(dict(kwargs)) + if cls.create_failures: + raise cls.create_failures.pop(0) + cls.next_sandbox_index += 1 + sandbox_id = f"tensorlake-sandbox-{cls.next_sandbox_index}" + files: dict[str, bytes] = {} + snapshot_id = kwargs.get("snapshot_id") + if isinstance(snapshot_id, str) and snapshot_id in cls.snapshots: + files = dict(cls.snapshots[snapshot_id]) + sandbox = cls( + sandbox_id=sandbox_id, + name=cast(str | None, kwargs.get("name")), + files=files, + ) + cls.sandboxes[sandbox_id] = sandbox + return sandbox + + @classmethod + async def connect(cls, sandbox_id: str, **kwargs: object) -> _FakeSandbox: + cls.connect_calls.append({"sandbox_id": sandbox_id, **kwargs}) + if sandbox_id in cls.connect_failures: + raise cls.connect_failures[sandbox_id] + sandbox = cls.sandboxes.get(sandbox_id) + if sandbox is None: + raise RuntimeError(f"sandbox {sandbox_id} not found") + return sandbox + + async def status(self) -> Any: + self.status_calls += 1 + if self.status_refresh_sandbox_url is not None: + self.sandbox_url = self.status_refresh_sandbox_url + return self._status + + async def info(self) -> _FakeSandboxInfo: + self.info_calls += 1 + return _FakeSandboxInfo(sandbox_url=self.sandbox_url) + + async def run( + self, + command: str, + args: list[str] | None = None, + env: dict[str, str] | None = None, + working_dir: str | None = None, + timeout: float | None = None, + ) -> _FakeTraced: + _ = (env, timeout) + args = args or [] + self.run_calls.append( + { + "command": command, + "args": list(args), + "working_dir": working_dir, + } + ) + + resolved = resolve_fake_workspace_path( + (command, *args), symlinks=self.symlinks, home_dir="/workspace" + ) + if resolved is not None: + return _FakeTraced( + _FakeCommandResult( + exit_code=resolved.exit_code, + stdout=resolved.stdout, + stderr=resolved.stderr, + ) + ) + + if self.next_run_result is not None: + result = self.next_run_result + self.next_run_result = None + return _FakeTraced(result) + + if command == "mkdir": + return _FakeTraced(_FakeCommandResult()) + + cwd = working_dir or "/workspace" + + if command == "tar" and args and args[0] == "cf": + archive_path = args[1] + assert "-C" in args + tar_root = args[args.index("-C") + 1] + include_dot = args[-1] == "." + exclusions = { + arg.removeprefix("--exclude=./") for arg in args if arg.startswith("--exclude=./") + } + buffer = io.BytesIO() + with tarfile.open(fileobj=buffer, mode="w") as tar: + for path, content in sorted(self.files.items()): + if not path.startswith(tar_root.rstrip("/") + "/"): + continue + rel_path = path[len(tar_root.rstrip("/")) + 1 :] + if any(rel_path == ex or rel_path.startswith(f"{ex}/") for ex in exclusions): + continue + info = tarfile.TarInfo(name=rel_path if include_dot else path) + info.size = len(content) + tar.addfile(info, io.BytesIO(content)) + self.files[archive_path] = buffer.getvalue() + return _FakeTraced(_FakeCommandResult()) + + if command == "tar" and args and args[0] == "xf": + archive_path = args[1] + destination = args[args.index("-C") + 1] + raw = self.files[archive_path] + with tarfile.open(fileobj=io.BytesIO(raw), mode="r") as tar: + for member in tar.getmembers(): + if not member.isfile(): + continue + extracted = tar.extractfile(member) + assert extracted is not None + self.files[f"{destination.rstrip('/')}/{member.name}"] = extracted.read() + return _FakeTraced(_FakeCommandResult()) + + if command == "test" and args and args[0] == "-d": + return _FakeTraced(_FakeCommandResult(exit_code=0)) + + _ = cwd + return _FakeTraced(_FakeCommandResult()) + + async def read_file(self, path: str) -> _FakeTraced: + if path not in self.files: + raise _FakeRemoteAPIError(404, f"file not found: {path}") + return _FakeTraced(self.files[path]) + + async def write_file(self, path: str, content: bytes) -> _FakeTraced: + self.files[path] = bytes(content) + return _FakeTraced(None) + + async def delete_file(self, path: str) -> _FakeTraced: + self.files.pop(path, None) + return _FakeTraced(None) + + async def terminate(self) -> None: + self.terminated = True + self.terminate_count += 1 + self._status = _FakeSandboxStatus.TERMINATED + # The real SDK closes the local Rust client inside `terminate()`; mirror that + # so leak-coverage tests see the same call count as production. + self.close() + + def close(self) -> None: + self.closed = True + self.close_count += 1 + + async def suspend( + self, wait: bool = True, timeout: float = 300.0, poll_interval: float = 1.0 + ) -> None: + _ = (wait, timeout, poll_interval) + if not self.name: + raise _FakeRemoteAPIError(400, "only named sandboxes can be suspended") + self.suspended = True + self._status = _FakeSandboxStatus.SUSPENDED + + async def resume( + self, wait: bool = True, timeout: float = 300.0, poll_interval: float = 1.0 + ) -> None: + _ = (wait, timeout, poll_interval) + if self.resume_failure is not None: + raise self.resume_failure + self.resumed = True + self._status = _FakeSandboxStatus.RUNNING + + async def update( + self, + name: str | None = None, + *, + allow_unauthenticated_access: bool | None = None, + exposed_ports: list[int] | None = None, + ) -> _FakeTraced: + self.update_calls.append( + { + "name": name, + "allow_unauthenticated_access": allow_unauthenticated_access, + "exposed_ports": list(exposed_ports) if exposed_ports is not None else None, + } + ) + if self.update_failure is not None: + raise self.update_failure + if type(self).update_failures: + raise type(self).update_failures.pop(0) + return _FakeTraced(_FakeSandboxInfo(sandbox_url=self.sandbox_url)) + + async def checkpoint( + self, + wait: bool = True, + timeout: float = 300.0, + poll_interval: float = 1.0, + checkpoint_type: Any = None, + wait_until: str = "local_ready", + ) -> _FakeSnapshotInfo: + _ = (wait, timeout, poll_interval, checkpoint_type) + self.last_checkpoint_wait_until = wait_until + snapshot_id = f"snap-{len(type(self).snapshots) + 1}" + type(self).snapshots[snapshot_id] = dict(self.files) + return _FakeSnapshotInfo(snapshot_id) + + +@pytest.fixture(autouse=True) +def _reset_fake_sandbox_state() -> None: + _FakeSandbox.reset() + + +def _load_tensorlake_module(monkeypatch: pytest.MonkeyPatch) -> Any: + _FakeSandbox.reset() + + fake_pkg = types.ModuleType("tensorlake") + fake_sandbox_pkg = cast(Any, types.ModuleType("tensorlake.sandbox")) + fake_sandbox_pkg.AsyncSandbox = _FakeSandbox + fake_sandbox_pkg.CheckpointType = _FakeCheckpointType + fake_sandbox_pkg.SandboxStatus = _FakeSandboxStatus + fake_sandbox_pkg.RemoteAPIError = _FakeRemoteAPIError + fake_sandbox_pkg.SandboxError = _FakeSandboxError + + monkeypatch.setitem(sys.modules, "tensorlake", fake_pkg) + monkeypatch.setitem(sys.modules, "tensorlake.sandbox", fake_sandbox_pkg) + sys.modules.pop("agents.extensions.sandbox.tensorlake.sandbox", None) + sys.modules.pop("agents.extensions.sandbox.tensorlake", None) + + return importlib.import_module("agents.extensions.sandbox.tensorlake.sandbox") + + +def test_tensorlake_package_re_exports_backend_symbols(monkeypatch: pytest.MonkeyPatch) -> None: + module = _load_tensorlake_module(monkeypatch) + pkg = importlib.import_module("agents.extensions.sandbox.tensorlake") + + assert pkg.TensorlakeSandboxClient is module.TensorlakeSandboxClient + assert pkg.TensorlakeSandboxSessionState is module.TensorlakeSandboxSessionState + + +def test_tensorlake_supports_pty_is_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000001"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-no-pty", + ) + fake = _FakeSandbox(sandbox_id="sandbox-no-pty") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + assert session.supports_pty() is False + + +@pytest.mark.asyncio +async def test_resolve_sandbox_id_handles_raising_property( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # The real `AsyncSandbox.sandbox_id` is a property that raises `SandboxError` before + # `info()` populates the cache. `_resolve_sandbox_id` must swallow that, seed the + # cache via `info()`, and read again — instead of propagating an SDK exception out + # of the integration's create/restore paths. + module = _load_tensorlake_module(monkeypatch) + + class _AlwaysRaisingSandbox: + info_calls = 0 + + @property + def sandbox_id(self) -> str: + raise _FakeSandboxError("sandbox_id is not yet known; call `info()` first.") + + async def info(self) -> None: + type(self).info_calls += 1 + + sandbox = _AlwaysRaisingSandbox() + assert await module._resolve_sandbox_id(sandbox) is None + # `info()` is awaited at most once even if the second read also fails. + assert _AlwaysRaisingSandbox.info_calls == 1 + + class _EmptySandbox: + sandbox_id = "" + + async def info(self) -> None: + pass + + assert await module._resolve_sandbox_id(_EmptySandbox()) is None + + class _ReadySandbox: + sandbox_id = "sb-123" + + async def info(self) -> None: + pass + + assert await module._resolve_sandbox_id(_ReadySandbox()) == "sb-123" + + class _LateBoundSandbox: + # Mirrors the real SDK shape: `sandbox_id` raises until `info()` populates the + # cache, after which the property returns the id. + def __init__(self) -> None: + self._id: str | None = None + + @property + def sandbox_id(self) -> str: + if self._id is None: + raise _FakeSandboxError("sandbox_id is not yet known") + return self._id + + async def info(self) -> None: + self._id = "sb-late" + + assert await module._resolve_sandbox_id(_LateBoundSandbox()) == "sb-late" + + +@pytest.mark.asyncio +async def test_create_passes_options_and_drops_unset_fields( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + client = module.TensorlakeSandboxClient() + session = await client.create( + manifest=Manifest(), + options=module.TensorlakeSandboxClientOptions( + image="my-image", + cpus=2.0, + memory_mb=2048, + disk_mb=20480, + timeout_secs=600, + name="demo", + secret_names=("OPENAI_KEY",), + allow_internet_access=False, + allow_out=("10.0.0.0/8",), + deny_out=("example.com",), + exposed_ports=(8080,), + allow_unauthenticated_port_access=True, + ), + ) + + assert _FakeSandbox.create_calls == [ + { + "image": "my-image", + "cpus": 2.0, + "memory_mb": 2048, + "disk_mb": 20480, + "timeout_secs": 600, + "name": "demo", + "secret_names": ["OPENAI_KEY"], + "allow_internet_access": False, + "allow_out": ["10.0.0.0/8"], + "deny_out": ["example.com"], + } + ] + inner = session._inner + assert inner.state.sandbox_id == "tensorlake-sandbox-1" + sandbox = _FakeSandbox.sandboxes["tensorlake-sandbox-1"] + assert sandbox.update_calls == [ + { + "name": None, + "allow_unauthenticated_access": True, + "exposed_ports": [8080], + } + ] + + +@pytest.mark.asyncio +async def test_create_fails_when_exposed_port_update_fails( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + client = module.TensorlakeSandboxClient() + cause = RuntimeError("port update failed") + _FakeSandbox.update_failures.append(cause) + + from agents.sandbox.errors import WorkspaceStartError + + with pytest.raises(WorkspaceStartError) as exc_info: + await client.create( + manifest=Manifest(), + options=module.TensorlakeSandboxClientOptions(exposed_ports=(8080,)), + ) + + assert exc_info.value.__cause__ is cause + assert exc_info.value.message == "failed to expose Tensorlake sandbox ports" + assert exc_info.value.context["reason"] == "tensorlake_exposed_ports_update_failed" + assert exc_info.value.context["sandbox_id"] == "tensorlake-sandbox-1" + assert exc_info.value.context["ports"] == [8080] + assert _FakeSandbox.sandboxes["tensorlake-sandbox-1"].terminate_count == 1 + + +@pytest.mark.asyncio +async def test_create_generates_name_when_pause_on_exit( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + client = module.TensorlakeSandboxClient() + + session = await client.create( + manifest=Manifest(), + options=module.TensorlakeSandboxClientOptions(pause_on_exit=True), + ) + + generated_name = _FakeSandbox.create_calls[0]["name"] + assert isinstance(generated_name, str) + assert generated_name.startswith("openai-agents-") + assert session._inner.state.name == generated_name + assert session._inner._sandbox.name == generated_name + + +@pytest.mark.asyncio +async def test_create_passes_routing_options(monkeypatch: pytest.MonkeyPatch) -> None: + module = _load_tensorlake_module(monkeypatch) + client = module.TensorlakeSandboxClient() + session = await client.create( + manifest=Manifest(), + options=module.TensorlakeSandboxClientOptions( + entrypoint=("python", "-m", "app"), + startup_timeout=90.0, + proxy_url="https://proxy.tensorlake.dev", + api_url="https://api.tensorlake.dev", + namespace="tenant-a", + ), + ) + + assert _FakeSandbox.create_calls == [ + { + "allow_internet_access": True, + "entrypoint": ["python", "-m", "app"], + "startup_timeout": 90.0, + "proxy_url": "https://proxy.tensorlake.dev", + "api_url": "https://api.tensorlake.dev", + "namespace": "tenant-a", + } + ] + state = session._inner.state + assert state.entrypoint == ("python", "-m", "app") + assert state.startup_timeout == 90.0 + assert state.proxy_url == "https://proxy.tensorlake.dev" + assert state.api_url == "https://api.tensorlake.dev" + assert state.namespace == "tenant-a" + + +@pytest.mark.asyncio +async def test_resume_forwards_routing_to_connect_and_recreate( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + + existing = _FakeSandbox(sandbox_id="sandbox-dead", status="terminated") + _FakeSandbox.sandboxes["sandbox-dead"] = existing + + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000020"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-dead", + entrypoint=("python", "-m", "app"), + startup_timeout=90.0, + proxy_url="https://proxy.tensorlake.dev", + api_url="https://api.tensorlake.dev", + namespace="tenant-a", + ) + + client = module.TensorlakeSandboxClient() + await client.resume(state) + + assert _FakeSandbox.connect_calls == [ + { + "sandbox_id": "sandbox-dead", + "proxy_url": "https://proxy.tensorlake.dev", + "api_url": "https://api.tensorlake.dev", + "namespace": "tenant-a", + } + ] + assert len(_FakeSandbox.create_calls) == 1 + create_kwargs = _FakeSandbox.create_calls[0] + assert create_kwargs["entrypoint"] == ["python", "-m", "app"] + assert create_kwargs["startup_timeout"] == 90.0 + assert create_kwargs["proxy_url"] == "https://proxy.tensorlake.dev" + assert create_kwargs["api_url"] == "https://api.tensorlake.dev" + assert create_kwargs["namespace"] == "tenant-a" + + +@pytest.mark.asyncio +async def test_create_omits_optional_kwargs_when_unset(monkeypatch: pytest.MonkeyPatch) -> None: + module = _load_tensorlake_module(monkeypatch) + client = module.TensorlakeSandboxClient() + await client.create( + manifest=Manifest(), + options=module.TensorlakeSandboxClientOptions(), + ) + + assert _FakeSandbox.create_calls == [{"allow_internet_access": True}] + + +@pytest.mark.asyncio +async def test_create_rejects_snapshot_with_lifetime_le_checkpoint_timeout( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # When workspace_persistence='snapshot' is in play and `timeout_secs` is not greater + # than `checkpoint_timeout_s`, the sandbox can be auto-terminated during checkpoint + # polling and orphan the snapshot. Reject the misconfiguration at create() time. + module = _load_tensorlake_module(monkeypatch) + client = module.TensorlakeSandboxClient() + + with pytest.raises(ValueError, match="timeout_secs must be strictly greater"): + await client.create( + manifest=Manifest(), + options=module.TensorlakeSandboxClientOptions( + workspace_persistence="snapshot", + timeout_secs=300, + checkpoint_timeout_s=300.0, + ), + ) + + # Default tar persistence should not trigger the validation even with the same timings. + await client.create( + manifest=Manifest(), + options=module.TensorlakeSandboxClientOptions( + timeout_secs=300, + checkpoint_timeout_s=300.0, + ), + ) + + # Snapshot persistence with strictly greater lifetime is accepted. + await client.create( + manifest=Manifest(), + options=module.TensorlakeSandboxClientOptions( + workspace_persistence="snapshot", + timeout_secs=600, + checkpoint_timeout_s=300.0, + ), + ) + + # `timeout_secs=0` requests the plan maximum (≥1h on every Tensorlake plan), which + # is always larger than any reasonable `checkpoint_timeout_s`, so the guard must + # let it through even though `0 <= checkpoint_timeout_s` is literally true. + await client.create( + manifest=Manifest(), + options=module.TensorlakeSandboxClientOptions( + workspace_persistence="snapshot", + timeout_secs=0, + checkpoint_timeout_s=300.0, + ), + ) + + +@pytest.mark.asyncio +async def test_create_rewrites_default_manifest_root(monkeypatch: pytest.MonkeyPatch) -> None: + # Callers that construct `Manifest(entries=...)` leave `root` at the cross-provider + # default `/workspace`, which is not writable for `tl-user` in the default Tensorlake + # image. The client should rewrite that default to the Tensorlake-writable path so + # those manifests work without callers having to know the backend-specific root. + module = _load_tensorlake_module(monkeypatch) + client = module.TensorlakeSandboxClient() + session = await client.create( + manifest=Manifest(), + options=module.TensorlakeSandboxClientOptions(), + ) + + assert session._inner.state.manifest.root == module.DEFAULT_TENSORLAKE_WORKSPACE_ROOT + + +@pytest.mark.asyncio +async def test_create_preserves_explicit_manifest_root(monkeypatch: pytest.MonkeyPatch) -> None: + # A non-default manifest root must be honored verbatim; only the default `/workspace` + # placeholder is rewritten. + module = _load_tensorlake_module(monkeypatch) + client = module.TensorlakeSandboxClient() + session = await client.create( + manifest=Manifest(root="/tmp/custom"), + options=module.TensorlakeSandboxClientOptions(), + ) + + assert session._inner.state.manifest.root == "/tmp/custom" + + +@pytest.mark.asyncio +async def test_exec_read_write_and_mkdir(monkeypatch: pytest.MonkeyPatch) -> None: + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000002"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-rw", + ) + fake = _FakeSandbox(sandbox_id="sandbox-rw") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + await session.write(Path("notes.txt"), io.BytesIO(b"hello")) + payload = await session.read(Path("notes.txt")) + assert payload.read() == b"hello" + + await session.mkdir(Path("subdir"), parents=True) + mkdir_calls = [c for c in fake.run_calls if c["command"] == "mkdir"] + assert mkdir_calls and mkdir_calls[-1]["args"] == ["-p", "--", "/workspace/subdir"] + + fake.next_run_result = _FakeCommandResult(stdout="hi\n", exit_code=0) + result = await session.exec("printf", "hi", shell=False) + assert result.ok() + assert result.stdout == b"hi\n" + + +@pytest.mark.asyncio +async def test_exec_user_param_wraps_with_sudo( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Tensorlake's AsyncSandbox.run does not accept user=; switching users goes through + # the base class's `sudo -u --` wrap like every other backend. + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-00000000beef"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-user", + ) + fake = _FakeSandbox(sandbox_id="sandbox-user") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + fake.next_run_result = _FakeCommandResult(stdout="ok\n", exit_code=0) + await session.exec("printf", "ok", shell=False, user="tl-user") + + last = fake.run_calls[-1] + assert last["command"] == "sudo" + assert last["args"] == ["-u", "tl-user", "--", "printf", "ok"] + assert "user" not in last + + +@pytest.mark.asyncio +async def test_mkdir_user_param_wraps_with_sudo( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # `session.mkdir(..., user=...)` must run as the requested sandbox-local user so the + # directory is created with the correct ownership; otherwise Tensorlake would run + # mkdir as the default `tl-user`, which fails for directories only the requested + # user can create. + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-0000000d1ca1"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-mkdir-user", + ) + fake = _FakeSandbox(sandbox_id="sandbox-mkdir-user") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + await session.mkdir(Path("subdir"), parents=True, user="root") + + mkdir_calls = [ + c + for c in fake.run_calls + if c["command"] == "sudo" and "mkdir" in cast(list[str], c["args"]) + ] + assert mkdir_calls, "expected mkdir to be wrapped with `sudo` when user= is set" + last = mkdir_calls[-1] + assert last["args"] == ["-u", "root", "--", "mkdir", "-p", "--", "/workspace/subdir"] + + +@pytest.mark.asyncio +async def test_exec_checked_nonzero_runs_privileged_metadata_as_root( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Manifest account provisioning and file metadata commands (groupadd/useradd/ + # usermod/chgrp/chmod) funnel through `_exec_checked_nonzero`. The default + # Tensorlake image's process user is the non-root `tl-user`, and the SDK + # exposes no process-user knob on `run()` — so without the sudo wrap these + # privileged operations fail with permission denied and any manifest that + # declares users or groups cannot start. + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-00000000feed"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-provision", + ) + fake = _FakeSandbox(sandbox_id="sandbox-provision") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + await session._exec_checked_nonzero("groupadd", "researchers") + await session._exec_checked_nonzero("useradd", "-U", "-M", "alice") + await session._exec_checked_nonzero("chmod", "0600", "/tmp/secret") + + sudo_calls = [c for c in fake.run_calls if c["command"] == "sudo"] + expected_args = [ + ["-u", "root", "--", "groupadd", "researchers"], + ["-u", "root", "--", "useradd", "-U", "-M", "alice"], + ["-u", "root", "--", "chmod", "0600", "/tmp/secret"], + ] + assert [c["args"] for c in sudo_calls] == expected_args, ( + f"expected privileged commands to be wrapped with `sudo -u root --`, got {sudo_calls!r}" + ) + + +@pytest.mark.asyncio +async def test_read_missing_file_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000003"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-missing", + ) + fake = _FakeSandbox(sandbox_id="sandbox-missing") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + from agents.sandbox.errors import WorkspaceReadNotFoundError + + with pytest.raises(WorkspaceReadNotFoundError): + await session.read(Path("nope.txt")) + + +@pytest.mark.asyncio +async def test_read_non_404_remote_api_error_raises_archive_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Non-404 `RemoteAPIError` from `read_file` surfaces as `WorkspaceArchiveReadError`. + + Only the 404 path is treated as a missing-file signal; other statuses (e.g. 403/500) + indicate a transport/auth failure and must not be reported as "not found". + """ + + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000050"), + manifest=Manifest(entries={"notes.txt": File(content=b"x")}), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-read-500", + ) + fake = _FakeSandbox(sandbox_id="sandbox-read-500") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + await session.start() + + async def _raise_500(_path: str) -> _FakeTraced: + raise _FakeRemoteAPIError(500, "internal server error") + + fake.read_file = _raise_500 # type: ignore[assignment] + + from agents.sandbox.errors import WorkspaceArchiveReadError, WorkspaceReadNotFoundError + + with pytest.raises(WorkspaceArchiveReadError) as exc_info: + await session.read(Path("notes.txt")) + + assert not isinstance(exc_info.value, WorkspaceReadNotFoundError) + cause = exc_info.value.__cause__ + assert isinstance(cause, _FakeRemoteAPIError) + assert cause.status_code == 500 + + +@pytest.mark.asyncio +async def test_exposed_port_resolution_uses_sandbox_id(monkeypatch: pytest.MonkeyPatch) -> None: + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000004"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-ports", + exposed_ports=(3000,), + ) + fake = _FakeSandbox(sandbox_id="sandbox-ports") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + endpoint = await session.resolve_exposed_port(3000) + assert endpoint.host == "3000-sandbox-ports.sandbox.tensorlake.ai" + assert endpoint.port == 443 + assert endpoint.tls is True + + +@pytest.mark.asyncio +async def test_exposed_port_resolution_uses_named_sandbox_when_set( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000005"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-id", + name="demo", + exposed_ports=(8080,), + ) + fake = _FakeSandbox(sandbox_id="sandbox-id", name="demo") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + endpoint = await session.resolve_exposed_port(8080) + assert endpoint.host == "8080-demo.sandbox.tensorlake.ai" + + +@pytest.mark.asyncio +async def test_exposed_port_resolution_uses_backend_sandbox_url( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-00000000000a"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-dev", + name="dev-env", + exposed_ports=(8080,), + ) + fake = _FakeSandbox( + sandbox_id="sandbox-dev", + name="dev-env", + sandbox_url="https://dev-env.sandbox.tensorlake.dev", + ) + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + endpoint = await session.resolve_exposed_port(8080) + assert endpoint.host == "8080-dev-env.sandbox.tensorlake.dev" + + +@pytest.mark.asyncio +async def test_custom_proxy_exposed_port_resolution_refreshes_minimal_cached_info( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-00000000003a"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-dev", + name="dev-env", + exposed_ports=(8080,), + proxy_url="https://sandbox.tensorlake.dev", + ) + fake = _FakeSandbox( + sandbox_id="sandbox-dev", + name="dev-env", + sandbox_url=None, + status_refresh_sandbox_url="https://dev-env.sandbox.tensorlake.dev", + ) + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + endpoint = await session.resolve_exposed_port(8080) + + assert endpoint.host == "8080-dev-env.sandbox.tensorlake.dev" + assert fake.status_calls == 1 + assert fake.info_calls == 2 + + +@pytest.mark.asyncio +async def test_exposed_port_resolution_caches_proxy_hostname( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-00000000000c"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-cache", + exposed_ports=(8080, 9090), + ) + fake = _FakeSandbox( + sandbox_id="sandbox-cache", + sandbox_url="https://sandbox-cache.sandbox.tensorlake.ai", + ) + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + await session.resolve_exposed_port(8080) + await session.resolve_exposed_port(9090) + + assert fake.info_calls == 1 + + +@pytest.mark.asyncio +async def test_custom_proxy_exposed_port_resolution_retries_after_transient_failure( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-00000000004a"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-dev", + name="dev-env", + exposed_ports=(8080,), + proxy_url="https://sandbox.tensorlake.dev", + ) + fake = _FakeSandbox( + sandbox_id="sandbox-dev", + name="dev-env", + sandbox_url=None, + ) + + info_failures = {"count": 1} + status_failures = {"count": 1} + original_info = fake.info + original_status = fake.status + + async def flaky_info() -> Any: + if info_failures["count"] > 0: + info_failures["count"] -= 1 + raise RuntimeError("transient info() failure") + return await original_info() + + async def flaky_status() -> Any: + if status_failures["count"] > 0: + status_failures["count"] -= 1 + raise RuntimeError("transient status() failure") + return await original_status() + + fake.info = flaky_info # type: ignore[assignment] + fake.status = flaky_status # type: ignore[assignment] + + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + # First lookup hits the transient failures and falls back to the public template; + # because this is a custom-control-plane deployment, the cache must NOT latch. + first = await session.resolve_exposed_port(8080) + assert first.host == "8080-dev-env.sandbox.tensorlake.ai" + + # Recover and supply a sandbox_url so the next info() refresh returns it. + fake.sandbox_url = "https://dev-env.sandbox.tensorlake.dev" + + # The retry must reach the backend instead of short-circuiting on the cached miss. + second = await session.resolve_exposed_port(8080) + assert second.host == "8080-dev-env.sandbox.tensorlake.dev" + + +@pytest.mark.asyncio +async def test_delete_terminates_remote_sandbox(monkeypatch: pytest.MonkeyPatch) -> None: + module = _load_tensorlake_module(monkeypatch) + client = module.TensorlakeSandboxClient() + session = await client.create( + manifest=Manifest(), + options=module.TensorlakeSandboxClientOptions(), + ) + fake = session._inner._sandbox + assert fake.terminated is False + + await client.delete(session) + + assert fake.terminate_count == 1 + + +@pytest.mark.asyncio +async def test_delete_terminates_even_when_pause_on_exit( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + client = module.TensorlakeSandboxClient() + session = await client.create( + manifest=Manifest(), + options=module.TensorlakeSandboxClientOptions(pause_on_exit=True), + ) + fake = session._inner._sandbox + + await client.delete(session) + + assert fake.terminate_count == 1 + assert fake.suspended is False + + +@pytest.mark.asyncio +async def test_shutdown_then_delete_does_not_double_terminate( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + client = module.TensorlakeSandboxClient() + session = await client.create( + manifest=Manifest(), + options=module.TensorlakeSandboxClientOptions(), + ) + fake = session._inner._sandbox + + await session.shutdown() + await client.delete(session) + + assert fake.terminate_count == 1 + + +@pytest.mark.asyncio +async def test_shutdown_pause_then_delete_preserves_suspended_sandbox( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + client = module.TensorlakeSandboxClient() + session = await client.create( + manifest=Manifest(), + options=module.TensorlakeSandboxClientOptions(pause_on_exit=True), + ) + fake = session._inner._sandbox + + await session.shutdown() + await client.delete(session) + + assert fake.suspended is True + assert fake.terminate_count == 0 + + +@pytest.mark.asyncio +async def test_shutdown_terminates_by_default(monkeypatch: pytest.MonkeyPatch) -> None: + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000006"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-shutdown", + ) + fake = _FakeSandbox(sandbox_id="sandbox-shutdown") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + await session.shutdown() + + assert fake.terminated is True + assert fake.suspended is False + + +@pytest.mark.asyncio +async def test_shutdown_suspends_when_pause_on_exit(monkeypatch: pytest.MonkeyPatch) -> None: + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000007"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-pause", + name="sandbox-pause-name", + pause_on_exit=True, + ) + fake = _FakeSandbox(sandbox_id="sandbox-pause", name="sandbox-pause-name") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + await session.shutdown() + + assert fake.suspended is True + assert fake.terminated is False + assert session._backend_lifecycle_finalized is True + + +@pytest.mark.asyncio +async def test_shutdown_pause_closes_local_handle(monkeypatch: pytest.MonkeyPatch) -> None: + """`suspend()` does not close the SDK's Rust client; shutdown must close it manually.""" + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000020"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-pause-close", + name="sandbox-pause-close-name", + pause_on_exit=True, + ) + fake = _FakeSandbox(sandbox_id="sandbox-pause-close", name="sandbox-pause-close-name") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + await session.shutdown() + + assert fake.suspended is True + assert fake.closed is True + assert fake.close_count == 1 + assert session._sandbox is None + + +@pytest.mark.asyncio +async def test_shutdown_terminate_closes_local_handle(monkeypatch: pytest.MonkeyPatch) -> None: + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000021"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-terminate-close", + ) + fake = _FakeSandbox(sandbox_id="sandbox-terminate-close") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + await session.shutdown() + + assert fake.terminated is True + assert fake.closed is True + # `AsyncSandbox.terminate()` already closes the local Rust client; the integration + # must not call `close()` a second time afterward. + assert fake.close_count == 1 + assert session._sandbox is None + + +@pytest.mark.asyncio +async def test_failed_shutdown_terminate_lets_client_delete_retry( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A transient terminate failure during shutdown must not orphan the remote sandbox. + + Regression: an earlier version of the leak fix nulled `_sandbox` even on the + shutdown failure path, which made `client.delete()` short-circuit and skip the + retry — leaving the remote sandbox running until its own `timeout_secs`. + """ + module = _load_tensorlake_module(monkeypatch) + client = module.TensorlakeSandboxClient() + session = await client.create( + manifest=Manifest(), + options=module.TensorlakeSandboxClientOptions(), + ) + inner = session._inner + fake = inner._sandbox + + original_terminate = fake.terminate + call_count = {"n": 0} + + async def flaky_terminate() -> None: + call_count["n"] += 1 + if call_count["n"] == 1: + raise RuntimeError("transient backend error") + await original_terminate() + + fake.terminate = flaky_terminate + + await session.shutdown() + + # First terminate raised; the local handle must still be attached so delete can retry. + assert inner._sandbox is fake + assert inner._backend_lifecycle_finalized is False + + await client.delete(session) + + # `client.delete()` retried terminate; remote is now finalized and the handle freed. + assert call_count["n"] == 2 + assert inner._sandbox is None + assert inner._backend_lifecycle_finalized is True + + +@pytest.mark.asyncio +async def test_failed_shutdown_suspend_then_terminate_fallback_retries_via_delete( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """When both suspend and the terminate fallback fail during shutdown, delete must retry.""" + module = _load_tensorlake_module(monkeypatch) + client = module.TensorlakeSandboxClient() + session = await client.create( + manifest=Manifest(), + options=module.TensorlakeSandboxClientOptions(pause_on_exit=True), + ) + inner = session._inner + fake = inner._sandbox + + # Force suspend to raise so the fallback terminate path runs. + async def failing_suspend( + wait: bool = True, timeout: float = 300.0, poll_interval: float = 1.0 + ) -> None: + raise RuntimeError("suspend offline") + + original_terminate = fake.terminate + terminate_calls = {"n": 0} + + async def flaky_terminate() -> None: + terminate_calls["n"] += 1 + if terminate_calls["n"] == 1: + raise RuntimeError("fallback terminate offline") + await original_terminate() + + fake.suspend = failing_suspend + fake.terminate = flaky_terminate + + await session.shutdown() + + assert inner._sandbox is fake + assert inner._backend_lifecycle_finalized is False + + await client.delete(session) + + # Two terminate attempts total: the in-shutdown fallback (failed) and the delete retry. + assert terminate_calls["n"] == 2 + assert inner._sandbox is None + assert inner._backend_lifecycle_finalized is True + + +@pytest.mark.asyncio +async def test_delete_closes_local_handle_after_pause_on_exit_shutdown( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A direct `client.delete()` after suspend must release the local handle, not just no-op.""" + module = _load_tensorlake_module(monkeypatch) + client = module.TensorlakeSandboxClient() + session = await client.create( + manifest=Manifest(), + options=module.TensorlakeSandboxClientOptions(pause_on_exit=True), + ) + inner = session._inner + fake = inner._sandbox + + # Simulate a session where shutdown completed remote suspend but somehow left the + # local Rust client attached. delete() must close it without sending terminate. + inner._backend_lifecycle_finalized = True + fake.closed = False # reset the close flag set by `_close_sandbox_handle` paths + + await client.delete(session) + + assert fake.closed is True + assert fake.terminate_count == 0 + assert inner._sandbox is None + + +@pytest.mark.asyncio +async def test_delete_closes_local_handle_on_terminate_path( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + client = module.TensorlakeSandboxClient() + session = await client.create( + manifest=Manifest(), + options=module.TensorlakeSandboxClientOptions(), + ) + inner = session._inner + fake = inner._sandbox + + await client.delete(session) + + assert fake.terminate_count == 1 + assert fake.closed is True + assert inner._sandbox is None + + +@pytest.mark.asyncio +async def test_persist_workspace_via_tar_round_trips_manifest( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000008"), + manifest=Manifest(entries={"notes.txt": File(content=b"payload")}), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-tar", + ) + fake = _FakeSandbox(sandbox_id="sandbox-tar") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + await session.start() + archive = await session.persist_workspace() + raw = archive.read() + assert isinstance(raw, bytes) and raw + + # Hydrate into a new sandbox and ensure files are restored. + other_state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000009"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-tar-restore", + ) + other_fake = _FakeSandbox(sandbox_id="sandbox-tar-restore") + other_session = module.TensorlakeSandboxSession.from_state(other_state, sandbox=other_fake) + await other_session.hydrate_workspace(io.BytesIO(raw)) + restored = await other_session.read(Path("notes.txt")) + assert restored.read() == b"payload" + + +@pytest.mark.asyncio +async def test_persist_workspace_via_tar_excludes_archive_when_root_is_tmp( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """When manifest.root is /tmp the tar archive lives inside the workspace tree. + + The archive file must be excluded from the tar command so GNU tar does not hit + its "file is the archive" error (exit code 1). + """ + module = _load_tensorlake_module(monkeypatch) + sid = uuid.UUID("00000000-0000-0000-0000-000000000040") + state = module.TensorlakeSandboxSessionState( + session_id=sid, + manifest=Manifest(root="/tmp", entries={"data.txt": File(content=b"val")}), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-tmp-root", + ) + fake = _FakeSandbox(sandbox_id="sandbox-tmp-root") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + await session.start() + await session.persist_workspace() + + tar_calls = [c for c in fake.run_calls if c["command"] == "tar"] + assert tar_calls, "expected at least one tar call" + last_tar_args = cast(list[str], tar_calls[-1]["args"]) + expected_archive_name = f"openai-agents-{sid.hex}.tar" + assert any(expected_archive_name in arg for arg in last_tar_args if "--exclude" in arg), ( + f"archive file {expected_archive_name!r} not excluded from tar args: {last_tar_args}" + ) + + +@pytest.mark.asyncio +async def test_persist_workspace_via_tar_nonzero_raises_archive_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000010"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-tar-failure", + ) + fake = _FakeSandbox(sandbox_id="sandbox-tar-failure") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + from agents.sandbox.errors import WorkspaceArchiveReadError + + fake.next_run_result = _FakeCommandResult(stderr="tar failed", exit_code=2) + + with pytest.raises(WorkspaceArchiveReadError) as exc_info: + await session.persist_workspace() + + assert "tar failed" in str(exc_info.value.__cause__) + + +@pytest.mark.asyncio +async def test_persist_workspace_via_checkpoint_returns_snapshot_ref( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-00000000000a"), + manifest=Manifest(entries={"notes.txt": File(content=b"snapshot-payload")}), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-checkpoint", + workspace_persistence="snapshot", + ) + fake = _FakeSandbox(sandbox_id="sandbox-checkpoint") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + await session.start() + archive = await session.persist_workspace() + raw = archive.read() + + assert raw.startswith(module._TENSORLAKE_SNAPSHOT_MAGIC) + snapshot_id = module._decode_tensorlake_snapshot_ref(raw) + assert snapshot_id == "snap-1" + assert _FakeSandbox.snapshots["snap-1"]["/workspace/notes.txt"] == b"snapshot-payload" + # Default matches the Tensorlake SDK: `local_ready` is enough to resume from the + # snapshot and avoids blocking on remote-storage upload. Callers needing a durable + # `snapshot_uri` opt in via `checkpoint_wait_until="completed"`. + assert fake.last_checkpoint_wait_until == "local_ready" + + +@pytest.mark.asyncio +async def test_persist_workspace_via_checkpoint_honors_wait_until_completed( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-00000000000b"), + manifest=Manifest(entries={"notes.txt": File(content=b"snapshot-payload")}), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-checkpoint-completed", + workspace_persistence="snapshot", + checkpoint_wait_until="completed", + ) + fake = _FakeSandbox(sandbox_id="sandbox-checkpoint-completed") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + await session.start() + await session.persist_workspace() + + assert fake.last_checkpoint_wait_until == "completed" + + +@pytest.mark.asyncio +async def test_hydrate_workspace_via_checkpoint_replaces_sandbox( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + + # First, take a checkpoint via the tar-style helper so it is registered. + initial = _FakeSandbox(sandbox_id="sandbox-source") + initial.files["/workspace/from-snapshot.txt"] = b"snap-data" + snap = await initial.checkpoint(checkpoint_type=_FakeCheckpointType.FILESYSTEM) + + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-00000000000b"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-pre-restore", + workspace_persistence="snapshot", + ) + pre_restore = _FakeSandbox(sandbox_id="sandbox-pre-restore") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=pre_restore) + + payload = module._encode_tensorlake_snapshot_ref(snapshot_id=snap.snapshot_id) + await session.hydrate_workspace(io.BytesIO(payload)) + + assert pre_restore.terminated is True + assert state.sandbox_id != "sandbox-pre-restore" + new_sandbox = _FakeSandbox.sandboxes[state.sandbox_id] + assert new_sandbox.files["/workspace/from-snapshot.txt"] == b"snap-data" + assert session._backend_lifecycle_finalized is False + # Regression: `delete()` must still terminate the live post-restore sandbox. + client = module.TensorlakeSandboxClient() + wrapped = client._wrap_session(session, instrumentation=None) + await client.delete(wrapped) + assert new_sandbox.terminated is True + + +@pytest.mark.asyncio +async def test_restore_from_checkpoint_raises_when_post_terminate_create_fails( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """If `AsyncSandbox.create(snapshot_id=...)` fails after the old sandbox is terminated, + `_restore_from_checkpoint` must surface a `WorkspaceArchiveWriteError` with the create + error as the cause — and the pre-restore sandbox must still be marked terminated. + """ + + module = _load_tensorlake_module(monkeypatch) + + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000060"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-pre-restore-fail", + workspace_persistence="snapshot", + ) + pre_restore = _FakeSandbox(sandbox_id="sandbox-pre-restore-fail") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=pre_restore) + + create_error = RuntimeError("snapshot restore failed") + _FakeSandbox.create_failures.append(create_error) + + payload = module._encode_tensorlake_snapshot_ref(snapshot_id="snap-missing") + + from agents.sandbox.errors import WorkspaceArchiveWriteError + + with pytest.raises(WorkspaceArchiveWriteError) as exc_info: + await session.hydrate_workspace(io.BytesIO(payload)) + + assert exc_info.value.__cause__ is create_error + assert exc_info.value.context["reason"] == "tensorlake_checkpoint_restore_failed" + assert exc_info.value.context["snapshot_id"] == "snap-missing" + assert pre_restore.terminated is True + + +@pytest.mark.asyncio +async def test_resume_reconnects_running_sandbox(monkeypatch: pytest.MonkeyPatch) -> None: + module = _load_tensorlake_module(monkeypatch) + + existing = _FakeSandbox(sandbox_id="sandbox-existing", status="running") + _FakeSandbox.sandboxes["sandbox-existing"] = existing + + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-00000000000c"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-existing", + ) + + client = module.TensorlakeSandboxClient() + session = await client.resume(state) + + assert _FakeSandbox.connect_calls == [{"sandbox_id": "sandbox-existing"}] + assert _FakeSandbox.create_calls == [] + assert session._inner.state.sandbox_id == "sandbox-existing" + + +@pytest.mark.asyncio +async def test_resume_fails_when_exposed_port_update_fails( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + + cause = RuntimeError("port update failed") + existing = _FakeSandbox(sandbox_id="sandbox-existing", status="running") + existing.update_failure = cause + _FakeSandbox.sandboxes["sandbox-existing"] = existing + + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-00000000003b"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-existing", + exposed_ports=(3000,), + ) + + client = module.TensorlakeSandboxClient() + + from agents.sandbox.errors import WorkspaceStartError + + with pytest.raises(WorkspaceStartError) as exc_info: + await client.resume(state) + + assert exc_info.value.__cause__ is cause + assert exc_info.value.context["reason"] == "tensorlake_exposed_ports_update_failed" + assert exc_info.value.context["sandbox_id"] == "sandbox-existing" + assert exc_info.value.context["ports"] == [3000] + assert _FakeSandbox.create_calls == [] + assert existing.terminate_count == 0 + + +@pytest.mark.asyncio +async def test_resume_cleans_up_recreated_sandbox_when_exposed_port_update_fails( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + + cause = RuntimeError("port update failed") + _FakeSandbox.update_failures.append(cause) + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-00000000003c"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-missing", + exposed_ports=(8080,), + ) + + client = module.TensorlakeSandboxClient() + + from agents.sandbox.errors import WorkspaceStartError + + with pytest.raises(WorkspaceStartError): + await client.resume(state) + + assert _FakeSandbox.sandboxes["tensorlake-sandbox-1"].terminate_count == 1 + + +@pytest.mark.asyncio +async def test_resume_creates_fresh_when_reconnect_fails( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-00000000000d"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-missing", + ) + + client = module.TensorlakeSandboxClient() + session = await client.resume(state) + + assert _FakeSandbox.connect_calls and _FakeSandbox.connect_calls[0]["sandbox_id"] == ( + "sandbox-missing" + ) + assert len(_FakeSandbox.create_calls) == 1 + new_id = session._inner.state.sandbox_id + assert new_id.startswith("tensorlake-sandbox-") + assert state.workspace_root_ready is False + + +@pytest.mark.asyncio +async def test_resume_creates_fresh_when_paused_resume_fails( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A failed `resume()` must not be reported as a preserved running sandbox.""" + + module = _load_tensorlake_module(monkeypatch) + + existing = _FakeSandbox(sandbox_id="sandbox-paused", status="suspended") + existing.resume_failure = RuntimeError("sandbox expired") + _FakeSandbox.sandboxes["sandbox-paused"] = existing + + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-00000000000f"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-paused", + pause_on_exit=True, + ) + + client = module.TensorlakeSandboxClient() + session = await client.resume(state) + + assert len(_FakeSandbox.create_calls) == 1 + assert _FakeSandbox.create_calls[0]["name"] == ( + "openai-agents-0000000000000000000000000000000f" + ) + new_id = session._inner.state.sandbox_id + assert new_id != "sandbox-paused" + assert new_id.startswith("tensorlake-sandbox-") + assert session._inner.state.name == "openai-agents-0000000000000000000000000000000f" + assert state.workspace_root_ready is False + assert session._inner._workspace_state_preserved_on_start() is False + assert session._inner._system_state_preserved_on_start() is False + # With pause_on_exit=True the user expects backend state to survive resume; a probe + # failure after a successful connect() is ambiguous and may be transient against a + # still-suspended sandbox. Terminating here would destroy potentially recoverable + # workspace state (especially with a Noop or stale snapshot), so we only release the + # local handle and let the backend reclaim the sandbox via its own timeout. + assert existing.terminate_count == 0 + assert existing.close_count == 1 + + +@pytest.mark.asyncio +async def test_resume_closes_abandoned_handle_when_status_not_running( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module = _load_tensorlake_module(monkeypatch) + + existing = _FakeSandbox(sandbox_id="sandbox-dead", status="terminated") + _FakeSandbox.sandboxes["sandbox-dead"] = existing + + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000011"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-dead", + ) + + client = module.TensorlakeSandboxClient() + session = await client.resume(state) + + assert len(_FakeSandbox.create_calls) == 1 + assert existing.terminate_count == 1 + assert session._inner._workspace_state_preserved_on_start() is False + + +@pytest.mark.asyncio +async def test_resume_recreates_directly_from_snapshot_when_reconnect_fails( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + """Snapshot-mode resume must skip the throwaway empty-sandbox create when reconnect fails.""" + + module = _load_tensorlake_module(monkeypatch) + + snapshot = LocalSnapshot(id="snap", base_path=tmp_path) + payload = module._encode_tensorlake_snapshot_ref(snapshot_id="snap-stored") + await snapshot.persist(io.BytesIO(payload)) + _FakeSandbox.snapshots["snap-stored"] = {"/workspace/from-snapshot.txt": b"snap-data"} + + existing = _FakeSandbox(sandbox_id="sandbox-paused", status="suspended") + existing.resume_failure = RuntimeError("sandbox expired") + _FakeSandbox.sandboxes["sandbox-paused"] = existing + + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000010"), + manifest=Manifest(), + snapshot=snapshot, + sandbox_id="sandbox-paused", + pause_on_exit=True, + workspace_persistence="snapshot", + ) + + client = module.TensorlakeSandboxClient() + session = await client.resume(state) + + assert len(_FakeSandbox.create_calls) == 1 + assert _FakeSandbox.create_calls[0].get("snapshot_id") == "snap-stored" + assert _FakeSandbox.create_calls[0]["name"] == ( + "openai-agents-00000000000000000000000000000010" + ) + new_id = session._inner.state.sandbox_id + assert new_id != "sandbox-paused" + new_sandbox = _FakeSandbox.sandboxes[new_id] + assert new_sandbox.files["/workspace/from-snapshot.txt"] == b"snap-data" + assert state.workspace_root_ready is True + assert session._inner._workspace_state_preserved_on_start() is True + assert session._inner._system_state_preserved_on_start() is True + + +def test_serialize_session_state_round_trips(monkeypatch: pytest.MonkeyPatch) -> None: + module = _load_tensorlake_module(monkeypatch) + + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-00000000000e"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-serialize", + image="custom", + cpus=4.0, + memory_mb=4096, + disk_mb=20480, + timeout_secs=120, + name="serialize", + allow_internet_access=False, + allow_out=("10.0.0.0/8",), + deny_out=("example.com",), + workspace_persistence="snapshot", + checkpoint_mode="memory", + ) + client = module.TensorlakeSandboxClient() + payload = client.serialize_session_state(state) + restored = client.deserialize_session_state(payload) + + assert isinstance(restored, module.TensorlakeSandboxSessionState) + assert restored.image == "custom" + assert restored.cpus == 4.0 + assert restored.disk_mb == 20480 + assert restored.workspace_persistence == "snapshot" + assert restored.checkpoint_mode == "memory" + + +@pytest.mark.asyncio +async def test_restore_from_checkpoint_marks_system_state_preserved( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """After _restore_from_checkpoint, system state must be flagged as preserved. + + When resume() cannot read a RemoteSnapshot without dependencies it creates a fresh + empty sandbox and sets _start_system_state_preserved=False. hydrate_workspace() + later replaces that sandbox with a full Tensorlake checkpoint (which already contains + OS users and groups). The base start flow must not re-run groupadd/useradd against + accounts that are already present in the restored image. + """ + module = _load_tensorlake_module(monkeypatch) + + # Seed a snapshot so the checkpoint restore can find it. + initial = _FakeSandbox(sandbox_id="sandbox-snap-src") + initial.files["/workspace/data.txt"] = b"hello" + snap = await initial.checkpoint(checkpoint_type=_FakeCheckpointType.FILESYSTEM) + + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000030"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-before-restore", + workspace_persistence="snapshot", + ) + fake = _FakeSandbox(sandbox_id="sandbox-before-restore") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + # Simulate the fresh-sandbox case: resume() could not read the snapshot and set + # preserved=False before handing the session to start(). + session._set_start_state_preserved(False, system=False) + + payload = module._encode_tensorlake_snapshot_ref(snapshot_id=snap.snapshot_id) + await session.hydrate_workspace(io.BytesIO(payload)) + + assert session.should_provision_manifest_accounts_on_resume() is False + + +@pytest.mark.asyncio +async def test_restore_from_memory_checkpoint_strips_snapshot_owned_kwargs( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Memory-checkpoint restore must omit fields the backend sources from the snapshot. + + Tensorlake docs (https://docs.tensorlake.ai/sandboxes/snapshots) say image, resources + (CPUs, memory, disk), entrypoint, and secrets come from the memory snapshot and cannot + be passed at restore time. Forwarding them from session state — which they will be set + in for any non-default user config — would make `AsyncSandbox.create(snapshot_id=...)` + reject the call. + """ + module = _load_tensorlake_module(monkeypatch) + + initial = _FakeSandbox(sandbox_id="sandbox-snap-src") + snap = await initial.checkpoint(checkpoint_type=_FakeCheckpointType.MEMORY) + + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000061"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-before-restore", + image="custom-image", + cpus=4.0, + memory_mb=4096, + disk_mb=20480, + secret_names=("OPENAI_API_KEY",), + entrypoint=("/bin/bash",), + allow_internet_access=False, + timeout_secs=120, + workspace_persistence="snapshot", + checkpoint_mode="memory", + ) + fake = _FakeSandbox(sandbox_id="sandbox-before-restore") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + _FakeSandbox.create_calls.clear() + payload = module._encode_tensorlake_snapshot_ref(snapshot_id=snap.snapshot_id) + await session.hydrate_workspace(io.BytesIO(payload)) + + assert len(_FakeSandbox.create_calls) == 1 + create_kwargs = _FakeSandbox.create_calls[0] + assert create_kwargs.get("snapshot_id") == snap.snapshot_id + for forbidden in ("image", "cpus", "memory_mb", "disk_mb", "entrypoint", "secret_names"): + assert forbidden not in create_kwargs, ( + f"{forbidden} must be omitted for memory-checkpoint restore" + ) + # Fields that the docs do not restrict on memory restore must still flow through. + assert create_kwargs["allow_internet_access"] is False + assert create_kwargs["timeout_secs"] == 120 + + +@pytest.mark.asyncio +async def test_restore_from_filesystem_checkpoint_keeps_resource_kwargs( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Filesystem-checkpoint restore must still forward image/resources/entrypoint/secrets. + + Per Tensorlake docs, the memory-snapshot restriction does not apply to filesystem + snapshots — resources are explicitly modifiable on restore. The kwargs filter must + only kick in for `checkpoint_mode="memory"`. + """ + module = _load_tensorlake_module(monkeypatch) + + initial = _FakeSandbox(sandbox_id="sandbox-snap-src-fs") + snap = await initial.checkpoint(checkpoint_type=_FakeCheckpointType.FILESYSTEM) + + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000062"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-before-restore-fs", + image="custom-image", + cpus=4.0, + memory_mb=4096, + disk_mb=20480, + secret_names=("OPENAI_API_KEY",), + entrypoint=("/bin/bash",), + workspace_persistence="snapshot", + checkpoint_mode="filesystem", + ) + fake = _FakeSandbox(sandbox_id="sandbox-before-restore-fs") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + _FakeSandbox.create_calls.clear() + payload = module._encode_tensorlake_snapshot_ref(snapshot_id=snap.snapshot_id) + await session.hydrate_workspace(io.BytesIO(payload)) + + assert len(_FakeSandbox.create_calls) == 1 + create_kwargs = _FakeSandbox.create_calls[0] + assert create_kwargs.get("snapshot_id") == snap.snapshot_id + assert create_kwargs["image"] == "custom-image" + assert create_kwargs["cpus"] == 4.0 + assert create_kwargs["memory_mb"] == 4096 + assert create_kwargs["disk_mb"] == 20480 + assert create_kwargs["secret_names"] == ["OPENAI_API_KEY"] + assert create_kwargs["entrypoint"] == ["/bin/bash"] + + +@pytest.mark.asyncio +async def test_resume_recreate_from_memory_checkpoint_strips_snapshot_owned_kwargs( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + """The resume-recreate path must also strip memory-snapshot-owned kwargs. + + `client.resume(state)` falls back to `AsyncSandbox.create(snapshot_id=...)` when the + paused sandbox cannot be reconnected. That second create site has to apply the same + memory-snapshot filtering as `_restore_from_checkpoint`, otherwise any non-default user + config breaks the recreate flow. + """ + module = _load_tensorlake_module(monkeypatch) + + snapshot = LocalSnapshot(id="snap", base_path=tmp_path) + payload = module._encode_tensorlake_snapshot_ref(snapshot_id="snap-stored-mem") + await snapshot.persist(io.BytesIO(payload)) + _FakeSandbox.snapshots["snap-stored-mem"] = {} + + existing = _FakeSandbox(sandbox_id="sandbox-paused-mem", status="suspended") + existing.resume_failure = RuntimeError("sandbox expired") + _FakeSandbox.sandboxes["sandbox-paused-mem"] = existing + + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000063"), + manifest=Manifest(), + snapshot=snapshot, + sandbox_id="sandbox-paused-mem", + image="custom-image", + cpus=2.0, + memory_mb=2048, + secret_names=("API_KEY",), + entrypoint=("/bin/sh",), + pause_on_exit=True, + workspace_persistence="snapshot", + checkpoint_mode="memory", + ) + + _FakeSandbox.create_calls.clear() + client = module.TensorlakeSandboxClient() + await client.resume(state) + + assert len(_FakeSandbox.create_calls) == 1 + create_kwargs = _FakeSandbox.create_calls[0] + assert create_kwargs.get("snapshot_id") == "snap-stored-mem" + for forbidden in ("image", "cpus", "memory_mb", "disk_mb", "entrypoint", "secret_names"): + assert forbidden not in create_kwargs + + +@pytest.mark.asyncio +async def test_after_start_reinstalls_helpers_when_sandbox_id_changes( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """_after_start must reinstall runtime helpers when sandbox_id changed mid-start. + + checkpoint restore replaces the sandbox and sandbox_id during start(); the helper + cache key becomes stale. _after_start() detects the mismatch and re-runs + _ensure_runtime_helpers() so the new backend has the helpers installed. + """ + module = _load_tensorlake_module(monkeypatch) + + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000031"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-new", + workspace_persistence="tar", + ) + fake = _FakeSandbox(sandbox_id="sandbox-new") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + # Simulate helpers installed on the pre-restore sandbox: cache key is stale. + session._runtime_helper_cache_key = "sandbox-old" + session._runtime_helpers_installed = set() + + await session._after_start() + + assert session._runtime_helper_cache_key == "sandbox-new" + assert any(c["command"] == "sh" for c in fake.run_calls) + + +@pytest.mark.asyncio +async def test_running_returns_false_when_workspace_not_ready( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """running() must return False when the workspace has not been set up yet. + + A Tensorlake sandbox can be in RUNNING state (backend alive) while the workspace + hasn't been provisioned. Callers must not treat such a session as usable. + """ + module = _load_tensorlake_module(monkeypatch) + + state = module.TensorlakeSandboxSessionState( + session_id=uuid.UUID("00000000-0000-0000-0000-000000000032"), + manifest=Manifest(), + snapshot=NoopSnapshot(id="snap"), + sandbox_id="sandbox-not-ready", + workspace_persistence="tar", + ) + # Backend is running but workspace_root_ready is False (before start()). + fake = _FakeSandbox(sandbox_id="sandbox-not-ready", status="running") + session = module.TensorlakeSandboxSession.from_state(state, sandbox=fake) + + assert await session.running() is False + + state.workspace_root_ready = True + assert await session.running() is True diff --git a/tests/sandbox/test_compatibility_guards.py b/tests/sandbox/test_compatibility_guards.py index 5a11e5bf77..b055937349 100644 --- a/tests/sandbox/test_compatibility_guards.py +++ b/tests/sandbox/test_compatibility_guards.py @@ -333,6 +333,17 @@ def test_core_sandbox_public_export_surface_is_stable() -> None: "VercelSandboxSessionState", }, ), + ( + "agents.extensions.sandbox.tensorlake", + { + "DEFAULT_TENSORLAKE_WORKSPACE_ROOT", + "TensorlakeSandboxClient", + "TensorlakeSandboxClientOptions", + "TensorlakeSandboxSession", + "TensorlakeSandboxSessionState", + "TensorlakeSandboxTimeouts", + }, + ), ], ) def test_extension_sandbox_package_export_surfaces_are_stable( @@ -510,6 +521,39 @@ def test_optional_sandbox_dataclass_constructor_field_order_is_stable( "network_policy", ), ), + ( + "agents.extensions.sandbox.tensorlake", + "TensorlakeSandboxClientOptions", + ( + "image", + "cpus", + "memory_mb", + "timeout_secs", + "name", + "secret_names", + "envs", + "allow_internet_access", + "allow_out", + "deny_out", + "exposed_ports", + "allow_unauthenticated_port_access", + "pause_on_exit", + "workspace_persistence", + "checkpoint_mode", + "checkpoint_timeout_s", + "timeouts", + "disk_mb", + "entrypoint", + "startup_timeout", + "proxy_url", + "api_url", + "namespace", + "organization_id", + "project_id", + "routing_hint", + "checkpoint_wait_until", + ), + ), ], ) def test_optional_sandbox_client_options_positional_field_order_is_stable( @@ -745,6 +789,47 @@ def test_optional_sandbox_client_options_positional_field_order_is_stable( "network_policy", ), ), + ( + "agents.extensions.sandbox.tensorlake", + "TensorlakeSandboxSessionState", + ( + "type", + "session_id", + "snapshot", + "manifest", + "exposed_ports", + "snapshot_fingerprint", + "snapshot_fingerprint_version", + "workspace_root_ready", + "sandbox_id", + "name", + "image", + "cpus", + "memory_mb", + "timeout_secs", + "secret_names", + "base_envs", + "allow_internet_access", + "allow_out", + "deny_out", + "allow_unauthenticated_port_access", + "pause_on_exit", + "workspace_persistence", + "checkpoint_mode", + "checkpoint_timeout_s", + "timeouts", + "disk_mb", + "entrypoint", + "startup_timeout", + "proxy_url", + "api_url", + "namespace", + "organization_id", + "project_id", + "routing_hint", + "checkpoint_wait_until", + ), + ), ], ) def test_sandbox_session_state_field_order_is_stable( @@ -786,6 +871,12 @@ def test_sandbox_session_state_field_order_is_stable( ("agents.extensions.sandbox.daytona", "DaytonaSandboxClientOptions", (), "daytona"), ("agents.extensions.sandbox.runloop", "RunloopSandboxClientOptions", (), "runloop"), ("agents.extensions.sandbox.vercel", "VercelSandboxClientOptions", (), "vercel"), + ( + "agents.extensions.sandbox.tensorlake", + "TensorlakeSandboxClientOptions", + (), + "tensorlake", + ), ], ) def test_optional_sandbox_client_options_json_round_trip_preserves_type( @@ -851,6 +942,11 @@ def test_optional_sandbox_client_options_json_round_trip_preserves_type( "VercelSandboxSessionState", {"sandbox_id": "sandbox-123"}, ), + ( + "agents.extensions.sandbox.tensorlake", + "TensorlakeSandboxSessionState", + {"sandbox_id": "sandbox-123"}, + ), ], ) def test_optional_sandbox_session_state_json_round_trip_preserves_type( diff --git a/uv.lock b/uv.lock index 3e5cb31b70..1843af847b 100644 --- a/uv.lock +++ b/uv.lock @@ -9,7 +9,8 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-05-09T02:05:26Z" +exclude-newer = "2026-05-16T03:20:10.814636Z" +exclude-newer-span = "P7D" [[package]] name = "aiofiles" @@ -1349,16 +1350,16 @@ wheels = [ [[package]] name = "grpcio-status" -version = "1.67.1" +version = "1.76.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "googleapis-common-protos" }, { name = "grpcio" }, { name = "protobuf" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/be/c7/fe0e79a80ac6346e0c6c0a24e9e3cbc3ae1c2a009acffb59eab484a6f69b/grpcio_status-1.67.1.tar.gz", hash = "sha256:2bf38395e028ceeecfd8866b081f61628114b384da7d51ae064ddc8d766a5d11", size = 13673, upload-time = "2024-10-29T06:30:21.787Z" } +sdist = { url = "https://files.pythonhosted.org/packages/3f/46/e9f19d5be65e8423f886813a2a9d0056ba94757b0c5007aa59aed1a961fa/grpcio_status-1.76.0.tar.gz", hash = "sha256:25fcbfec74c15d1a1cb5da3fab8ee9672852dc16a5a9eeb5baf7d7a9952943cd", size = 13679, upload-time = "2025-10-21T16:28:52.545Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/05/18/56999a1da3577d8ccc8698a575d6638e15fe25650cc88b2ce0a087f180b9/grpcio_status-1.67.1-py3-none-any.whl", hash = "sha256:16e6c085950bdacac97c779e6a502ea671232385e6e37f258884d6883392c2bd", size = 14427, upload-time = "2024-10-29T06:27:38.228Z" }, + { url = "https://files.pythonhosted.org/packages/8c/cc/27ba60ad5a5f2067963e6a858743500df408eb5855e98be778eaef8c9b02/grpcio_status-1.76.0-py3-none-any.whl", hash = "sha256:380568794055a8efbbd8871162df92012e0228a5f6dffaf57f2a00c534103b18", size = 14425, upload-time = "2025-10-21T16:28:40.853Z" }, ] [[package]] @@ -1448,6 +1449,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, ] +[package.optional-dependencies] +http2 = [ + { name = "h2" }, +] + [[package]] name = "httpx-sse" version = "0.4.1" @@ -2501,6 +2507,9 @@ temporal = [ { name = "temporalio" }, { name = "textual" }, ] +tensorlake = [ + { name = "tensorlake" }, +] vercel = [ { name = "vercel" }, ] @@ -2576,6 +2585,7 @@ requires-dist = [ { name = "runloop-api-client", marker = "extra == 'runloop'", specifier = ">=1.16.0,<2.0.0" }, { name = "sqlalchemy", marker = "extra == 'sqlalchemy'", specifier = ">=2.0" }, { name = "temporalio", marker = "extra == 'temporal'", specifier = "==1.26.0" }, + { name = "tensorlake", marker = "extra == 'tensorlake'", specifier = ">=0.5.11" }, { name = "textual", marker = "extra == 'temporal'", specifier = ">=8.2.3,<8.3" }, { name = "types-requests", specifier = ">=2.0,<3" }, { name = "typing-extensions", specifier = ">=4.12.2,<5" }, @@ -2584,7 +2594,7 @@ requires-dist = [ { name = "websockets", marker = "extra == 'realtime'", specifier = ">=15.0,<17" }, { name = "websockets", marker = "extra == 'voice'", specifier = ">=15.0,<17" }, ] -provides-extras = ["voice", "viz", "litellm", "any-llm", "realtime", "sqlalchemy", "encrypt", "redis", "dapr", "mongodb", "docker", "blaxel", "daytona", "cloudflare", "e2b", "modal", "runloop", "vercel", "s3", "temporal"] +provides-extras = ["voice", "viz", "litellm", "any-llm", "realtime", "sqlalchemy", "encrypt", "redis", "dapr", "mongodb", "docker", "blaxel", "daytona", "cloudflare", "e2b", "modal", "runloop", "tensorlake", "vercel", "s3", "temporal"] [package.metadata.requires-dev] dev = [ @@ -2909,16 +2919,17 @@ wheels = [ [[package]] name = "protobuf" -version = "5.29.5" +version = "6.33.6" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/43/29/d09e70352e4e88c9c7a198d5645d7277811448d76c23b00345670f7c8a38/protobuf-5.29.5.tar.gz", hash = "sha256:bc1463bafd4b0929216c35f437a8e28731a2b7fe3d98bb77a600efced5a15c84", size = 425226, upload-time = "2025-05-28T23:51:59.82Z" } +sdist = { url = "https://files.pythonhosted.org/packages/66/70/e908e9c5e52ef7c3a6c7902c9dfbb34c7e29c25d2f81ade3856445fd5c94/protobuf-6.33.6.tar.gz", hash = "sha256:a6768d25248312c297558af96a9f9c929e8c4cee0659cb07e780731095f38135", size = 444531, upload-time = "2026-03-18T19:05:00.988Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5f/11/6e40e9fc5bba02988a214c07cf324595789ca7820160bfd1f8be96e48539/protobuf-5.29.5-cp310-abi3-win32.whl", hash = "sha256:3f1c6468a2cfd102ff4703976138844f78ebd1fb45f49011afc5139e9e283079", size = 422963, upload-time = "2025-05-28T23:51:41.204Z" }, - { url = "https://files.pythonhosted.org/packages/81/7f/73cefb093e1a2a7c3ffd839e6f9fcafb7a427d300c7f8aef9c64405d8ac6/protobuf-5.29.5-cp310-abi3-win_amd64.whl", hash = "sha256:3f76e3a3675b4a4d867b52e4a5f5b78a2ef9565549d4037e06cf7b0942b1d3fc", size = 434818, upload-time = "2025-05-28T23:51:44.297Z" }, - { url = "https://files.pythonhosted.org/packages/dd/73/10e1661c21f139f2c6ad9b23040ff36fee624310dc28fba20d33fdae124c/protobuf-5.29.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e38c5add5a311f2a6eb0340716ef9b039c1dfa428b28f25a7838ac329204a671", size = 418091, upload-time = "2025-05-28T23:51:45.907Z" }, - { url = "https://files.pythonhosted.org/packages/6c/04/98f6f8cf5b07ab1294c13f34b4e69b3722bb609c5b701d6c169828f9f8aa/protobuf-5.29.5-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:fa18533a299d7ab6c55a238bf8629311439995f2e7eca5caaff08663606e9015", size = 319824, upload-time = "2025-05-28T23:51:47.545Z" }, - { url = "https://files.pythonhosted.org/packages/85/e4/07c80521879c2d15f321465ac24c70efe2381378c00bf5e56a0f4fbac8cd/protobuf-5.29.5-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:63848923da3325e1bf7e9003d680ce6e14b07e55d0473253a690c3a8b8fd6e61", size = 319942, upload-time = "2025-05-28T23:51:49.11Z" }, - { url = "https://files.pythonhosted.org/packages/7e/cc/7e77861000a0691aeea8f4566e5d3aa716f2b1dece4a24439437e41d3d25/protobuf-5.29.5-py3-none-any.whl", hash = "sha256:6cf42630262c59b2d8de33954443d94b746c952b01434fc58a417fdbd2e84bd5", size = 172823, upload-time = "2025-05-28T23:51:58.157Z" }, + { url = "https://files.pythonhosted.org/packages/fc/9f/2f509339e89cfa6f6a4c4ff50438db9ca488dec341f7e454adad60150b00/protobuf-6.33.6-cp310-abi3-win32.whl", hash = "sha256:7d29d9b65f8afef196f8334e80d6bc1d5d4adedb449971fefd3723824e6e77d3", size = 425739, upload-time = "2026-03-18T19:04:48.373Z" }, + { url = "https://files.pythonhosted.org/packages/76/5d/683efcd4798e0030c1bab27374fd13a89f7c2515fb1f3123efdfaa5eab57/protobuf-6.33.6-cp310-abi3-win_amd64.whl", hash = "sha256:0cd27b587afca21b7cfa59a74dcbd48a50f0a6400cfb59391340ad729d91d326", size = 437089, upload-time = "2026-03-18T19:04:50.381Z" }, + { url = "https://files.pythonhosted.org/packages/5c/01/a3c3ed5cd186f39e7880f8303cc51385a198a81469d53d0fdecf1f64d929/protobuf-6.33.6-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:9720e6961b251bde64edfdab7d500725a2af5280f3f4c87e57c0208376aa8c3a", size = 427737, upload-time = "2026-03-18T19:04:51.866Z" }, + { url = "https://files.pythonhosted.org/packages/ee/90/b3c01fdec7d2f627b3a6884243ba328c1217ed2d978def5c12dc50d328a3/protobuf-6.33.6-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:e2afbae9b8e1825e3529f88d514754e094278bb95eadc0e199751cdd9a2e82a2", size = 324610, upload-time = "2026-03-18T19:04:53.096Z" }, + { url = "https://files.pythonhosted.org/packages/9b/ca/25afc144934014700c52e05103c2421997482d561f3101ff352e1292fb81/protobuf-6.33.6-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:c96c37eec15086b79762ed265d59ab204dabc53056e3443e702d2681f4b39ce3", size = 339381, upload-time = "2026-03-18T19:04:54.616Z" }, + { url = "https://files.pythonhosted.org/packages/16/92/d1e32e3e0d894fe00b15ce28ad4944ab692713f2e7f0a99787405e43533a/protobuf-6.33.6-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:e9db7e292e0ab79dd108d7f1a94fe31601ce1ee3f7b79e0692043423020b0593", size = 323436, upload-time = "2026-03-18T19:04:55.768Z" }, + { url = "https://files.pythonhosted.org/packages/c4/72/02445137af02769918a93807b2b7890047c32bfb9f90371cbc12688819eb/protobuf-6.33.6-py3-none-any.whl", hash = "sha256:77179e006c476e69bf8e8ce866640091ec42e1beb80b213c3900006ecfba6901", size = 170656, upload-time = "2026-03-18T19:04:59.826Z" }, ] [[package]] @@ -3966,6 +3977,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fd/9b/c50840a26af3587c0c8d9af04d9976743e22496996dc1a377efc75dcd316/temporalio-1.26.0-cp310-abi3-win_amd64.whl", hash = "sha256:1c4a0d82f0a3796cbf78864c799f8dca0b94cdaec68e7b8b224c859005686ec4", size = 14525849, upload-time = "2026-04-15T23:42:57.589Z" }, ] +[[package]] +name = "tensorlake" +version = "0.5.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "grpcio" }, + { name = "httpx", extra = ["http2"] }, + { name = "protobuf" }, + { name = "pydantic" }, + { name = "websocket-client" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/13/d1/180a95f3ad169b0d728f62c9ebacf731ed8418781b9e8b81c05bb9d80b98/tensorlake-0.5.12.tar.gz", hash = "sha256:cc855148ce1de5d654d0df55cec43169b6bef3e74234ca0576da49178634bb1f", size = 2285581, upload-time = "2026-05-15T21:37:17.5Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d2/24ab49b893a2597ece6262fb474862b8d11321d3b49578193854c6731790/tensorlake-0.5.12-py3-none-macosx_11_0_arm64.whl", hash = "sha256:24c3cf48c87011997b45df52799acfebd657a8d6dcc9315f477fe9d2245d0158", size = 15077600, upload-time = "2026-05-15T21:37:04.533Z" }, + { url = "https://files.pythonhosted.org/packages/f6/cc/300811aa8a89165cdfc511de940ee2660c8e69a833970afd261372b1596b/tensorlake-0.5.12-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdc27c3232694febabe4ba83f7cd403bf1db9b6edc71aac327ba3fe5b77d2c74", size = 15630351, upload-time = "2026-05-15T21:37:08.177Z" }, + { url = "https://files.pythonhosted.org/packages/fc/5f/fe38e288bd3995d52ec8bb2cc7e76950bff9fa1bece1f3f9126dc533db3f/tensorlake-0.5.12-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a76409afd29f4d9568f25c1605410d8d8528457bf1f2354a53019b30125369e5", size = 16249081, upload-time = "2026-05-15T21:37:11.578Z" }, + { url = "https://files.pythonhosted.org/packages/05/26/2befc804be146eae669c8db76d66ee925b56bfcab39c2553a2ec1a0ce0e9/tensorlake-0.5.12-py3-none-win_amd64.whl", hash = "sha256:2dacfb141c897d853cf7374d3851ec771a7c221c89e0ee7e3556876fdc4c484e", size = 17114345, upload-time = "2026-05-15T21:37:14.919Z" }, +] + [[package]] name = "testcontainers" version = "4.12.0" @@ -4471,6 +4501,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/eb/d8/0d1d2e9d3fabcf5d6840362adcf05f8cf3cd06a73358140c3a97189238ae/wcmatch-10.1-py3-none-any.whl", hash = "sha256:5848ace7dbb0476e5e55ab63c6bbd529745089343427caa5537f230cc01beb8a", size = 39854, upload-time = "2025-06-22T19:14:00.978Z" }, ] +[[package]] +name = "websocket-client" +version = "1.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/41/aa4bf9664e4cda14c3b39865b12251e8e7d239f4cd0e3cc1b6c2ccde25c1/websocket_client-1.9.0.tar.gz", hash = "sha256:9e813624b6eb619999a97dc7958469217c3176312b3a16a4bd1bc7e08a46ec98", size = 70576, upload-time = "2025-10-07T21:16:36.495Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/34/db/b10e48aa8fff7407e67470363eac595018441cf32d5e1001567a7aeba5d2/websocket_client-1.9.0-py3-none-any.whl", hash = "sha256:af248a825037ef591efbf6ed20cc5faa03d3b47b9e5a2230a529eeee1c1fc3ef", size = 82616, upload-time = "2025-10-07T21:16:34.951Z" }, +] + [[package]] name = "websockets" version = "15.0.1"