diff --git a/examples/sandbox_11_snapshots.py b/examples/sandbox_11_snapshots.py index 952a070..ef8b616 100644 --- a/examples/sandbox_11_snapshots.py +++ b/examples/sandbox_11_snapshots.py @@ -22,12 +22,33 @@ load_dotenv() +SNAPSHOT_EXPIRATION_MS = 86_400_000 +SNAPSHOT_EXPIRATION_DRIFT_MS = 1_000 + + +def assert_time_near(*, actual: int | None, expected: int, drift: int, label: str) -> None: + assert actual is not None, f"{label} should report an expires_at timestamp" + assert abs(actual - expected) <= drift, f"{label} should be within {drift}ms of {expected}" + + +def assert_snapshot_expires_at_expected_time( + *, created_at: int, expires_at: int | None, label: str +) -> None: + assert_time_near( + actual=expires_at, + expected=created_at + SNAPSHOT_EXPIRATION_MS, + drift=SNAPSHOT_EXPIRATION_DRIFT_MS, + label=label, + ) + async def async_demo() -> None: print("=" * 60) print("ASYNC SNAPSHOT EXAMPLE") print("=" * 60) + async_snapshot_ids: list[str] = [] + # Step 1: Create a sandbox and set it up print("\n[1] Creating initial sandbox...") sandbox1 = await AsyncSandbox.create(timeout=120_000) @@ -46,12 +67,20 @@ async def async_demo() -> None: config = await sandbox1.read_file("config.json") print(f" Created config.json: {config.decode()}") - # Step 2: Create a snapshot (this STOPS the sandbox) - print("\n[3] Creating snapshot...") - snapshot = await sandbox1.snapshot() + # Step 2: Create a snapshot with an explicit expiration (this STOPS the sandbox) + print("\n[3] Creating snapshot with expiration=86400000...") + snapshot = await sandbox1.snapshot(expiration=SNAPSHOT_EXPIRATION_MS) + async_snapshot_ids.append(snapshot.snapshot_id) print(f" Snapshot ID: {snapshot.snapshot_id}") print(f" Status: {snapshot.status}") + print(f" Created At: {snapshot.created_at}") + print(f" Expires At: {snapshot.expires_at}") print(f" Sandbox status after snapshot: {sandbox1.status}") + assert_snapshot_expires_at_expected_time( + created_at=snapshot.created_at, + expires_at=snapshot.expires_at, + label="expiring async snapshot", + ) finally: await sandbox1.client.aclose() @@ -76,20 +105,65 @@ async def async_demo() -> None: assert config2 == b'{"version": "1.0", "env": "async"}', "config.json mismatch!" assert users2 == b"alice\nbob\ncharlie", "users.txt mismatch!" - print("\n✓ Async assertions passed!") + print("\n✓ Async restore assertions passed!") + + # Create a second snapshot with expiration=0 to preserve it indefinitely. + print("\n[6] Creating second snapshot with expiration=0...") + persistent_snapshot = await sandbox2.snapshot(expiration=0) + async_snapshot_ids.append(persistent_snapshot.snapshot_id) + print(f" Snapshot ID: {persistent_snapshot.snapshot_id}") + print(f" Created At: {persistent_snapshot.created_at}") + print(f" Expires At: {persistent_snapshot.expires_at}") + assert persistent_snapshot.expires_at is None, ( + "persistent async snapshot should not have an expires_at timestamp" + ) finally: - await sandbox2.stop() await sandbox2.client.aclose() - # Step 5: Retrieve and delete the snapshot - print("\n[6] Retrieving and deleting snapshot...") - fetched = await AsyncSnapshot.get(snapshot_id=snapshot.snapshot_id) - try: - await fetched.delete() - print(f" Deleted snapshot, status: {fetched.status}") - finally: - await fetched.client.aclose() + # Step 5: List snapshots and confirm the new snapshots are discoverable. + print("\n[7] Listing recent snapshots...") + since = snapshot.created_at - 1 + pager = AsyncSnapshot.list(limit=10, since=since) + first_page = await pager + first_page_ids = [listed.id for listed in first_page.snapshots] + print(f" First page snapshot IDs: {first_page_ids}") + + paged_ids: list[list[str]] = [] + found_ids: set[str] = set() + async for page in pager.iter_pages(): + page_ids = [listed.id for listed in page.snapshots] + paged_ids.append(page_ids) + found_ids.update(page_ids) + if all(snapshot_id in found_ids for snapshot_id in async_snapshot_ids): + break + + print(f" Visited pages: {paged_ids}") + assert all(snapshot_id in found_ids for snapshot_id in async_snapshot_ids), ( + "Did not find all async snapshots in AsyncSnapshot.list() results" + ) + + item_ids: list[str] = [] + async for listed in AsyncSnapshot.list(limit=10, since=since).iter_items(): + item_ids.append(listed.id) + if all(snapshot_id in item_ids for snapshot_id in async_snapshot_ids): + break + + print(f" Iterated snapshot IDs: {item_ids}") + assert all(snapshot_id in item_ids for snapshot_id in async_snapshot_ids), ( + "Did not find all async snapshots while iterating AsyncSnapshot.list() items" + ) + + # Step 6: Retrieve and delete the snapshots + print("\n[8] Retrieving and deleting snapshots...") + for snapshot_id in async_snapshot_ids: + fetched = await AsyncSnapshot.get(snapshot_id=snapshot_id) + try: + await fetched.delete() + print(f" Deleted snapshot {snapshot_id}, status: {fetched.status}") + assert fetched.status == "deleted", "async snapshot should be deleted" + finally: + await fetched.client.aclose() def sync_demo() -> None: @@ -97,6 +171,8 @@ def sync_demo() -> None: print("SYNC SNAPSHOT EXAMPLE") print("=" * 60) + sync_snapshot_ids: list[str] = [] + # Step 1: Create a sandbox and set it up print("\n[1] Creating initial sandbox...") sandbox1 = Sandbox.create(timeout=120_000) @@ -115,12 +191,20 @@ def sync_demo() -> None: config = sandbox1.read_file("config.json") print(f" Created config.json: {config.decode()}") - # Step 2: Create a snapshot (this STOPS the sandbox) - print("\n[3] Creating snapshot...") - snapshot = sandbox1.snapshot() + # Step 2: Create a snapshot with an explicit expiration (this STOPS the sandbox) + print("\n[3] Creating snapshot with expiration=86400000...") + snapshot = sandbox1.snapshot(expiration=SNAPSHOT_EXPIRATION_MS) + sync_snapshot_ids.append(snapshot.snapshot_id) print(f" Snapshot ID: {snapshot.snapshot_id}") print(f" Status: {snapshot.status}") + print(f" Created At: {snapshot.created_at}") + print(f" Expires At: {snapshot.expires_at}") print(f" Sandbox status after snapshot: {sandbox1.status}") + assert_snapshot_expires_at_expected_time( + created_at=snapshot.created_at, + expires_at=snapshot.expires_at, + label="expiring sync snapshot", + ) finally: sandbox1.client.close() @@ -145,20 +229,61 @@ def sync_demo() -> None: assert config2 == b'{"version": "1.0", "env": "sync"}', "config.json mismatch!" assert users2 == b"alice\nbob\ncharlie", "users.txt mismatch!" - print("\n✓ Sync assertions passed!") + print("\n✓ Sync restore assertions passed!") + + print("\n[6] Creating second snapshot with expiration=0...") + persistent_snapshot = sandbox2.snapshot(expiration=0) + sync_snapshot_ids.append(persistent_snapshot.snapshot_id) + print(f" Snapshot ID: {persistent_snapshot.snapshot_id}") + print(f" Created At: {persistent_snapshot.created_at}") + print(f" Expires At: {persistent_snapshot.expires_at}") + assert persistent_snapshot.expires_at is None, ( + "persistent sync snapshot should not have an expires_at timestamp" + ) finally: - sandbox2.stop() sandbox2.client.close() - # Step 5: Retrieve and delete the snapshot - print("\n[6] Retrieving and deleting snapshot...") - fetched = Snapshot.get(snapshot_id=snapshot.snapshot_id) - try: - fetched.delete() - print(f" Deleted snapshot, status: {fetched.status}") - finally: - fetched.client.close() + print("\n[7] Listing recent snapshots...") + since = snapshot.created_at - 1 + first_page = Snapshot.list(limit=10, since=since) + first_page_ids = [listed.id for listed in first_page.snapshots] + print(f" First page snapshot IDs: {first_page_ids}") + + paged_ids: list[list[str]] = [] + found_ids: set[str] = set() + for page in first_page.iter_pages(): + page_ids = [listed.id for listed in page.snapshots] + paged_ids.append(page_ids) + found_ids.update(page_ids) + if all(snapshot_id in found_ids for snapshot_id in sync_snapshot_ids): + break + + print(f" Visited pages: {paged_ids}") + assert all(snapshot_id in found_ids for snapshot_id in sync_snapshot_ids), ( + "Did not find all sync snapshots in Snapshot.list() pages" + ) + + item_ids: list[str] = [] + for listed in Snapshot.list(limit=10, since=since).iter_items(): + item_ids.append(listed.id) + if all(snapshot_id in item_ids for snapshot_id in sync_snapshot_ids): + break + + print(f" Iterated snapshot IDs: {item_ids}") + assert all(snapshot_id in item_ids for snapshot_id in sync_snapshot_ids), ( + "Did not find all sync snapshots while iterating Snapshot.list() items" + ) + + print("\n[8] Retrieving and deleting snapshots...") + for snapshot_id in sync_snapshot_ids: + fetched = Snapshot.get(snapshot_id=snapshot_id) + try: + fetched.delete() + print(f" Deleted snapshot {snapshot_id}, status: {fetched.status}") + assert fetched.status == "deleted", "sync snapshot should be deleted" + finally: + fetched.client.close() if __name__ == "__main__": diff --git a/src/vercel/_internal/sandbox/core.py b/src/vercel/_internal/sandbox/core.py index bea5e66..83c8a9e 100644 --- a/src/vercel/_internal/sandbox/core.py +++ b/src/vercel/_internal/sandbox/core.py @@ -16,7 +16,7 @@ import tarfile from collections.abc import AsyncGenerator, Generator from importlib.metadata import version as _pkg_version -from typing import Any +from typing import Any, TypeAlias, cast import httpx @@ -44,6 +44,7 @@ SandboxesResponse, SandboxResponse, SnapshotResponse, + SnapshotsResponse, WriteFile, ) from vercel._internal.sandbox.network_policy import ( @@ -61,6 +62,10 @@ f"vercel/sandbox/{VERSION} (Python/{sys.version}; {PLATFORM.system}/{PLATFORM.machine})" ) +JSONScalar: TypeAlias = str | int | float | bool | None +JSONValue: TypeAlias = JSONScalar | dict[str, "JSONValue"] | list["JSONValue"] +RequestQuery: TypeAlias = dict[str, str | int | float | bool | None] + # --------------------------------------------------------------------------- # Request client — error handling + request_json convenience @@ -83,11 +88,11 @@ async def request( path: str, *, headers: dict[str, str] | None = None, - query: dict[str, Any] | None = None, + query: RequestQuery | None = None, body: JSONBody | BytesBody | None = None, stream: bool = False, ) -> httpx.Response: - params: dict[str, Any] | None = None + params: RequestQuery | None = None if query: params = {k: v for k, v in query.items() if v is not None} @@ -113,7 +118,7 @@ async def request( error_body = None # Parse a helpful error message - parsed: Any | None = None + parsed: JSONValue | None = None message = f"HTTP {response.status_code}" if error_body: try: @@ -148,19 +153,30 @@ async def request_json( self, method: str, path: str, - **kwargs: Any, - ) -> Any: - headers = kwargs.pop("headers", None) or {} + *, + headers: dict[str, str] | None = None, + query: RequestQuery | None = None, + body: JSONBody | BytesBody | None = None, + stream: bool = False, + ) -> JSONValue: + headers = dict(headers or {}) headers.setdefault("content-type", "application/json") - r = await self.request(method, path, headers=headers, **kwargs) - return r.json() + r = await self.request( + method, + path, + headers=headers, + query=query, + body=body, + stream=stream, + ) + return cast(JSONValue, r.json()) def _build_sandbox_error( response: httpx.Response, message: str, *, - data: Any | None = None, + data: JSONValue | None = None, ) -> APIError: status_code = response.status_code if status_code == 401: @@ -410,9 +426,14 @@ async def extend_timeout(self, *, sandbox_id: str, duration: int) -> SandboxResp ) return SandboxResponse.model_validate(data) - async def create_snapshot(self, *, sandbox_id: str) -> CreateSnapshotResponse: + async def create_snapshot( + self, *, sandbox_id: str, expiration: int | None = None + ) -> CreateSnapshotResponse: + body = None if expiration is None else JSONBody({"expiration": expiration}) data = await self._request_client.request_json( - "POST", f"/v1/sandboxes/{sandbox_id}/snapshot" + "POST", + f"/v1/sandboxes/{sandbox_id}/snapshot", + body=body, ) return CreateSnapshotResponse.model_validate(data) @@ -422,6 +443,26 @@ async def get_snapshot(self, *, snapshot_id: str) -> SnapshotResponse: ) return SnapshotResponse.model_validate(data) + async def list_snapshots( + self, + *, + project_id: str | None = None, + limit: int | None = None, + since: int | None = None, + until: int | None = None, + ) -> SnapshotsResponse: + data = await self._request_client.request_json( + "GET", + "/v1/sandboxes/snapshots", + query={ + "project": project_id, + "limit": limit, + "since": since, + "until": until, + }, + ) + return SnapshotsResponse.model_validate(data) + async def delete_snapshot(self, *, snapshot_id: str) -> SnapshotResponse: data = await self._request_client.request_json( "DELETE", f"/v1/sandboxes/snapshots/{snapshot_id}" diff --git a/src/vercel/_internal/sandbox/errors.py b/src/vercel/_internal/sandbox/errors.py index 3f4e108..62bb96d 100644 --- a/src/vercel/_internal/sandbox/errors.py +++ b/src/vercel/_internal/sandbox/errors.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import Any - import httpx @@ -23,11 +21,11 @@ class SandboxError(Exception): class APIError(SandboxError): - def __init__(self, response: httpx.Response, message: str, *, data: Any | None = None): + def __init__(self, response: httpx.Response, message: str, *, data: object | None = None): super().__init__(message) self.response = response self.status_code = response.status_code - self.data = data + self.data: object | None = data class SandboxAuthError(APIError): @@ -44,7 +42,7 @@ def __init__( response: httpx.Response, message: str, *, - data: Any | None = None, + data: object | None = None, retry_after: str | int | None = None, ) -> None: super().__init__(response, message, data=data) diff --git a/src/vercel/_internal/sandbox/models.py b/src/vercel/_internal/sandbox/models.py index ef3e325..94ff510 100644 --- a/src/vercel/_internal/sandbox/models.py +++ b/src/vercel/_internal/sandbox/models.py @@ -193,11 +193,18 @@ class Snapshot(BaseModel): region: str status: Literal["created", "deleted", "failed"] size_bytes: int = Field(alias="sizeBytes") - expires_at: int = Field(alias="expiresAt") + expires_at: int | None = Field(default=None, alias="expiresAt") created_at: int = Field(alias="createdAt") updated_at: int = Field(alias="updatedAt") +class SnapshotsResponse(BaseModel): + """API response containing a list of snapshots.""" + + snapshots: list[Snapshot] + pagination: Pagination + + class SnapshotResponse(BaseModel): """API response containing a snapshot.""" diff --git a/src/vercel/_internal/sandbox/pagination.py b/src/vercel/_internal/sandbox/pagination.py index d9bb0cd..1f48a73 100644 --- a/src/vercel/_internal/sandbox/pagination.py +++ b/src/vercel/_internal/sandbox/pagination.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass +from datetime import datetime, timezone from vercel._internal.sandbox.models import Pagination @@ -10,13 +11,28 @@ class SandboxPageInfo: until: int -@dataclass(frozen=True, slots=True) -class SandboxListParams: - project_id: str | None = None - limit: int | None = None - since: int | None = None - until: int | None = None +@dataclass(frozen=True, slots=True, init=False) +class _BaseListParams: + project_id: str | None + limit: int | None + since: int | None + until: int | None + + def __init__( + self, + project_id: str | None = None, + limit: int | None = None, + since: datetime | int | None = None, + until: datetime | int | None = None, + ) -> None: + object.__setattr__(self, "project_id", project_id) + object.__setattr__(self, "limit", limit) + object.__setattr__(self, "since", normalize_list_timestamp(since)) + object.__setattr__(self, "until", normalize_list_timestamp(until)) + +@dataclass(frozen=True, slots=True, init=False) +class SandboxListParams(_BaseListParams): def with_until(self, until: int) -> SandboxListParams: return SandboxListParams( project_id=self.project_id, @@ -32,4 +48,46 @@ def next_sandbox_page_info(pagination: Pagination) -> SandboxPageInfo | None: return SandboxPageInfo(until=pagination.next) -__all__ = ["SandboxListParams", "SandboxPageInfo", "next_sandbox_page_info"] +@dataclass(frozen=True, slots=True) +class SnapshotPageInfo: + until: int + + +@dataclass(frozen=True, slots=True, init=False) +class SnapshotListParams(_BaseListParams): + def with_until(self, until: int) -> SnapshotListParams: + return SnapshotListParams( + project_id=self.project_id, + limit=self.limit, + since=self.since, + until=until, + ) + + +def next_snapshot_page_info(pagination: Pagination) -> SnapshotPageInfo | None: + if pagination.next is None: + return None + return SnapshotPageInfo(until=pagination.next) + + +def normalize_list_timestamp(value: datetime | int | None) -> int | None: + if value is None: + return None + if isinstance(value, int): + return value + if isinstance(value, datetime): + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + return int(value.timestamp() * 1000) + raise TypeError("List timestamps must be datetime or integer milliseconds") + + +__all__ = [ + "SandboxListParams", + "SandboxPageInfo", + "SnapshotListParams", + "SnapshotPageInfo", + "next_sandbox_page_info", + "next_snapshot_page_info", + "normalize_list_timestamp", +] diff --git a/src/vercel/sandbox/__init__.py b/src/vercel/sandbox/__init__.py index 366d50f..fe6bf58 100644 --- a/src/vercel/sandbox/__init__.py +++ b/src/vercel/sandbox/__init__.py @@ -16,9 +16,14 @@ from .command import AsyncCommand, AsyncCommandFinished, Command, CommandFinished from .models import GitSource, SnapshotSource, Source, TarballSource -from .page import AsyncSandboxPage, SandboxPage +from .page import AsyncSandboxPage, AsyncSnapshotPage, SandboxPage, SnapshotPage from .sandbox import AsyncSandbox, Sandbox -from .snapshot import AsyncSnapshot, Snapshot +from .snapshot import ( + MIN_SNAPSHOT_EXPIRATION_MS, + AsyncSnapshot, + Snapshot, + SnapshotExpiration, +) __all__ = [ "SandboxError", @@ -29,10 +34,14 @@ "SandboxServerError", "AsyncSandbox", "AsyncSandboxPage", + "AsyncSnapshotPage", "AsyncSnapshot", "Sandbox", "SandboxPage", + "SnapshotPage", "Snapshot", + "SnapshotExpiration", + "MIN_SNAPSHOT_EXPIRATION_MS", "AsyncCommand", "AsyncCommandFinished", "Command", diff --git a/src/vercel/sandbox/models.py b/src/vercel/sandbox/models.py index d9496c8..a539b0c 100644 --- a/src/vercel/sandbox/models.py +++ b/src/vercel/sandbox/models.py @@ -16,6 +16,7 @@ Snapshot, SnapshotResponse, SnapshotSource, + SnapshotsResponse, Source, TarballSource, WriteFile, @@ -37,6 +38,7 @@ "SandboxRoute", "SandboxesResponse", "Snapshot", + "SnapshotsResponse", "SnapshotResponse", "SnapshotSource", "Source", diff --git a/src/vercel/sandbox/page.py b/src/vercel/sandbox/page.py index 5dd61cf..37d673e 100644 --- a/src/vercel/sandbox/page.py +++ b/src/vercel/sandbox/page.py @@ -5,8 +5,17 @@ from typing import Any from vercel._internal.pagination import PageController -from vercel._internal.sandbox.models import Pagination, Sandbox as SandboxModel -from vercel._internal.sandbox.pagination import SandboxPageInfo, next_sandbox_page_info +from vercel._internal.sandbox.models import ( + Pagination, + Sandbox as SandboxModel, + Snapshot as SnapshotModel, +) +from vercel._internal.sandbox.pagination import ( + SandboxPageInfo, + SnapshotPageInfo, + next_sandbox_page_info, + next_snapshot_page_info, +) @dataclass(slots=True) @@ -118,9 +127,126 @@ async def iter_items(self) -> AsyncIterator[SandboxModel]: yield item +@dataclass(slots=True) +class SnapshotPage: + snapshots: list[SnapshotModel] + pagination: Pagination + _controller: PageController[SnapshotPage, SnapshotModel, SnapshotPageInfo] = field( + init=False, + repr=False, + ) + + @classmethod + def create( + cls, + *, + snapshots: list[SnapshotModel], + pagination: Pagination, + fetch_next_page: Callable[[SnapshotPageInfo], Awaitable[SnapshotPage]], + ) -> SnapshotPage: + page = cls(snapshots=list(snapshots), pagination=pagination) + page._controller = PageController( + get_items=lambda current_page: current_page.snapshots, + get_next_page_info=lambda current_page: next_snapshot_page_info( + current_page.pagination + ), + fetch_next_page=fetch_next_page, + ) + return page + + def has_next_page(self) -> bool: + return self._controller.has_next_page(self) + + def next_page_info(self) -> SnapshotPageInfo | None: + return self._controller.next_page_info(self) + + def get_next_page(self) -> SnapshotPage | None: + return self._controller.get_next_page_sync(self) + + def iter_pages(self) -> Iterator[SnapshotPage]: + return self._controller.iter_pages_sync(self) + + def iter_items(self) -> Iterator[SnapshotModel]: + return self._controller.iter_items_sync(self) + + +@dataclass(slots=True) +class AsyncSnapshotPage: + snapshots: list[SnapshotModel] + pagination: Pagination + _controller: PageController[AsyncSnapshotPage, SnapshotModel, SnapshotPageInfo] = field( + init=False, + repr=False, + ) + + @classmethod + def create( + cls, + *, + snapshots: list[SnapshotModel], + pagination: Pagination, + fetch_next_page: Callable[[SnapshotPageInfo], Awaitable[AsyncSnapshotPage]], + ) -> AsyncSnapshotPage: + page = cls(snapshots=list(snapshots), pagination=pagination) + page._controller = PageController( + get_items=lambda current_page: current_page.snapshots, + get_next_page_info=lambda current_page: next_snapshot_page_info( + current_page.pagination + ), + fetch_next_page=fetch_next_page, + ) + return page + + def has_next_page(self) -> bool: + return self._controller.has_next_page(self) + + def next_page_info(self) -> SnapshotPageInfo | None: + return self._controller.next_page_info(self) + + async def get_next_page(self) -> AsyncSnapshotPage | None: + return await self._controller.get_next_page(self) + + def iter_pages(self) -> AsyncIterator[AsyncSnapshotPage]: + return self._controller.iter_pages(self) + + def iter_items(self) -> AsyncIterator[SnapshotModel]: + return self._controller.iter_items(self) + + +@dataclass(slots=True) +class AsyncSnapshotPager: + _fetch_first_page: Callable[[], Awaitable[AsyncSnapshotPage]] + _first_page: AsyncSnapshotPage | None = field(init=False, default=None, repr=False) + + async def _get_first_page(self) -> AsyncSnapshotPage: + if self._first_page is None: + self._first_page = await self._fetch_first_page() + return self._first_page + + def __await__(self) -> Generator[Any, None, AsyncSnapshotPage]: + return self._get_first_page().__await__() + + def __aiter__(self) -> AsyncIterator[SnapshotModel]: + return self.iter_items() + + async def iter_pages(self) -> AsyncIterator[AsyncSnapshotPage]: + first_page = await self._get_first_page() + async for page in first_page.iter_pages(): + yield page + + async def iter_items(self) -> AsyncIterator[SnapshotModel]: + first_page = await self._get_first_page() + async for item in first_page.iter_items(): + yield item + + __all__ = [ + "AsyncSnapshotPager", + "AsyncSnapshotPage", "AsyncSandboxPager", "AsyncSandboxPage", "SandboxPage", + "SnapshotPage", "SandboxPageInfo", + "SnapshotPageInfo", ] diff --git a/src/vercel/sandbox/sandbox.py b/src/vercel/sandbox/sandbox.py index 9fb5dea..3285662 100644 --- a/src/vercel/sandbox/sandbox.py +++ b/src/vercel/sandbox/sandbox.py @@ -3,7 +3,7 @@ import builtins import time from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import datetime from typing import Any from vercel._internal.iter_coroutine import iter_coroutine @@ -30,7 +30,12 @@ ) from .page import AsyncSandboxPage, AsyncSandboxPager, SandboxPage from .pty.shell import start_interactive_shell -from .snapshot import AsyncSnapshot, Snapshot as SnapshotClass +from .snapshot import ( + AsyncSnapshot, + Snapshot as SnapshotClass, + SnapshotExpiration, + normalize_snapshot_expiration, +) def _normalize_source(source: Source | None) -> dict[str, Any] | None: @@ -104,18 +109,6 @@ async def fetch_next_page(page_info) -> SandboxPage: ) -def _normalize_list_timestamp(value: datetime | int | None) -> int | None: - if value is None: - return None - if isinstance(value, int): - return value - if isinstance(value, datetime): - if value.tzinfo is None: - value = value.replace(tzinfo=timezone.utc) - return int(value.timestamp() * 1000) - raise TypeError("Sandbox list timestamps must be datetime or integer milliseconds") - - @dataclass class AsyncSandbox: client: AsyncSandboxOpsClient @@ -256,8 +249,8 @@ def list( params = SandboxListParams( project_id=creds.project_id, limit=limit, - since=_normalize_list_timestamp(since), - until=_normalize_list_timestamp(until), + since=since, + until=until, ) return AsyncSandboxPager( _fetch_first_page=lambda: _build_async_sandbox_page(creds=creds, params=params) @@ -416,14 +409,20 @@ async def extend_timeout(self, duration: int) -> None: response = await self.client.extend_timeout(sandbox_id=self.sandbox.id, duration=duration) self.sandbox = response.sandbox - async def snapshot(self) -> AsyncSnapshot: + async def snapshot( + self, *, expiration: int | SnapshotExpiration | None = None + ) -> AsyncSnapshot: """ Create a snapshot from this currently running sandbox. New sandboxes can then be created from this snapshot. Note: this sandbox will be stopped as part of the snapshot creation process. """ - response = await self.client.create_snapshot(sandbox_id=self.sandbox.id) + normalized_expiration = normalize_snapshot_expiration(expiration) + response = await self.client.create_snapshot( + sandbox_id=self.sandbox.id, + expiration=normalized_expiration, + ) self.sandbox = response.sandbox return AsyncSnapshot(client=self.client, snapshot=response.snapshot) @@ -614,8 +613,8 @@ def list( params = SandboxListParams( project_id=creds.project_id, limit=limit, - since=_normalize_list_timestamp(since), - until=_normalize_list_timestamp(until), + since=since, + until=until, ) return iter_coroutine(_build_sync_sandbox_page(creds=creds, params=params)) @@ -776,14 +775,20 @@ def extend_timeout(self, duration: int) -> None: ) self.sandbox = response.sandbox - def snapshot(self) -> SnapshotClass: + def snapshot(self, *, expiration: int | SnapshotExpiration | None = None) -> SnapshotClass: """ Create a snapshot from this currently running sandbox. New sandboxes can then be created from this snapshot. Note: this sandbox will be stopped as part of the snapshot creation process. """ - response = iter_coroutine(self.client.create_snapshot(sandbox_id=self.sandbox.id)) + normalized_expiration = normalize_snapshot_expiration(expiration) + response = iter_coroutine( + self.client.create_snapshot( + sandbox_id=self.sandbox.id, + expiration=normalized_expiration, + ) + ) self.sandbox = response.sandbox return SnapshotClass(client=self.client, snapshot=response.snapshot) diff --git a/src/vercel/sandbox/snapshot.py b/src/vercel/sandbox/snapshot.py index bdc1350..a36125e 100644 --- a/src/vercel/sandbox/snapshot.py +++ b/src/vercel/sandbox/snapshot.py @@ -1,13 +1,102 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Literal +from datetime import datetime +from typing import Final, Literal from vercel._internal.iter_coroutine import iter_coroutine from vercel._internal.sandbox import AsyncSandboxOpsClient, SyncSandboxOpsClient from vercel._internal.sandbox.models import Snapshot as SnapshotModel +from vercel._internal.sandbox.pagination import SnapshotListParams from ..oidc import Credentials, get_credentials +from .page import AsyncSnapshotPage, AsyncSnapshotPager, SnapshotPage + +MIN_SNAPSHOT_EXPIRATION_MS: Final[int] = 86_400_000 + + +class SnapshotExpiration(int): + """Snapshot expiration in milliseconds. + + Valid values are ``0`` for no expiration or any value greater than or equal + to ``86_400_000`` (24 hours). + """ + + def __new__(cls, value: int) -> SnapshotExpiration: + value = int(value) + if value != 0 and value < MIN_SNAPSHOT_EXPIRATION_MS: + raise ValueError( + "Snapshot expiration must be 0 for no expiration or >= 86400000 milliseconds" + ) + return int.__new__(cls, value) + + +def normalize_snapshot_expiration( + expiration: int | SnapshotExpiration | None, +) -> SnapshotExpiration | None: + if expiration is None: + return None + if isinstance(expiration, SnapshotExpiration): + return expiration + return SnapshotExpiration(expiration) + + +async def _build_async_snapshot_page( + *, + creds: Credentials, + params: SnapshotListParams, +) -> AsyncSnapshotPage: + client = AsyncSandboxOpsClient(team_id=creds.team_id, token=creds.token) + try: + response = await client.list_snapshots( + project_id=params.project_id, + limit=params.limit, + since=params.since, + until=params.until, + ) + finally: + await client.aclose() + + async def fetch_next_page(page_info) -> AsyncSnapshotPage: + return await _build_async_snapshot_page( + creds=creds, + params=params.with_until(page_info.until), + ) + + return AsyncSnapshotPage.create( + snapshots=response.snapshots, + pagination=response.pagination, + fetch_next_page=fetch_next_page, + ) + + +async def _build_sync_snapshot_page( + *, + creds: Credentials, + params: SnapshotListParams, +) -> SnapshotPage: + client = SyncSandboxOpsClient(team_id=creds.team_id, token=creds.token) + try: + response = await client.list_snapshots( + project_id=params.project_id, + limit=params.limit, + since=params.since, + until=params.until, + ) + finally: + client.close() + + async def fetch_next_page(page_info) -> SnapshotPage: + return await _build_sync_snapshot_page( + creds=creds, + params=params.with_until(page_info.until), + ) + + return SnapshotPage.create( + snapshots=response.snapshots, + pagination=response.pagination, + fetch_next_page=fetch_next_page, + ) @dataclass @@ -38,8 +127,13 @@ def size_bytes(self) -> int: return self.snapshot.size_bytes @property - def expires_at(self) -> int: - """Timestamp when the snapshot expires.""" + def created_at(self) -> int: + """Timestamp when the snapshot was created.""" + return self.snapshot.created_at + + @property + def expires_at(self) -> int | None: + """Timestamp when the snapshot expires, or None for no expiration.""" return self.snapshot.expires_at @staticmethod @@ -56,6 +150,28 @@ async def get( resp = await client.get_snapshot(snapshot_id=snapshot_id) return AsyncSnapshot(client=client, snapshot=resp.snapshot) + @staticmethod + def list( + *, + limit: int | None = None, + since: datetime | int | None = None, + until: datetime | int | None = None, + token: str | None = None, + project_id: str | None = None, + team_id: str | None = None, + ) -> AsyncSnapshotPager: + """List snapshots and return the first page.""" + creds: Credentials = get_credentials(token=token, project_id=project_id, team_id=team_id) + params = SnapshotListParams( + project_id=creds.project_id, + limit=limit, + since=since, + until=until, + ) + return AsyncSnapshotPager( + _fetch_first_page=lambda: _build_async_snapshot_page(creds=creds, params=params) + ) + async def delete(self) -> None: """Delete this snapshot.""" resp = await self.client.delete_snapshot(snapshot_id=self.snapshot.id) @@ -90,8 +206,13 @@ def size_bytes(self) -> int: return self.snapshot.size_bytes @property - def expires_at(self) -> int: - """Timestamp when the snapshot expires.""" + def created_at(self) -> int: + """Timestamp when the snapshot was created.""" + return self.snapshot.created_at + + @property + def expires_at(self) -> int | None: + """Timestamp when the snapshot expires, or None for no expiration.""" return self.snapshot.expires_at @staticmethod @@ -108,6 +229,26 @@ def get( resp = iter_coroutine(client.get_snapshot(snapshot_id=snapshot_id)) return Snapshot(client=client, snapshot=resp.snapshot) + @staticmethod + def list( + *, + limit: int | None = None, + since: datetime | int | None = None, + until: datetime | int | None = None, + token: str | None = None, + project_id: str | None = None, + team_id: str | None = None, + ) -> SnapshotPage: + """List snapshots and return the first page.""" + creds: Credentials = get_credentials(token=token, project_id=project_id, team_id=team_id) + params = SnapshotListParams( + project_id=creds.project_id, + limit=limit, + since=since, + until=until, + ) + return iter_coroutine(_build_sync_snapshot_page(creds=creds, params=params)) + def delete(self) -> None: """Delete this snapshot.""" resp = iter_coroutine(self.client.delete_snapshot(snapshot_id=self.snapshot.id)) diff --git a/tests/integration/test_sandbox_sync_async.py b/tests/integration/test_sandbox_sync_async.py index 3f2e146..494c24c 100644 --- a/tests/integration/test_sandbox_sync_async.py +++ b/tests/integration/test_sandbox_sync_async.py @@ -38,6 +38,19 @@ def _sandbox_with_id( return sandbox +def _snapshot_with_id( + base_snapshot: dict, + snapshot_id: str, + *, + created_at: int, +) -> dict: + snapshot = dict(base_snapshot) + snapshot["id"] = snapshot_id + snapshot["createdAt"] = created_at + snapshot["updatedAt"] = created_at + return snapshot + + async def _collect_async_pages(page) -> list: return [current_page async for current_page in page.iter_pages()] @@ -1070,6 +1083,197 @@ def handler(request: httpx.Request) -> httpx.Response: assert [sandbox.id for sandbox in items] == ["sbx_async_terminal"] assert requests == [{"teamId": "team_test123", "project": "prj_test123"}] + +class TestSnapshotList: + """Test snapshot listing operations.""" + + @respx.mock + def test_list_snapshot_sync_serializes_filters_and_iterates_pages( + self, mock_env_clear, mock_sandbox_snapshot_response + ): + from vercel.sandbox import Snapshot + + project = "snapshot-project" + limit = 2 + since = datetime(2024, 1, 15, 12, 0, tzinfo=timezone.utc) + until = datetime(2024, 1, 15, 12, 30, tzinfo=timezone.utc) + expected_since = str(int(since.timestamp() * 1000)) + expected_until = str(int(until.timestamp() * 1000)) + next_until = "1705320000000" + + first_page = { + "snapshots": [ + _snapshot_with_id( + mock_sandbox_snapshot_response, + "snap_list_1", + created_at=1705320600000, + ), + _snapshot_with_id( + mock_sandbox_snapshot_response, + "snap_list_2", + created_at=1705320300000, + ), + ], + "pagination": { + "count": 3, + "next": int(next_until), + "prev": None, + }, + } + second_page = { + "snapshots": [ + _snapshot_with_id( + mock_sandbox_snapshot_response, + "snap_list_3", + created_at=1705320000000, + ), + ], + "pagination": { + "count": 3, + "next": None, + "prev": 1705320600000, + }, + } + requests: list[dict[str, str]] = [] + + def handler(request: httpx.Request) -> httpx.Response: + params = dict(request.url.params) + requests.append(params) + if params.get("until") == next_until: + return httpx.Response(200, json=second_page) + return httpx.Response(200, json=first_page) + + respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/snapshots").mock(side_effect=handler) + + page = Snapshot.list( + token="test_token", + team_id="team_test123", + project_id=project, + limit=limit, + since=since, + until=until, + ) + + assert requests == [ + { + "teamId": "team_test123", + "project": project, + "limit": str(limit), + "since": expected_since, + "until": expected_until, + } + ] + assert [ + (snapshot.id, snapshot.created_at, snapshot.expires_at) for snapshot in page.snapshots + ] == [ + ( + "snap_list_1", + 1705320600000, + mock_sandbox_snapshot_response["expiresAt"], + ), + ( + "snap_list_2", + 1705320300000, + mock_sandbox_snapshot_response["expiresAt"], + ), + ] + assert page.pagination.count == 3 + assert page.next_page_info() is not None + assert page.next_page_info().until == int(next_until) + + @respx.mock + @pytest.mark.asyncio + async def test_list_snapshot_async_serializes_integer_filters_and_iterates_pages( + self, mock_env_clear, mock_sandbox_snapshot_response + ): + from vercel.sandbox import AsyncSnapshot + + project = "snapshot-project" + limit = 2 + since = 1705321200000 + until = 1705323000000 + next_until = "1705319400000" + + first_page = { + "snapshots": [ + _snapshot_with_id( + mock_sandbox_snapshot_response, + "snap_async_1", + created_at=1705320600000, + ), + _snapshot_with_id( + mock_sandbox_snapshot_response, + "snap_async_2", + created_at=1705320000000, + ), + ], + "pagination": { + "count": 3, + "next": int(next_until), + "prev": None, + }, + } + second_page = { + "snapshots": [ + _snapshot_with_id( + mock_sandbox_snapshot_response, + "snap_async_3", + created_at=1705319400000, + ), + ], + "pagination": { + "count": 3, + "next": None, + "prev": 1705320600000, + }, + } + requests: list[dict[str, str]] = [] + + def handler(request: httpx.Request) -> httpx.Response: + params = dict(request.url.params) + requests.append(params) + if params.get("until") == next_until: + return httpx.Response(200, json=second_page) + return httpx.Response(200, json=first_page) + + respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/snapshots").mock(side_effect=handler) + + page = await AsyncSnapshot.list( + token="test_token", + team_id="team_test123", + project_id=project, + limit=limit, + since=since, + until=until, + ) + + assert requests == [ + { + "teamId": "team_test123", + "project": project, + "limit": str(limit), + "since": str(since), + "until": str(until), + } + ] + assert [ + (snapshot.id, snapshot.created_at, snapshot.expires_at) for snapshot in page.snapshots + ] == [ + ( + "snap_async_1", + 1705320600000, + mock_sandbox_snapshot_response["expiresAt"], + ), + ( + "snap_async_2", + 1705320000000, + mock_sandbox_snapshot_response["expiresAt"], + ), + ] + assert page.pagination.count == 3 + assert page.next_page_info() is not None + assert page.next_page_info().until == int(next_until) + @respx.mock def test_get_sandbox_sync_exposes_mode_network_policy( self, mock_env_clear, mock_sandbox_get_response_with_mode_network_policy @@ -2335,10 +2539,10 @@ class TestSandboxSnapshot: """Test sandbox snapshot operations.""" @respx.mock - def test_create_snapshot_sync( + def test_create_snapshot_sync_without_expiration( self, mock_env_clear, mock_sandbox_get_response, mock_sandbox_snapshot_response ): - """Test synchronous snapshot creation.""" + """Test sync snapshot creation omits the body when expiration is unset.""" from vercel.sandbox import Sandbox sandbox_id = "sbx_test123456" @@ -2377,12 +2581,260 @@ def test_create_snapshot_sync( snapshot = sandbox.snapshot() assert route.called + assert route.calls.last.request.content == b"" assert snapshot.snapshot_id == mock_sandbox_snapshot_response["id"] + assert snapshot.created_at == mock_sandbox_snapshot_response["createdAt"] + assert snapshot.expires_at == mock_sandbox_snapshot_response["expiresAt"] # Sandbox should be stopped after snapshot assert sandbox.status == "stopped" sandbox.client.close() + @respx.mock + def test_create_snapshot_sync_with_expiration( + self, mock_env_clear, mock_sandbox_get_response, mock_sandbox_snapshot_response + ): + """Test sync snapshot creation forwards a non-zero expiration.""" + from vercel.sandbox import Sandbox + + sandbox_id = "sbx_test123456" + + respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_get_response, + "routes": [], + }, + ) + ) + + stopped_response = dict(mock_sandbox_get_response) + stopped_response["status"] = "stopped" + route = respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}/snapshot").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": stopped_response, + "snapshot": mock_sandbox_snapshot_response, + }, + ) + ) + + sandbox = Sandbox.get( + sandbox_id=sandbox_id, + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) + + snapshot = sandbox.snapshot(expiration=86_400_000) + + assert route.called + body = json.loads(route.calls.last.request.content) + assert body == {"expiration": 86_400_000} + assert snapshot.created_at == mock_sandbox_snapshot_response["createdAt"] + assert sandbox.status == "stopped" + + sandbox.client.close() + + @respx.mock + def test_create_snapshot_sync_with_zero_expiration( + self, mock_env_clear, mock_sandbox_get_response, mock_sandbox_snapshot_response + ): + """Test sync snapshot creation preserves explicit zero expiration.""" + from vercel.sandbox import Sandbox + + sandbox_id = "sbx_test123456" + + respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_get_response, + "routes": [], + }, + ) + ) + + stopped_response = dict(mock_sandbox_get_response) + stopped_response["status"] = "stopped" + route = respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}/snapshot").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": stopped_response, + "snapshot": mock_sandbox_snapshot_response, + }, + ) + ) + + sandbox = Sandbox.get( + sandbox_id=sandbox_id, + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) + + sandbox.snapshot(expiration=0) + + assert route.called + body = json.loads(route.calls.last.request.content) + assert body == {"expiration": 0} + assert sandbox.status == "stopped" + + sandbox.client.close() + + @respx.mock + async def test_create_snapshot_async_without_expiration( + self, mock_env_clear, mock_sandbox_get_response, mock_sandbox_snapshot_response + ): + """Test async snapshot creation omits the body when expiration is unset.""" + from vercel.sandbox import AsyncSandbox + + sandbox_id = "sbx_test123456" + + respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_get_response, + "routes": [], + }, + ) + ) + + stopped_response = dict(mock_sandbox_get_response) + stopped_response["status"] = "stopped" + route = respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}/snapshot").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": stopped_response, + "snapshot": mock_sandbox_snapshot_response, + }, + ) + ) + + sandbox = await AsyncSandbox.get( + sandbox_id=sandbox_id, + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) + + snapshot = await sandbox.snapshot() + + assert route.called + assert route.calls.last.request.content == b"" + assert snapshot.snapshot_id == mock_sandbox_snapshot_response["id"] + assert snapshot.created_at == mock_sandbox_snapshot_response["createdAt"] + assert snapshot.expires_at == mock_sandbox_snapshot_response["expiresAt"] + assert sandbox.status == "stopped" + + await sandbox.client.aclose() + + @respx.mock + @pytest.mark.asyncio + async def test_create_snapshot_async_with_expiration( + self, mock_env_clear, mock_sandbox_get_response, mock_sandbox_snapshot_response + ): + """Test async snapshot creation forwards a non-zero expiration.""" + from vercel.sandbox import AsyncSandbox + + sandbox_id = "sbx_test123456" + + respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_get_response, + "routes": [], + }, + ) + ) + + stopped_response = dict(mock_sandbox_get_response) + stopped_response["status"] = "stopped" + route = respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}/snapshot").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": stopped_response, + "snapshot": mock_sandbox_snapshot_response, + }, + ) + ) + + sandbox = await AsyncSandbox.get( + sandbox_id=sandbox_id, + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) + + snapshot = await sandbox.snapshot(expiration=86_400_000) + + assert route.called + body = json.loads(route.calls.last.request.content) + assert body == {"expiration": 86_400_000} + assert snapshot.created_at == mock_sandbox_snapshot_response["createdAt"] + assert sandbox.status == "stopped" + + await sandbox.client.aclose() + + @respx.mock + @pytest.mark.asyncio + async def test_create_snapshot_async_with_zero_expiration_and_optional_expires_at( + self, mock_env_clear, mock_sandbox_get_response, mock_sandbox_snapshot_response + ): + """Test async snapshot creation keeps zero expiration and optional expiry metadata.""" + from vercel.sandbox import AsyncSandbox + + sandbox_id = "sbx_test123456" + snapshot_response = dict(mock_sandbox_snapshot_response) + snapshot_response.pop("expiresAt") + + respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_get_response, + "routes": [], + }, + ) + ) + + stopped_response = dict(mock_sandbox_get_response) + stopped_response["status"] = "stopped" + route = respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}/snapshot").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": stopped_response, + "snapshot": snapshot_response, + }, + ) + ) + + sandbox = await AsyncSandbox.get( + sandbox_id=sandbox_id, + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) + + snapshot = await sandbox.snapshot(expiration=0) + + assert route.called + body = json.loads(route.calls.last.request.content) + assert body == {"expiration": 0} + assert snapshot.created_at == snapshot_response["createdAt"] + assert snapshot.expires_at is None + assert sandbox.status == "stopped" + + await sandbox.client.aclose() + class TestSandboxExtendTimeout: """Test sandbox timeout extension.""" diff --git a/tests/test_examples.py b/tests/test_examples.py index 4f3549e..2b6aecc 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import os import subprocess import sys diff --git a/tests/unit/test_sandbox_page.py b/tests/unit/test_sandbox_page.py new file mode 100644 index 0000000..cc159c8 --- /dev/null +++ b/tests/unit/test_sandbox_page.py @@ -0,0 +1,135 @@ +"""Tests for sandbox page helpers.""" + +from __future__ import annotations + +import pytest + +from vercel._internal.sandbox.models import Pagination, Snapshot as SnapshotModel +from vercel.sandbox.page import AsyncSnapshotPage, SnapshotPage + + +def _snapshot_model(snapshot_id: str, *, created_at: int) -> SnapshotModel: + return SnapshotModel.model_validate( + { + "id": snapshot_id, + "sourceSandboxId": "sbx_test123456", + "region": "iad1", + "status": "created", + "sizeBytes": 1024, + "expiresAt": created_at + 86_400_000, + "createdAt": created_at, + "updatedAt": created_at, + } + ) + + +async def _collect_async_pages(page: AsyncSnapshotPage) -> list[AsyncSnapshotPage]: + return [current_page async for current_page in page.iter_pages()] + + +async def _collect_async_items(page: AsyncSnapshotPage) -> list: + return [snapshot async for snapshot in page.iter_items()] + + +class TestSnapshotPage: + def test_iterates_pages_and_items(self) -> None: + second_page = SnapshotPage.create( + snapshots=[_snapshot_model("snap_2", created_at=1705320000000)], + pagination=Pagination(count=2, next=None, prev=1705320600000), + fetch_next_page=self._fetch_unreachable_page, + ) + + async def fetch_next_page(_page_info): + return second_page + + first_page = SnapshotPage.create( + snapshots=[_snapshot_model("snap_1", created_at=1705320600000)], + pagination=Pagination(count=2, next=1705320000000, prev=None), + fetch_next_page=fetch_next_page, + ) + + assert first_page.has_next_page() is True + next_page_info = first_page.next_page_info() + assert next_page_info is not None + assert next_page_info.until == 1705320000000 + assert [ + [snapshot.id for snapshot in page.snapshots] for page in first_page.iter_pages() + ] == [ + ["snap_1"], + ["snap_2"], + ] + assert [snapshot.id for snapshot in first_page.iter_items()] == ["snap_1", "snap_2"] + + def test_terminal_page_does_not_fetch_more(self) -> None: + page = SnapshotPage.create( + snapshots=[_snapshot_model("snap_terminal", created_at=1705320600000)], + pagination=Pagination(count=1, next=None, prev=None), + fetch_next_page=self._fetch_unreachable_page, + ) + + assert page.has_next_page() is False + assert page.next_page_info() is None + assert page.get_next_page() is None + assert [ + [snapshot.id for snapshot in current.snapshots] for current in page.iter_pages() + ] == [ + ["snap_terminal"], + ] + assert [snapshot.id for snapshot in page.iter_items()] == ["snap_terminal"] + + @staticmethod + async def _fetch_unreachable_page(_page_info) -> SnapshotPage: + raise AssertionError("fetch_next_page should not be called") + + +class TestAsyncSnapshotPage: + @pytest.mark.asyncio + async def test_iterates_pages_and_items(self) -> None: + second_page = AsyncSnapshotPage.create( + snapshots=[_snapshot_model("snap_async_2", created_at=1705320000000)], + pagination=Pagination(count=2, next=None, prev=1705320600000), + fetch_next_page=self._fetch_unreachable_page, + ) + + async def fetch_next_page(_page_info): + return second_page + + first_page = AsyncSnapshotPage.create( + snapshots=[_snapshot_model("snap_async_1", created_at=1705320600000)], + pagination=Pagination(count=2, next=1705320000000, prev=None), + fetch_next_page=fetch_next_page, + ) + + assert first_page.has_next_page() is True + next_page_info = first_page.next_page_info() + assert next_page_info is not None + assert next_page_info.until == 1705320000000 + pages = await _collect_async_pages(first_page) + assert [[snapshot.id for snapshot in page.snapshots] for page in pages] == [ + ["snap_async_1"], + ["snap_async_2"], + ] + items = await _collect_async_items(first_page) + assert [snapshot.id for snapshot in items] == ["snap_async_1", "snap_async_2"] + + @pytest.mark.asyncio + async def test_terminal_page_does_not_fetch_more(self) -> None: + page = AsyncSnapshotPage.create( + snapshots=[_snapshot_model("snap_async_terminal", created_at=1705320600000)], + pagination=Pagination(count=1, next=None, prev=None), + fetch_next_page=self._fetch_unreachable_page, + ) + + assert page.has_next_page() is False + assert page.next_page_info() is None + assert await page.get_next_page() is None + pages = await _collect_async_pages(page) + assert [[snapshot.id for snapshot in current.snapshots] for current in pages] == [ + ["snap_async_terminal"], + ] + items = await _collect_async_items(page) + assert [snapshot.id for snapshot in items] == ["snap_async_terminal"] + + @staticmethod + async def _fetch_unreachable_page(_page_info) -> AsyncSnapshotPage: + raise AssertionError("fetch_next_page should not be called") diff --git a/tests/unit/test_sandbox_snapshot.py b/tests/unit/test_sandbox_snapshot.py new file mode 100644 index 0000000..3854646 --- /dev/null +++ b/tests/unit/test_sandbox_snapshot.py @@ -0,0 +1,30 @@ +"""Tests for snapshot helpers.""" + +from __future__ import annotations + +import pytest + +from vercel.sandbox import MIN_SNAPSHOT_EXPIRATION_MS, SnapshotExpiration +from vercel.sandbox.snapshot import normalize_snapshot_expiration + + +class TestSnapshotExpiration: + def test_allows_zero(self) -> None: + expiration = SnapshotExpiration(0) + + assert expiration == 0 + + def test_allows_minimum(self) -> None: + expiration = SnapshotExpiration(MIN_SNAPSHOT_EXPIRATION_MS) + + assert expiration == MIN_SNAPSHOT_EXPIRATION_MS + + def test_rejects_values_below_minimum(self) -> None: + with pytest.raises(ValueError, match="0 for no expiration or >= 86400000"): + SnapshotExpiration(MIN_SNAPSHOT_EXPIRATION_MS - 1) + + def test_normalize_coerces_int(self) -> None: + expiration = normalize_snapshot_expiration(MIN_SNAPSHOT_EXPIRATION_MS) + + assert expiration == SnapshotExpiration(MIN_SNAPSHOT_EXPIRATION_MS) + assert isinstance(expiration, SnapshotExpiration)