Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 151 additions & 26 deletions examples/sandbox_11_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,33 @@

load_dotenv()

SNAPSHOT_EXPIRATION_MS = 86_400_000
SNAPSHOT_EXPIRATION_DRIFT_MS = 1_000


def assert_time_near(*, actual: int | None, expected: int, drift: int, label: str) -> None:
assert actual is not None, f"{label} should report an expires_at timestamp"
assert abs(actual - expected) <= drift, f"{label} should be within {drift}ms of {expected}"


def assert_snapshot_expires_at_expected_time(
*, created_at: int, expires_at: int | None, label: str
) -> None:
assert_time_near(
actual=expires_at,
expected=created_at + SNAPSHOT_EXPIRATION_MS,
drift=SNAPSHOT_EXPIRATION_DRIFT_MS,
label=label,
)


async def async_demo() -> None:
print("=" * 60)
print("ASYNC SNAPSHOT EXAMPLE")
print("=" * 60)

async_snapshot_ids: list[str] = []

# Step 1: Create a sandbox and set it up
print("\n[1] Creating initial sandbox...")
sandbox1 = await AsyncSandbox.create(timeout=120_000)
Expand All @@ -46,12 +67,20 @@ async def async_demo() -> None:
config = await sandbox1.read_file("config.json")
print(f" Created config.json: {config.decode()}")

# Step 2: Create a snapshot (this STOPS the sandbox)
print("\n[3] Creating snapshot...")
snapshot = await sandbox1.snapshot()
# Step 2: Create a snapshot with an explicit expiration (this STOPS the sandbox)
print("\n[3] Creating snapshot with expiration=86400000...")
snapshot = await sandbox1.snapshot(expiration=SNAPSHOT_EXPIRATION_MS)
async_snapshot_ids.append(snapshot.snapshot_id)
print(f" Snapshot ID: {snapshot.snapshot_id}")
print(f" Status: {snapshot.status}")
print(f" Created At: {snapshot.created_at}")
print(f" Expires At: {snapshot.expires_at}")
print(f" Sandbox status after snapshot: {sandbox1.status}")
assert_snapshot_expires_at_expected_time(
created_at=snapshot.created_at,
expires_at=snapshot.expires_at,
label="expiring async snapshot",
)

finally:
await sandbox1.client.aclose()
Expand All @@ -76,27 +105,74 @@ async def async_demo() -> None:

assert config2 == b'{"version": "1.0", "env": "async"}', "config.json mismatch!"
assert users2 == b"alice\nbob\ncharlie", "users.txt mismatch!"
print("\n✓ Async assertions passed!")
print("\n✓ Async restore assertions passed!")

# Create a second snapshot with expiration=0 to preserve it indefinitely.
print("\n[6] Creating second snapshot with expiration=0...")
persistent_snapshot = await sandbox2.snapshot(expiration=0)
async_snapshot_ids.append(persistent_snapshot.snapshot_id)
print(f" Snapshot ID: {persistent_snapshot.snapshot_id}")
print(f" Created At: {persistent_snapshot.created_at}")
print(f" Expires At: {persistent_snapshot.expires_at}")
assert persistent_snapshot.expires_at is None, (
"persistent async snapshot should not have an expires_at timestamp"
)

finally:
await sandbox2.stop()
await sandbox2.client.aclose()

# Step 5: Retrieve and delete the snapshot
print("\n[6] Retrieving and deleting snapshot...")
fetched = await AsyncSnapshot.get(snapshot_id=snapshot.snapshot_id)
try:
await fetched.delete()
print(f" Deleted snapshot, status: {fetched.status}")
finally:
await fetched.client.aclose()
# Step 5: List snapshots and confirm the new snapshots are discoverable.
print("\n[7] Listing recent snapshots...")
since = snapshot.created_at - 1
pager = AsyncSnapshot.list(limit=10, since=since)
first_page = await pager
first_page_ids = [listed.id for listed in first_page.snapshots]
print(f" First page snapshot IDs: {first_page_ids}")

paged_ids: list[list[str]] = []
found_ids: set[str] = set()
async for page in pager.iter_pages():
page_ids = [listed.id for listed in page.snapshots]
paged_ids.append(page_ids)
found_ids.update(page_ids)
if all(snapshot_id in found_ids for snapshot_id in async_snapshot_ids):
break

print(f" Visited pages: {paged_ids}")
assert all(snapshot_id in found_ids for snapshot_id in async_snapshot_ids), (
"Did not find all async snapshots in AsyncSnapshot.list() results"
)

item_ids: list[str] = []
async for listed in AsyncSnapshot.list(limit=10, since=since).iter_items():
item_ids.append(listed.id)
if all(snapshot_id in item_ids for snapshot_id in async_snapshot_ids):
break

print(f" Iterated snapshot IDs: {item_ids}")
assert all(snapshot_id in item_ids for snapshot_id in async_snapshot_ids), (
"Did not find all async snapshots while iterating AsyncSnapshot.list() items"
)

# Step 6: Retrieve and delete the snapshots
print("\n[8] Retrieving and deleting snapshots...")
for snapshot_id in async_snapshot_ids:
fetched = await AsyncSnapshot.get(snapshot_id=snapshot_id)
try:
await fetched.delete()
print(f" Deleted snapshot {snapshot_id}, status: {fetched.status}")
assert fetched.status == "deleted", "async snapshot should be deleted"
finally:
await fetched.client.aclose()


def sync_demo() -> None:
print("\n" + "=" * 60)
print("SYNC SNAPSHOT EXAMPLE")
print("=" * 60)

sync_snapshot_ids: list[str] = []

# Step 1: Create a sandbox and set it up
print("\n[1] Creating initial sandbox...")
sandbox1 = Sandbox.create(timeout=120_000)
Expand All @@ -115,12 +191,20 @@ def sync_demo() -> None:
config = sandbox1.read_file("config.json")
print(f" Created config.json: {config.decode()}")

# Step 2: Create a snapshot (this STOPS the sandbox)
print("\n[3] Creating snapshot...")
snapshot = sandbox1.snapshot()
# Step 2: Create a snapshot with an explicit expiration (this STOPS the sandbox)
print("\n[3] Creating snapshot with expiration=86400000...")
snapshot = sandbox1.snapshot(expiration=SNAPSHOT_EXPIRATION_MS)
sync_snapshot_ids.append(snapshot.snapshot_id)
print(f" Snapshot ID: {snapshot.snapshot_id}")
print(f" Status: {snapshot.status}")
print(f" Created At: {snapshot.created_at}")
print(f" Expires At: {snapshot.expires_at}")
print(f" Sandbox status after snapshot: {sandbox1.status}")
assert_snapshot_expires_at_expected_time(
created_at=snapshot.created_at,
expires_at=snapshot.expires_at,
label="expiring sync snapshot",
)

finally:
sandbox1.client.close()
Expand All @@ -145,20 +229,61 @@ def sync_demo() -> None:

assert config2 == b'{"version": "1.0", "env": "sync"}', "config.json mismatch!"
assert users2 == b"alice\nbob\ncharlie", "users.txt mismatch!"
print("\n✓ Sync assertions passed!")
print("\n✓ Sync restore assertions passed!")

print("\n[6] Creating second snapshot with expiration=0...")
persistent_snapshot = sandbox2.snapshot(expiration=0)
sync_snapshot_ids.append(persistent_snapshot.snapshot_id)
print(f" Snapshot ID: {persistent_snapshot.snapshot_id}")
print(f" Created At: {persistent_snapshot.created_at}")
print(f" Expires At: {persistent_snapshot.expires_at}")
assert persistent_snapshot.expires_at is None, (
"persistent sync snapshot should not have an expires_at timestamp"
)

finally:
sandbox2.stop()
sandbox2.client.close()

# Step 5: Retrieve and delete the snapshot
print("\n[6] Retrieving and deleting snapshot...")
fetched = Snapshot.get(snapshot_id=snapshot.snapshot_id)
try:
fetched.delete()
print(f" Deleted snapshot, status: {fetched.status}")
finally:
fetched.client.close()
print("\n[7] Listing recent snapshots...")
since = snapshot.created_at - 1
first_page = Snapshot.list(limit=10, since=since)
first_page_ids = [listed.id for listed in first_page.snapshots]
print(f" First page snapshot IDs: {first_page_ids}")

paged_ids: list[list[str]] = []
found_ids: set[str] = set()
for page in first_page.iter_pages():
page_ids = [listed.id for listed in page.snapshots]
paged_ids.append(page_ids)
found_ids.update(page_ids)
if all(snapshot_id in found_ids for snapshot_id in sync_snapshot_ids):
break

print(f" Visited pages: {paged_ids}")
assert all(snapshot_id in found_ids for snapshot_id in sync_snapshot_ids), (
"Did not find all sync snapshots in Snapshot.list() pages"
)

item_ids: list[str] = []
for listed in Snapshot.list(limit=10, since=since).iter_items():
item_ids.append(listed.id)
if all(snapshot_id in item_ids for snapshot_id in sync_snapshot_ids):
break

print(f" Iterated snapshot IDs: {item_ids}")
assert all(snapshot_id in item_ids for snapshot_id in sync_snapshot_ids), (
"Did not find all sync snapshots while iterating Snapshot.list() items"
)

print("\n[8] Retrieving and deleting snapshots...")
for snapshot_id in sync_snapshot_ids:
fetched = Snapshot.get(snapshot_id=snapshot_id)
try:
fetched.delete()
print(f" Deleted snapshot {snapshot_id}, status: {fetched.status}")
assert fetched.status == "deleted", "sync snapshot should be deleted"
finally:
fetched.client.close()


if __name__ == "__main__":
Expand Down
65 changes: 53 additions & 12 deletions src/vercel/_internal/sandbox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import tarfile
from collections.abc import AsyncGenerator, Generator
from importlib.metadata import version as _pkg_version
from typing import Any
from typing import Any, TypeAlias, cast

import httpx

Expand Down Expand Up @@ -44,6 +44,7 @@
SandboxesResponse,
SandboxResponse,
SnapshotResponse,
SnapshotsResponse,
WriteFile,
)
from vercel._internal.sandbox.network_policy import (
Expand All @@ -61,6 +62,10 @@
f"vercel/sandbox/{VERSION} (Python/{sys.version}; {PLATFORM.system}/{PLATFORM.machine})"
)

JSONScalar: TypeAlias = str | int | float | bool | None
JSONValue: TypeAlias = JSONScalar | dict[str, "JSONValue"] | list["JSONValue"]
RequestQuery: TypeAlias = dict[str, str | int | float | bool | None]


# ---------------------------------------------------------------------------
# Request client — error handling + request_json convenience
Expand All @@ -83,11 +88,11 @@ async def request(
path: str,
*,
headers: dict[str, str] | None = None,
query: dict[str, Any] | None = None,
query: RequestQuery | None = None,
body: JSONBody | BytesBody | None = None,
stream: bool = False,
) -> httpx.Response:
params: dict[str, Any] | None = None
params: RequestQuery | None = None
if query:
params = {k: v for k, v in query.items() if v is not None}

Expand All @@ -113,7 +118,7 @@ async def request(
error_body = None

# Parse a helpful error message
parsed: Any | None = None
parsed: JSONValue | None = None
message = f"HTTP {response.status_code}"
if error_body:
try:
Expand Down Expand Up @@ -148,19 +153,30 @@ async def request_json(
self,
method: str,
path: str,
**kwargs: Any,
) -> Any:
headers = kwargs.pop("headers", None) or {}
*,
headers: dict[str, str] | None = None,
query: RequestQuery | None = None,
body: JSONBody | BytesBody | None = None,
stream: bool = False,
) -> JSONValue:
headers = dict(headers or {})
headers.setdefault("content-type", "application/json")
r = await self.request(method, path, headers=headers, **kwargs)
return r.json()
r = await self.request(
method,
path,
headers=headers,
query=query,
body=body,
stream=stream,
)
return cast(JSONValue, r.json())


def _build_sandbox_error(
response: httpx.Response,
message: str,
*,
data: Any | None = None,
data: JSONValue | None = None,
) -> APIError:
status_code = response.status_code
if status_code == 401:
Expand Down Expand Up @@ -410,9 +426,14 @@ async def extend_timeout(self, *, sandbox_id: str, duration: int) -> SandboxResp
)
return SandboxResponse.model_validate(data)

async def create_snapshot(self, *, sandbox_id: str) -> CreateSnapshotResponse:
async def create_snapshot(
self, *, sandbox_id: str, expiration: int | None = None
) -> CreateSnapshotResponse:
body = None if expiration is None else JSONBody({"expiration": expiration})
data = await self._request_client.request_json(
"POST", f"/v1/sandboxes/{sandbox_id}/snapshot"
"POST",
f"/v1/sandboxes/{sandbox_id}/snapshot",
body=body,
)
return CreateSnapshotResponse.model_validate(data)

Expand All @@ -422,6 +443,26 @@ async def get_snapshot(self, *, snapshot_id: str) -> SnapshotResponse:
)
return SnapshotResponse.model_validate(data)

async def list_snapshots(
self,
*,
project_id: str | None = None,
limit: int | None = None,
since: int | None = None,
until: int | None = None,
) -> SnapshotsResponse:
data = await self._request_client.request_json(
"GET",
"/v1/sandboxes/snapshots",
query={
"project": project_id,
"limit": limit,
"since": since,
"until": until,
},
)
return SnapshotsResponse.model_validate(data)

async def delete_snapshot(self, *, snapshot_id: str) -> SnapshotResponse:
data = await self._request_client.request_json(
"DELETE", f"/v1/sandboxes/snapshots/{snapshot_id}"
Expand Down
Loading
Loading