diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..1cbe7a7 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,100 @@ +# Changelog + +All notable changes to this project are documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +A round of platform improvements to `dexpace-sdk-core`: new optional building +blocks (typed serialization, webhook verification, pagination, two pipeline +policies), tightened retry and tracing behaviour, and a batch of correctness +fixes across bodies, SSE parsing, Digest auth, and error reporting. Everything +lands in `core`; the transport packages are unchanged. No public symbol was +removed, so existing code continues to work without modification. + +### Added + +- **Tristate values** (`serde.tristate`). A three-way type distinguishing + "set to a value", "explicitly set to null", and "absent", so partial updates + (PATCH-style payloads) can round-trip an explicit `null` without conflating it + with an omitted field. +- **Typed model codec** (`serde.codec`). A small encode/decode layer over the + existing `Serde` protocol for converting between typed models and wire bytes, + built on the standard library only. This is the largest new surface and is + worth a careful read before depending on it. +- **Webhook signature verification** (`http.webhooks`). Helpers to verify the + authenticity of inbound webhook payloads using constant-time comparison. +- **Pagination** (`pagination`). A paginator abstraction with pluggable + next-page strategies, a `Link` header parser, and a page model, so list + endpoints can be iterated without each caller re-implementing cursor handling. +- **Idempotency-key policy** (`pipeline.policies.idempotency`, plus its async + twin). Stamps a generated idempotency key onto retriable, non-idempotent + requests so safe automatic retries don't double-apply a side effect. +- **Client-identity policy** (`pipeline.policies.client_identity`, plus its + async twin). Sets a consistent `User-Agent` / client-identity header derived + from the configured application id and SDK version. +- **HTTP tracer** (`instrumentation.http_tracer`). An adapter-style tracer base + whose per-event methods default to no-ops, so a subclass overrides only the + events it cares about. Wired through the tracing policy for span emission. +- **Log correlation** (`instrumentation.correlation`). A `contextvar`-backed + correlation id that flows through the pipeline and is attached to log records, + so logs from a single logical request can be tied together. + +### Changed + +- **Retry tuning** (`pipeline.policies.retry` / `async_retry`). More + configurable backoff and clearer rules for which responses and exceptions are + retried, including respecting `Retry-After`. The async retry path now observes + cancellation cleanly between attempts. +- **Tracing and redirect policies** now emit tracer events and carry correlation + through redirects, with credentials stripped on cross-origin redirects. +- **Default pipelines** (`pipeline.defaults`). The standard sync/async stacks now + assemble the new idempotency and client-identity policies alongside the + existing retry, redirect, logging, and tracing policies. +- **Loggable bodies** (`http.request.loggable_request_body`, + `http.response.loggable_response_body`). Capture is bounded and repeatable + reads behave correctly; the byte cap is honoured on the tap without truncating + the primary write path. +- **Error reporting** (`errors.http`). HTTP errors now expose whether they are + `retryable` and carry a bounded body snapshot for diagnostics, with the + snapshot capped so an error never holds an unbounded payload. + +### Fixed + +- **SSE parsing** (`http.sse.parser`) now strips a leading UTF-8 byte-order mark + and cleans up the async stream deterministically on cancellation or exit. +- **Digest auth** (`http.auth.digest`) honours the server-advertised charset + when computing the digest, fixing authentication against servers that send + non-ASCII credentials. +- **MediaType** (`http.common.media_type`) handles parameter parsing edge cases + (quoting, casing, and whitespace) more robustly. +- **Async response cancellation** (`http.response.async_response`, + `async_response_body`). Cancelling an in-flight read now releases the + underlying resources instead of leaking them, and re-raises `CancelledError` + after cleanup. + +### Verified + +- `mypy --strict`, `ruff check`, `ruff format --check`, and `pytest` run in CI + across the supported Python matrix (3.12–3.14). New modules ship with tests + under each package's `tests/` tree, and `py.typed` continues to ship so + downstream type-checkers consume the annotations. + +### Honest scope boundaries + +The following were intentionally left out of this round and are **not** included: + +- **Default error map** — error classification beyond the `retryable` + flag and body snapshot was deferred; callers still map status codes to domain + errors themselves. +- **`sendfile` fast-path** — file bodies are streamed via the existing + `iter_bytes` path; no zero-copy `sendfile` transport optimisation was added. +- **MCP support** — no Model Context Protocol integration is included. +- **Java SDK items** — the Java counterpart lives in a separate repository and + was out of scope here. +- **Code generation** — no client/model code generation was added; all surfaces + in this release are hand-written. + +[Unreleased]: https://github.com/dexpace/python-sdk/compare/main...HEAD diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/errors/http.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/errors/http.py index 0738dc1..49f57fc 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/errors/http.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/errors/http.py @@ -5,8 +5,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Final, Generic, TypeVar +from ..http.response.loggable_response_body import LoggableResponseBody from .base import SdkError if TYPE_CHECKING: @@ -26,6 +27,12 @@ else: ModelT = TypeVar("ModelT") +# Status codes for which a retry is worthwhile by default: request timeout, +# rate limiting, and the transient 5xx family. Mirrors the retry policy's +# ``_DEFAULT_STATUS_RETRIES`` so ``retryable`` and the policy agree out of +# the box; callers can override per error via the ``retryable`` kwarg. +_DEFAULT_RETRYABLE_STATUS: Final[frozenset[int]] = frozenset({408, 429, 500, 502, 503, 504}) + # UP046 wants PEP 695 ``class Foo[T = Any](...)`` form, but that syntax # requires Python 3.13+ at runtime; we still support 3.12. @@ -49,12 +56,18 @@ class HttpResponseError(SdkError, Generic[ModelT]): # noqa: UP046 model: Optional deserialised body payload (set by consumer libraries when they parse the error body). Typed as ``ModelT | None``. + retryable: Whether retrying the request might succeed. Derived from + the response status by default (request timeout, rate limiting, + and transient 5xx are retryable) so the retry policy can read the + flag directly instead of re-deriving it; callers may override it + explicitly via the ``retryable`` constructor keyword. """ status: Status | None reason: str | None response: _AnyResponse | None model: ModelT | None + retryable: bool def __init__( self, @@ -70,17 +83,60 @@ def __init__( response: The HTTP response that triggered the error. **kwargs: Forwarded to ``SdkError`` (``error``, ``continuation_token``). The ``model`` key is consumed - separately for caller-supplied deserialised bodies. + separately for caller-supplied deserialised bodies. The + ``retryable`` key, if given, overrides the status-derived + default (pass ``True``/``False`` to force it). """ self.response = response self.status = response.status if response is not None else None self.reason = response.reason if response is not None else None self.model = kwargs.pop("model", None) + retryable_override = kwargs.pop("retryable", None) + self.retryable = ( + self._status_is_retryable() if retryable_override is None else bool(retryable_override) + ) if message is None: label = self.status.name if self.status is not None else "unknown" message = f"Operation returned a non-success status: {label}" super().__init__(message, **kwargs) + def _status_is_retryable(self) -> bool: + """Return whether this error's status is retryable by default. + + Returns: + ``True`` when the captured status is one of the default + retryable codes, ``False`` when no status was captured. + """ + return self.status is not None and int(self.status) in _DEFAULT_RETRYABLE_STATUS + + def body_snapshot(self, max_bytes: int | None = None) -> bytes: + """Preview the error response body without consuming it. + + Safe to call from logging and post-mortem paths: it never drains a + single-use stream. Bytes are only returned when the body has already + been captured for repeatable reads (a ``LoggableResponseBody``); for + any other body — or when no response/body is present — an empty + ``bytes`` is returned rather than destroying the payload. + + Args: + max_bytes: If given, return at most this many bytes from the + front of the captured body. ``None`` returns the full + capture. + + Returns: + The captured body bytes, optionally truncated to ``max_bytes``; + empty when no non-consuming preview is available. + + Raises: + ValueError: If ``max_bytes`` is negative. + """ + if max_bytes is not None and max_bytes < 0: + raise ValueError(f"max_bytes must be non-negative, got {max_bytes}") + body = self.response.body if self.response is not None else None + if isinstance(body, LoggableResponseBody): + return body.snapshot(max_bytes) + return b"" + class DecodeError(HttpResponseError[ModelT]): """The response body could not be decoded as the expected format.""" diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/auth/digest.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/auth/digest.py index 0df558a..c5e50c0 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/auth/digest.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/auth/digest.py @@ -107,6 +107,7 @@ def handle( realm = selected.parameters.get("realm", "") nonce = selected.parameters.get("nonce", "") opaque = selected.parameters.get("opaque") + charset = _select_charset(selected.parameters.get("charset")) qop = self._pick_qop(selected.parameters.get("qop")) if qop is None and "qop" in selected.parameters: # The server advertised qop but did not include ``auth``: we @@ -125,6 +126,7 @@ def handle( nc=nc, cnonce=cnonce, qop=qop, + charset=charset, ) header_value = _format_header( username=self._username, @@ -188,9 +190,10 @@ def _compute_response( nc: str, cnonce: str, qop: str | None, + charset: str, ) -> str: def h(data: str) -> str: - return hasher(data.encode("utf-8")).hexdigest() + return hasher(data.encode(charset)).hexdigest() ha1 = h(f"{self._username}:{realm}:{self._password}") if algorithm.endswith("-SESS"): @@ -202,6 +205,28 @@ def h(data: str) -> str: return h(f"{ha1}:{nonce}:{ha2}") +def _select_charset(charset_param: str | None) -> str: + """Choose the encoding for credential hashing per RFC 7616 §3.4. + + RFC 7616 defines exactly one valid ``charset`` value — ``UTF-8`` — which + a server advertises to request that ``username`` and ``password`` be + encoded as UTF-8 before hashing. When the directive is absent (or carries + any other value), the legacy RFC 2617 default of ISO-8859-1 applies. + + Args: + charset_param: The raw ``charset`` directive from the challenge, or + ``None`` if the server did not send one. Matched case-insensitively + against ``UTF-8``. + + Returns: + The Python codec name to pass to ``str.encode`` — ``"utf-8"`` when the + server advertised ``charset=UTF-8``, otherwise ``"iso-8859-1"``. + """ + if charset_param is not None and charset_param.strip().upper() == "UTF-8": + return "utf-8" + return "iso-8859-1" + + def _request_uri(url: Url) -> str: """Compute the ``uri`` parameter — path plus query, per RFC 7616 §3.4.6.""" path = url.path or "/" diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/media_type.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/media_type.py index 942b16c..d5c9677 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/media_type.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/media_type.py @@ -5,6 +5,7 @@ from __future__ import annotations +import codecs from collections.abc import Mapping from dataclasses import dataclass from typing import Self @@ -75,9 +76,19 @@ def full_type(self) -> str: @property def charset(self) -> str | None: - """The ``charset`` parameter, or ``None`` if absent.""" + """The ``charset`` parameter as a known codec name, or ``None``. + + Returns ``None`` when the parameter is absent *or* names an encoding + that the Python codec registry does not recognise. Degrading an + unknown charset to ``None`` (rather than raising) lets callers fall + back to a default encoding instead of failing to decode a body. + """ for key, value in self.parameters: if key == "charset": + try: + codecs.lookup(value) + except (LookupError, ValueError): + return None return value return None diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/streaming.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/streaming.py index c494553..7800a09 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/streaming.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/streaming.py @@ -13,7 +13,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator from typing import Any -from ...errors import DeserializationError +from ...errors.serialization import DeserializationError def iter_jsonl(chunks: Iterable[bytes]) -> Iterator[Any]: diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/request/loggable_request_body.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/request/loggable_request_body.py index 918617c..e149a56 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/request/loggable_request_body.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/request/loggable_request_body.py @@ -76,9 +76,26 @@ def iter_bytes(self, chunk_size: int = 64 * 1024) -> Iterator[bytes]: self._tap.write(chunk[:remaining]) yield chunk - def snapshot(self) -> bytes: - """Return an immutable copy of the captured bytes.""" - return self._tap.getvalue() + def snapshot(self, max_bytes: int | None = None) -> bytes: + """Return an immutable copy of the captured bytes. + + Args: + max_bytes: If given, copy at most this many bytes from the front + of the tap. A ``memoryview`` bounds the slice so no more than + ``max_bytes`` are ever materialised, even when the tap holds a + large payload. ``None`` returns the full tap. + + Returns: + The captured bytes, optionally truncated to ``max_bytes``. + + Raises: + ValueError: If ``max_bytes`` is negative. + """ + if max_bytes is None: + return self._tap.getvalue() + if max_bytes < 0: + raise ValueError(f"max_bytes must be non-negative, got {max_bytes}") + return bytes(self._tap.getbuffer()[:max_bytes]) @property def captured_size(self) -> int: diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/response/async_response.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/response/async_response.py index 1d22f70..a7b6df7 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/response/async_response.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/response/async_response.py @@ -12,6 +12,7 @@ from ..common.headers import Headers from ..common.http_header_name import HttpHeaderName from ..common.protocol import Protocol +from .async_response_body import _shielded_cleanup from .status import Status if TYPE_CHECKING: @@ -38,9 +39,14 @@ class AsyncResponse: body: AsyncResponseBody | None = None async def close(self) -> None: - """Close the response body. Idempotent.""" + """Close the response body. Idempotent. + + When invoked from ``__aexit__`` while an ``asyncio.CancelledError`` is + propagating out of an ``async with`` block, the body close is shielded + so the transport handle is released before cancellation continues. + """ if self.body is not None: - await self.body.close() + await _shielded_cleanup(self.body.close()) async def __aenter__(self) -> Self: return self diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/response/async_response_body.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/response/async_response_body.py index d6d023a..34fb4db 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/response/async_response_body.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/response/async_response_body.py @@ -5,8 +5,9 @@ from __future__ import annotations +import asyncio from abc import ABC, abstractmethod -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Awaitable from types import TracebackType from typing import Self @@ -17,6 +18,60 @@ _bytes = bytes +async def _shielded_cleanup(cleanup: Awaitable[object]) -> None: + """Run a cleanup coroutine without letting cancellation interrupt it. + + This is the single cancellation convention used by the async response + bodies: a ``finally`` block that releases transport resources may run + while an ``asyncio.CancelledError`` is already propagating through the + enclosing task. Awaiting the cleanup directly would let that + cancellation interrupt it mid-way, leaking the underlying connection. + + The cleanup is wrapped in ``asyncio.shield`` so it always runs to + completion. If the surrounding scope is cancelled, the ``CancelledError`` + raised by ``shield`` is caught and the wait retried until the shielded + cleanup finishes; the cancellation is then re-raised so it continues to + propagate. Cleanup never swallows cancellation — it merely defers it + until the resource is released. A ``CancelledError`` raised because the + cleanup *itself* was cancelled is propagated immediately. + + A pending outer cancellation always wins: if the cleanup runs to + completion but raises an ordinary exception while a cancellation is + waiting, the cancellation is re-raised (the cleanup error does not mask + it). When no cancellation is pending, a cleanup failure surfaces to the + caller unchanged. + + Args: + cleanup: The resource-release coroutine to run to completion. + + Raises: + asyncio.CancelledError: Re-raised after the cleanup completes when + the enclosing scope was cancelled while the cleanup ran. + Exception: Whatever the cleanup coroutine raised, when no outer + cancellation is pending. + """ + inner = asyncio.ensure_future(cleanup) + cancelled = False + while not inner.done(): + try: + await asyncio.shield(inner) + except asyncio.CancelledError: + if inner.cancelled(): + # The cleanup itself was cancelled, not just our wait on it. + raise + # An outer cancellation hit our wait, not the shielded cleanup. + # Keep waiting until the cleanup finishes, then re-raise so the + # cancellation continues to propagate. + cancelled = True + except Exception: + # The cleanup failed; ``inner`` retains the exception, surfaced + # below. A pending cancellation still takes precedence. + break + if cancelled: + raise asyncio.CancelledError + inner.result() + + class AsyncResponseBody(ABC): """Async twin of ``ResponseBody``. @@ -59,7 +114,7 @@ async def bytes(self) -> _bytes: async for chunk in self.aiter_bytes(): chunks.append(chunk) finally: - await self.close() + await _shielded_cleanup(self.close()) return b"".join(chunks) async def string(self, encoding: str | None = None) -> str: @@ -124,19 +179,15 @@ class _AsyncStreamResponseBody(AsyncResponseBody): Note: Cancellation contract. The generator returned by ``aiter_bytes`` - relies on a ``finally: await self.close()`` clause to release the - underlying stream. If the consuming task is cancelled - mid-iteration, that ``finally`` block runs while a - ``CancelledError`` is already in flight; depending on the host - transport, the inner ``await self._stream.close()`` may itself be - cancelled before it completes, leaving the stream open. The async - generator's ``aclose()`` is best-effort under cancellation. - - Callers that may be cancelled mid-stream are responsible for - deterministic cleanup: wrap the body in ``async with body:`` or - invoke ``await body.close()`` explicitly from a cancellation-safe - scope (for example, an ``asyncio.shield`` inside a ``finally`` - block) to guarantee the transport handle is released. + relies on a ``finally`` clause to release the underlying stream. If + the consuming task is cancelled mid-iteration, that ``finally`` block + runs while a ``CancelledError`` is already in flight. The cleanup is + routed through ``_shielded_cleanup``, which wraps the inner + ``await self._stream.close()`` in ``asyncio.shield`` so the close runs + to completion before the ``CancelledError`` is re-raised. The transport + handle is therefore released even when the iterating task is cancelled + mid-stream, and the cancellation continues to propagate afterwards — + cleanup never swallows it. """ __slots__ = ("_closed", "_consumed", "_length", "_media_type", "_stream") @@ -177,7 +228,7 @@ async def close(self) -> None: if self._closed: return self._closed = True - await self._stream.close() + await _shielded_cleanup(self._stream.close()) __all__ = ["AsyncResponseBody"] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/response/loggable_response_body.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/response/loggable_response_body.py index 0c043d4..e259d7c 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/response/loggable_response_body.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/response/loggable_response_body.py @@ -5,6 +5,7 @@ from __future__ import annotations +import threading from collections.abc import Iterator from typing import Final @@ -25,9 +26,20 @@ class LoggableResponseBody(ResponseBody): Cap semantics: bytes beyond ``max_capture_bytes`` are dropped from the cache, but the underlying body is still fully drained and closed. + + Mid-drain failure: if the underlying body raises part-way through the + one-time drain, the bytes read so far are retained in the cache and the + originating exception is stored. ``iter_bytes`` re-raises that exception + on every call so callers cannot mistake a truncated read for success, + while ``snapshot`` still returns the partial bytes for post-mortem + logging. + + Thread-safe first read: the one-time drain is guarded by a lock plus a + double-checked flag so concurrent first readers consume the underlying + single-use stream exactly once. """ - __slots__ = ("_cached", "_closed", "_drained", "_inner", "_max") + __slots__ = ("_cached", "_closed", "_drained", "_error", "_inner", "_lock", "_max") def __init__( self, @@ -48,8 +60,10 @@ def __init__( self._inner = inner self._max = max_capture_bytes self._cached: bytes = b"" + self._error: BaseException | None = None self._drained = False self._closed = False + self._lock = threading.Lock() def media_type(self) -> MediaType | None: return self._inner.media_type() @@ -59,6 +73,8 @@ def content_length(self) -> int: def iter_bytes(self, chunk_size: int = 64 * 1024) -> Iterator[bytes]: self._drain() + if self._error is not None: + raise self._error view = memoryview(self._cached) for start in range(0, len(view), chunk_size): yield bytes(view[start : start + chunk_size]) @@ -69,10 +85,31 @@ def close(self) -> None: self._closed = True self._inner.close() - def snapshot(self) -> bytes: - """Return an immutable copy of the captured bytes (draining if needed).""" + def snapshot(self, max_bytes: int | None = None) -> bytes: + """Return an immutable copy of the captured bytes (draining if needed). + + On a mid-drain failure the partial bytes read before the error are + returned for post-mortem logging; the stored exception is not raised + here (it surfaces from ``iter_bytes``). + + Args: + max_bytes: If given, copy at most this many bytes from the front + of the capture. A ``memoryview`` bounds the slice so no more + than ``max_bytes`` are ever materialised. ``None`` returns the + full capture. + + Returns: + The captured bytes, optionally truncated to ``max_bytes``. + + Raises: + ValueError: If ``max_bytes`` is negative. + """ self._drain() - return self._cached + if max_bytes is None: + return self._cached + if max_bytes < 0: + raise ValueError(f"max_bytes must be non-negative, got {max_bytes}") + return bytes(memoryview(self._cached)[:max_bytes]) @property def captured_size(self) -> int: @@ -81,15 +118,22 @@ def captured_size(self) -> int: def _drain(self) -> None: if self._drained: return - self._drained = True - chunks: list[bytes] = [] - captured = 0 - for chunk in self._inner.iter_bytes(): - if captured < self._max: - take = min(self._max - captured, len(chunk)) - chunks.append(chunk[:take]) - captured += take - self._cached = b"".join(chunks) + with self._lock: + if self._drained: + return + chunks: list[bytes] = [] + captured = 0 + try: + for chunk in self._inner.iter_bytes(): + if captured < self._max: + take = min(self._max - captured, len(chunk)) + chunks.append(chunk[:take]) + captured += take + except Exception as exc: + self._error = exc + finally: + self._cached = b"".join(chunks) + self._drained = True __all__ = ["LoggableResponseBody"] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/sse/parser.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/sse/parser.py index 4062e88..628e412 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/sse/parser.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/sse/parser.py @@ -24,16 +24,61 @@ from __future__ import annotations +import asyncio from collections import deque -from collections.abc import AsyncIterable, Iterable, Iterator +from collections.abc import AsyncIterable, Awaitable, Iterable, Iterator from dataclasses import dataclass, field -from typing import Final +from types import TracebackType +from typing import Final, Self from dexpace.sdk.core.errors import StreamingError + +async def _shielded_aclose(cleanup: Awaitable[None]) -> None: + """Run an async-stream close to completion under cancellation. + + Mirrors the cancellation convention used by the async response bodies: + a close that runs while an ``asyncio.CancelledError`` is in flight is + wrapped in ``asyncio.shield`` so it finishes releasing the upstream + transport before the cancellation is re-raised. The close never swallows + cancellation — it only defers it until the upstream stream is released. + + A pending outer cancellation always wins: if the close runs to completion + but raises an ordinary exception while a cancellation is waiting, the + cancellation is re-raised rather than masked by the close error. When no + cancellation is pending, a close failure surfaces to the caller unchanged. + + Args: + cleanup: The upstream-close coroutine to run to completion. + + Raises: + asyncio.CancelledError: Re-raised after the close completes when the + enclosing scope was cancelled while the close ran. + Exception: Whatever the close coroutine raised, when no outer + cancellation is pending. + """ + inner = asyncio.ensure_future(cleanup) + cancelled = False + while not inner.done(): + try: + await asyncio.shield(inner) + except asyncio.CancelledError: + if inner.cancelled(): + raise + cancelled = True + except Exception: + # The close failed; ``inner`` retains the exception, surfaced + # below. A pending cancellation still takes precedence. + break + if cancelled: + raise asyncio.CancelledError + inner.result() + + _LF: Final[int] = 0x0A _CR: Final[int] = 0x0D _COLON: Final[int] = 0x3A +_UTF8_BOM: Final[bytes] = b"\xef\xbb\xbf" @dataclass(frozen=True, slots=True) @@ -78,6 +123,7 @@ class SseParser: _last_id: str | None = None _retry: int | None = None _pending: deque[SseEvent] = field(default_factory=deque) + _bom_stripped: bool = False max_line_bytes: int = 1 << 20 # 1 MiB def feed(self, chunk: bytes) -> None: @@ -90,6 +136,8 @@ def feed(self, chunk: bytes) -> None: if not chunk: return self._buffer.extend(chunk) + if not self._strip_leading_bom(): + return # Buffer too short to decide on the BOM yet. while True: line, consumed = _read_line(self._buffer) if line is None: @@ -111,6 +159,7 @@ def end(self) -> Iterator[SseEvent]: StreamingError: If the trailing buffer ends mid-codepoint and cannot be decoded as UTF-8. """ + self._strip_leading_bom(at_end=True) if self._buffer: try: line = self._buffer.decode("utf-8") @@ -122,6 +171,35 @@ def end(self) -> Iterator[SseEvent]: self._dispatch() yield from self.drain() + def _strip_leading_bom(self, *, at_end: bool = False) -> bool: + """Remove a single leading UTF-8 BOM (``EF BB BF``) once at stream start. + + The check runs exactly once: after the first byte arrives. Because the + three BOM bytes may span chunk boundaries, the decision is deferred + until either the buffer holds at least three bytes or the stream ends. + + Args: + at_end: When ``True``, force the decision even if fewer than three + bytes are buffered (no further input is coming). + + Returns: + ``True`` once the BOM has been handled (stripped or ruled out) and + line consumption may proceed; ``False`` while still waiting for + enough bytes to decide. + """ + if self._bom_stripped: + return True + if ( + not at_end + and len(self._buffer) < len(_UTF8_BOM) + and self._buffer == _UTF8_BOM[: len(self._buffer)] + ): + return False # Possible partial BOM — wait for more bytes. + self._bom_stripped = True + if self._buffer[: len(_UTF8_BOM)] == _UTF8_BOM: + del self._buffer[: len(_UTF8_BOM)] + return True + def _process_line(self, line: str) -> None: if not line: self._dispatch() @@ -202,15 +280,25 @@ class AsyncSseStream: """Async iterator that drives an ``SseParser`` from an async byte stream. Construct via :func:`parse_async_events` or directly. Use as - ``async for event in stream``. + ``async for event in stream``, ideally inside ``async with`` so the + upstream byte stream is released deterministically. + + Note: + Cancellation contract. If the consuming task is cancelled + mid-stream, :meth:`aclose` (run from ``__aexit__`` or directly) + releases the upstream byte iterator. The upstream ``aclose`` is + routed through ``_shielded_aclose`` so it runs to completion even + while a ``CancelledError`` is in flight; the cancellation is then + re-raised and continues to propagate — closing never swallows it. """ - __slots__ = ("_chunks", "_parser", "_pending") + __slots__ = ("_chunks", "_closed", "_parser", "_pending") def __init__(self, chunks: AsyncIterable[bytes]) -> None: self._chunks = aiter(chunks) self._parser = SseParser() self._pending: Iterator[SseEvent] = iter(()) + self._closed = False def __aiter__(self) -> AsyncSseStream: return self @@ -232,6 +320,31 @@ async def __anext__(self) -> SseEvent: self._parser.feed(chunk) self._pending = self._parser.drain() + async def aclose(self) -> None: + """Release the upstream byte stream. Idempotent. + + Closes the wrapped async iterator's ``aclose`` (when it exposes one) + under ``asyncio.shield`` so the transport handle is released even when + the consuming task is cancelled mid-stream. + """ + if self._closed: + return + self._closed = True + upstream_aclose = getattr(self._chunks, "aclose", None) + if upstream_aclose is not None: + await _shielded_aclose(upstream_aclose()) + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + await self.aclose() + def parse_async_events(chunks: AsyncIterable[bytes]) -> AsyncSseStream: """Build an async SSE event stream from an async byte iterable. diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/webhooks/__init__.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/webhooks/__init__.py new file mode 100644 index 0000000..3711214 --- /dev/null +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/webhooks/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Standard Webhooks signature verification (standardwebhooks.com). + +Stdlib-only HMAC-SHA256 verification of inbound webhooks: rebuild the signed +content ``{id}.{timestamp}.{body}``, compare constant-time against any of the +provided ``v1,`` signatures, and reject deliveries outside a ±5-minute +timestamp window. :class:`WebhookVerifier.unwrap` additionally parses the +verified JSON body. +""" + +from __future__ import annotations + +from .verification import ( + DEFAULT_TOLERANCE_SECONDS, + InvalidWebhookSignatureError, + WebhookVerifier, +) + +__all__ = [ + "DEFAULT_TOLERANCE_SECONDS", + "InvalidWebhookSignatureError", + "WebhookVerifier", +] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/webhooks/verification.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/webhooks/verification.py new file mode 100644 index 0000000..f866e3b --- /dev/null +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/webhooks/verification.py @@ -0,0 +1,255 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Standard Webhooks signature verification. + +Implements the `Standard Webhooks `_ scheme +for verifying inbound webhook payloads, using the standard library only. + +A producer signs each webhook by computing ``HMAC-SHA256`` over the content +``{id}.{timestamp}.{body}`` keyed by a shared secret, base64-encoding the +digest, and prefixing it with the scheme version (``v1,``). The three pieces a +receiver needs travel in headers: + +- ``webhook-id``: an opaque message identifier, also part of the signed content + so a signature cannot be replayed against a different message id. +- ``webhook-timestamp``: the Unix epoch second the message was signed, used to + reject stale deliveries outside a tolerance window. +- ``webhook-signature``: one or more space-separated ``v1,`` tokens. A + producer may publish several (e.g. during secret rotation); a match against + any one is sufficient. + +The shared secret is supplied in its on-the-wire form, ``whsec_``; the +prefix is stripped and the remainder base64-decoded to recover the raw HMAC +key. + +Verification is constant-time (``hmac.compare_digest``) and rejects timestamps +that are too old or too far in the future relative to an injected +:class:`~dexpace.sdk.core.util.Clock`, defaulting to the process clock. + +Example: + >>> verifier = WebhookVerifier("whsec_MfKQ9r8GKYqrTwjUPD8ILPZIo2LaLaSw") + >>> headers = { + ... "webhook-id": "msg_2KWPBgLlAfxdpx2AI54pPJ85f4W", + ... "webhook-timestamp": "1690000000", + ... "webhook-signature": "v1,g0hM9SsE+OTPJTGt/tmIKtSyZlE3uFJELVlNIOLJ1OE=", + ... } + >>> payload = verifier.unwrap(headers, request_body) # parsed JSON dict +""" + +from __future__ import annotations + +import base64 +import binascii +import hashlib +import hmac +import json +from collections.abc import Mapping +from typing import Final + +from dexpace.sdk.core.util import SYSTEM_CLOCK, Clock + +__all__ = [ + "DEFAULT_TOLERANCE_SECONDS", + "InvalidWebhookSignatureError", + "WebhookVerifier", +] + +_SECRET_PREFIX: Final[str] = "whsec_" +_SIGNATURE_VERSION: Final[str] = "v1" +_ID_HEADER: Final[str] = "webhook-id" +_TIMESTAMP_HEADER: Final[str] = "webhook-timestamp" +_SIGNATURE_HEADER: Final[str] = "webhook-signature" + +DEFAULT_TOLERANCE_SECONDS: Final[int] = 5 * 60 +"""Default ``±`` timestamp tolerance in seconds (5 minutes per the spec).""" + + +class InvalidWebhookSignatureError(ValueError): + """Raised when a webhook payload fails verification. + + A ``ValueError`` subclass so callers can catch it as either the specific + type or the broad input-validation category. The message never echoes the + secret or the expected signature, only the reason the check failed. + """ + + +def _require_header(headers: Mapping[str, str], name: str) -> str: + """Return the value of ``name`` from ``headers`` (case-insensitive). + + Args: + headers: Inbound request headers. + name: Lower-case header name to look up. + + Returns: + The header value. + + Raises: + InvalidWebhookSignatureError: If the header is missing. + """ + value = headers.get(name) + if value is None: + lowered = {key.lower(): val for key, val in headers.items()} + value = lowered.get(name) + if value is None: + raise InvalidWebhookSignatureError(f"missing required header: {name}") + return value + + +def _decode_secret(secret: str) -> bytes: + """Decode a ``whsec_``-prefixed base64 secret into the raw HMAC key. + + Args: + secret: The shared secret, with or without the ``whsec_`` prefix. The + prefix is optional so a caller that already stripped it still works. + + Returns: + The raw key bytes. + + Raises: + InvalidWebhookSignatureError: If the base64 body is malformed. + """ + body = secret[len(_SECRET_PREFIX) :] if secret.startswith(_SECRET_PREFIX) else secret + try: + return base64.b64decode(body, validate=True) + except (binascii.Error, ValueError) as exc: + raise InvalidWebhookSignatureError("malformed webhook secret") from exc + + +def _as_bytes(body: str | bytes) -> bytes: + """Return ``body`` as UTF-8 bytes, passing through ``bytes`` unchanged.""" + return body if isinstance(body, bytes) else body.encode("utf-8") + + +class WebhookVerifier: + """Verifies inbound webhooks against the Standard Webhooks scheme. + + Immutable after construction: the decoded key, tolerance, and clock are + fixed. Safe to share across threads — verification holds no mutable state. + + Args: + secret: The shared signing secret in ``whsec_`` form (the + prefix is optional). + tolerance_seconds: Maximum absolute difference, in seconds, allowed + between the signed timestamp and the current time. Defaults to + :data:`DEFAULT_TOLERANCE_SECONDS` (5 minutes). + clock: Time source used to evaluate the tolerance window. Defaults to + the process clock; inject a fake to test the replay window. + + Raises: + InvalidWebhookSignatureError: If ``secret`` is malformed. + ValueError: If ``tolerance_seconds`` is negative. + """ + + __slots__ = ("_clock", "_key", "_tolerance_seconds") + + def __init__( + self, + secret: str, + *, + tolerance_seconds: int = DEFAULT_TOLERANCE_SECONDS, + clock: Clock = SYSTEM_CLOCK, + ) -> None: + if tolerance_seconds < 0: + raise ValueError(f"tolerance_seconds must be non-negative, got {tolerance_seconds}") + self._key: Final[bytes] = _decode_secret(secret) + self._tolerance_seconds: Final[int] = tolerance_seconds + self._clock: Final[Clock] = clock + + def verify(self, headers: Mapping[str, str], body: str | bytes) -> None: + """Verify a webhook delivery, raising on any failure. + + Looks up the ``webhook-id`` / ``webhook-timestamp`` / ``webhook- + signature`` headers (case-insensitively), checks the timestamp against + the tolerance window, recomputes the expected signature over + ``{id}.{timestamp}.{body}``, and compares it constant-time against each + provided signature token. Returns normally when at least one token + matches. + + Args: + headers: Inbound request headers. Lookups are case-insensitive. + body: The raw request body, exactly as received. Passing a + re-serialized form risks a byte mismatch and a spurious + failure, so prefer the original bytes. + + Raises: + InvalidWebhookSignatureError: If a required header is missing, the + timestamp is malformed or outside the tolerance window, or no + provided signature matches. + """ + webhook_id = _require_header(headers, _ID_HEADER) + timestamp = _require_header(headers, _TIMESTAMP_HEADER) + signature_header = _require_header(headers, _SIGNATURE_HEADER) + + self._check_timestamp(timestamp) + expected = self._sign(webhook_id, timestamp, _as_bytes(body)) + if not self._matches_any(signature_header, expected): + raise InvalidWebhookSignatureError("no matching signature") + + def unwrap(self, headers: Mapping[str, str], body: str | bytes) -> object: + """Verify a webhook and return its parsed JSON payload. + + A convenience over :meth:`verify` for the common case of a JSON body: + the signature is checked against the *raw* bytes first, then the same + bytes are parsed. Verifying before parsing guarantees only authentic + payloads are ever deserialized. + + Args: + headers: Inbound request headers. + body: The raw request body, exactly as received. + + Returns: + The decoded JSON value (typically a ``dict`` for an object body, + but any JSON value is returned as-is). + + Raises: + InvalidWebhookSignatureError: If verification fails or the verified + body is not valid JSON. + """ + raw = _as_bytes(body) + self.verify(headers, raw) + try: + return json.loads(raw) + except json.JSONDecodeError as exc: + raise InvalidWebhookSignatureError("verified body is not valid JSON") from exc + + def _check_timestamp(self, timestamp: str) -> None: + """Reject a timestamp that is malformed or outside the tolerance window. + + Raises: + InvalidWebhookSignatureError: If ``timestamp`` is not an integer or + differs from now by more than the configured tolerance. + """ + try: + signed_at = int(timestamp) + except ValueError as exc: + raise InvalidWebhookSignatureError("malformed webhook-timestamp") from exc + delta = self._clock.now() - signed_at + if delta > self._tolerance_seconds: + raise InvalidWebhookSignatureError("webhook timestamp is too old") + if delta < -self._tolerance_seconds: + raise InvalidWebhookSignatureError("webhook timestamp is in the future") + + def _sign(self, webhook_id: str, timestamp: str, body: bytes) -> str: + """Return the base64 ``HMAC-SHA256`` over ``{id}.{timestamp}.{body}``.""" + signed_content = b"%s.%s." % (webhook_id.encode("utf-8"), timestamp.encode("utf-8")) + digest = hmac.new(self._key, signed_content + body, hashlib.sha256).digest() + return base64.b64encode(digest).decode("ascii") + + def _matches_any(self, signature_header: str, expected: str) -> bool: + """Return whether any ``v1,`` token equals ``expected``. + + Tokens are space-separated. Unknown-version and malformed tokens are + skipped rather than raising, so a forward-compatible producer that adds + a future scheme version alongside ``v1`` still verifies. Comparison is + constant-time to avoid leaking the expected signature via timing. + """ + expected_bytes = expected.encode("ascii") + matched = False + for token in signature_header.split(" "): + version, _, candidate = token.partition(",") + if version != _SIGNATURE_VERSION or not candidate: + continue + if hmac.compare_digest(candidate.encode("ascii"), expected_bytes): + matched = True + return matched diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/__init__.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/__init__.py index ce8d7be..6f09a69 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/__init__.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/__init__.py @@ -14,7 +14,15 @@ from __future__ import annotations -from .client_logger import ClientLogger +from .client_logger import ClientLogger, CorrelationFilter +from .correlation import ( + bind_correlation, + get_span_id, + get_trace_id, + set_span_id, + set_trace_id, +) +from .http_tracer import HttpTracer, HttpTracerFactory from .identifiers import SpanId, TraceFlags, TraceId, TraceIdType, TraceState from .instrumentation_context import InstrumentationContext from .log_level import LogLevel @@ -28,7 +36,13 @@ MetricsContext, UpDownCounter, ) -from .noop import NOOP_INSTRUMENTATION_CONTEXT, NOOP_SPAN, NOOP_TRACER +from .noop import ( + NOOP_HTTP_TRACER, + NOOP_HTTP_TRACER_FACTORY, + NOOP_INSTRUMENTATION_CONTEXT, + NOOP_SPAN, + NOOP_TRACER, +) from .span import Span from .tracer import Tracer from .tracing_scope import TracingScope @@ -38,14 +52,19 @@ "DEFAULT_QUERY_ALLOWLIST", "NOOP_COUNTER", "NOOP_HISTOGRAM", + "NOOP_HTTP_TRACER", + "NOOP_HTTP_TRACER_FACTORY", "NOOP_INSTRUMENTATION_CONTEXT", "NOOP_METRICS_CONTEXT", "NOOP_SPAN", "NOOP_TRACER", "NOOP_UPDOWN_COUNTER", "ClientLogger", + "CorrelationFilter", "Counter", "Histogram", + "HttpTracer", + "HttpTracerFactory", "InstrumentationContext", "LogLevel", "MetricsContext", @@ -59,4 +78,9 @@ "TracingScope", "UpDownCounter", "UrlRedactor", + "bind_correlation", + "get_span_id", + "get_trace_id", + "set_span_id", + "set_trace_id", ] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/client_logger.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/client_logger.py index c998e65..eee0307 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/client_logger.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/client_logger.py @@ -8,6 +8,7 @@ import logging from typing import Any, Final +from .correlation import get_span_id, get_trace_id from .log_level import LogLevel _LEVEL_MAP: Final[dict[LogLevel, int]] = { @@ -18,6 +19,28 @@ } +class CorrelationFilter(logging.Filter): + """Stamps the active trace/span ids onto every record it sees. + + Reads the context-local ids from :mod:`correlation` and attaches them as + ``trace.id`` / ``span.id`` record attributes (plus the dotted-name-safe + ``trace_id`` / ``span_id`` aliases for ``%``-style format strings). When no + trace is bound the attributes are set to ``None`` so formatters referencing + them never raise. Because the ids live in ``contextvars``, this works across + ``await`` boundaries without any extra plumbing. + """ + + def filter(self, record: logging.LogRecord) -> bool: + """Attach correlation ids and always allow the record through.""" + trace_id = get_trace_id() + span_id = get_span_id() + setattr(record, "trace.id", trace_id) + setattr(record, "span.id", span_id) + record.trace_id = trace_id + record.span_id = span_id + return True + + class ClientLogger: """Thin facade over stdlib ``logging`` that emits structured key=value pairs. @@ -46,15 +69,15 @@ def __init__( self.name = name self._logger = logging.getLogger(name) self._static_fields = static_fields + _install_correlation_filter(self._logger) def log(self, level: LogLevel, message: str, **fields: Any) -> None: """Emit a structured log record at ``level``.""" py_level = _LEVEL_MAP[level] if not self._logger.isEnabledFor(py_level): return - self._logger.log( - py_level, "%s %s", message, _format_fields({**self._static_fields, **fields}) - ) + rendered = _format_fields({**self._static_fields, **_correlation_fields(), **fields}) + self._logger.log(py_level, "%s %s", message, rendered) def error(self, message: str, **fields: Any) -> None: """Emit a structured record at ``ERROR`` level.""" @@ -73,6 +96,25 @@ def verbose(self, message: str, **fields: Any) -> None: self.log(LogLevel.VERBOSE, message, **fields) +def _install_correlation_filter(logger: logging.Logger) -> None: + """Attach a :class:`CorrelationFilter` to ``logger`` exactly once.""" + if any(isinstance(existing, CorrelationFilter) for existing in logger.filters): + return + logger.addFilter(CorrelationFilter()) + + +def _correlation_fields() -> dict[str, str]: + """Return the bound trace/span ids as logfmt fields, omitting unset ones.""" + fields: dict[str, str] = {} + trace_id = get_trace_id() + if trace_id is not None: + fields["trace.id"] = trace_id + span_id = get_span_id() + if span_id is not None: + fields["span.id"] = span_id + return fields + + def _format_fields(fields: dict[str, Any]) -> str: parts: list[str] = [] for key, value in fields.items(): @@ -90,4 +132,4 @@ def _format_fields(fields: dict[str, Any]) -> str: return " ".join(parts) -__all__ = ["ClientLogger"] +__all__ = ["ClientLogger", "CorrelationFilter"] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/correlation.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/correlation.py new file mode 100644 index 0000000..d7839f0 --- /dev/null +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/correlation.py @@ -0,0 +1,101 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Context-local trace/span correlation identifiers. + +Holds the active trace id and span id in module-level ``contextvars`` so any +code on the same logical flow — including across ``await`` boundaries, which +asyncio propagates automatically — can read them without threading the values +through every call. The tracing policy sets them when it opens a span; +``ClientLogger`` reads them to stamp ``trace.id`` / ``span.id`` onto every log +record. + +Code that hops to a worker thread (``loop.run_in_executor``) does not inherit +the caller's context automatically; use :func:`bind_correlation` there to +re-establish the ids inside the worker. +""" + +from __future__ import annotations + +from contextlib import contextmanager +from contextvars import ContextVar +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Iterator + from contextvars import Token + +_trace_id: ContextVar[str | None] = ContextVar("dexpace_trace_id", default=None) +_span_id: ContextVar[str | None] = ContextVar("dexpace_span_id", default=None) + + +def get_trace_id() -> str | None: + """Return the active trace id, or ``None`` when no trace is bound.""" + return _trace_id.get() + + +def get_span_id() -> str | None: + """Return the active span id, or ``None`` when no span is bound.""" + return _span_id.get() + + +def set_trace_id(value: str | None) -> Token[str | None]: + """Set the active trace id. + + Args: + value: The trace id to bind, or ``None`` to clear it. + + Returns: + A reset token; pass it to ``ContextVar.reset`` to restore the prior + value. Prefer :func:`bind_correlation` for scoped use. + """ + return _trace_id.set(value) + + +def set_span_id(value: str | None) -> Token[str | None]: + """Set the active span id. + + Args: + value: The span id to bind, or ``None`` to clear it. + + Returns: + A reset token; pass it to ``ContextVar.reset`` to restore the prior + value. Prefer :func:`bind_correlation` for scoped use. + """ + return _span_id.set(value) + + +@contextmanager +def bind_correlation( + *, + trace_id: str | None = None, + span_id: str | None = None, +) -> Iterator[None]: + """Bind trace/span ids for the duration of the ``with`` block. + + Restores the previous ids on exit, even if the body raises. Only the + arguments that are passed are bound; omitted ids are left untouched. + + Args: + trace_id: Trace id to bind for the block, or ``None`` to clear it. + span_id: Span id to bind for the block, or ``None`` to clear it. + + Yields: + Nothing; use as a plain scope guard. + """ + trace_token = _trace_id.set(trace_id) + span_token = _span_id.set(span_id) + try: + yield + finally: + _span_id.reset(span_token) + _trace_id.reset(trace_token) + + +__all__ = [ + "bind_correlation", + "get_span_id", + "get_trace_id", + "set_span_id", + "set_trace_id", +] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/http_tracer.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/http_tracer.py new file mode 100644 index 0000000..114b213 --- /dev/null +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/http_tracer.py @@ -0,0 +1,122 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Fine-grained HTTP request telemetry callbacks. + +``HttpTracer`` is an event sink that pipeline policies notify at the moments +that matter to an operator: each attempt boundary, retry exhaustion, byte +counts on the wire, and connection acquisition. Every method is a no-op by +default, so a consumer overrides only the events their backend cares about and +inherits the rest. ``HttpTracerFactory`` mints a fresh tracer per logical +operation so per-call state (attempt counters, timers) stays isolated. + +Modeled on Google gax's ``ApiTracer``. The SDK ships the contract plus a no-op +default (:class:`NoopHttpTracer` / :data:`NOOP_HTTP_TRACER_FACTORY`); consumers +plug in a real implementation per their metrics/tracing stack. +""" + +from __future__ import annotations + +from abc import ABC +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from collections.abc import Mapping + + +class HttpTracer(ABC): + """Sink for fine-grained HTTP request events. + + One instance tracks a single logical operation across its (possibly + retried) attempts. Pipeline policies call these methods at the appropriate + moments; the default implementations do nothing, so subclasses override + only the events they consume. + + Implementations are notified from whatever thread or task drives the + request and should keep callbacks cheap and non-blocking; they must not + raise. + """ + + def operation_started(self) -> None: + """The overall operation began (before the first attempt).""" + + def operation_succeeded(self) -> None: + """The operation completed successfully.""" + + def operation_failed(self, error: BaseException) -> None: + """The operation failed permanently. + + Args: + error: The exception that terminated the operation. + """ + + def attempt_started(self, attempt: int) -> None: + """A new attempt began. + + Args: + attempt: Zero-based index of the attempt about to be sent. + """ + + def attempt_failed(self, error: BaseException, next_delay: float) -> None: + """An attempt failed and a retry is scheduled. + + Args: + error: The exception that failed the attempt. + next_delay: Seconds the policy will wait before the next attempt. + """ + + def attempt_retries_exhausted(self) -> None: + """The retry budget was exhausted; no further attempts will be made.""" + + def request_url_resolved(self, url: str) -> None: + """The final request URL was resolved (post-redirect). + + Args: + url: The absolute URL the attempt targets. + """ + + def request_sent(self, byte_count: int) -> None: + """The request body finished writing to the wire. + + Args: + byte_count: Number of body bytes written. + """ + + def response_headers_received(self, status: int, headers: Mapping[str, str]) -> None: + """Response status and headers arrived (before the body). + + Args: + status: HTTP status code of the response. + headers: The response headers. + """ + + def response_received(self, byte_count: int) -> None: + """The response body finished reading from the wire. + + Args: + byte_count: Number of body bytes read. + """ + + def connection_acquired(self, host: str, port: int) -> None: + """A transport connection was acquired for the attempt. + + Args: + host: Remote host the connection targets. + port: Remote port the connection targets. + """ + + +@runtime_checkable +class HttpTracerFactory(Protocol): + """Mints a fresh :class:`HttpTracer` per logical operation. + + Policies create one tracer at the start of each operation so per-call state + (attempt counters, timers) never leaks across operations. + """ + + def create(self) -> HttpTracer: + """Return a new tracer for one operation.""" + ... + + +__all__ = ["HttpTracer", "HttpTracerFactory"] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/instrumentation_context.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/instrumentation_context.py index c480894..ebdccaa 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/instrumentation_context.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/instrumentation_context.py @@ -5,15 +5,27 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING from .identifiers import SpanId, TraceFlags, TraceId, TraceIdType, TraceState if TYPE_CHECKING: + from .http_tracer import HttpTracerFactory from .span import Span +def _default_http_tracer_factory() -> HttpTracerFactory: + """Return the shared no-op tracer factory. + + Imported lazily to avoid a circular import: ``noop`` imports this module to + build the no-op context singleton. + """ + from .noop import NOOP_HTTP_TRACER_FACTORY + + return NOOP_HTTP_TRACER_FACTORY + + @dataclass(frozen=True, slots=True) class InstrumentationContext: """Metadata carried with every traced operation. @@ -25,6 +37,10 @@ class InstrumentationContext: The shared no-op singleton :data:`NOOP_INSTRUMENTATION_CONTEXT` is used when tracing is disabled. + + ``http_tracer_factory`` mints a per-operation :class:`HttpTracer` for + fine-grained request telemetry; it defaults to the no-op factory so callers + that don't instrument pay nothing. """ trace_id_type: TraceIdType @@ -34,6 +50,7 @@ class InstrumentationContext: trace_flags: TraceFlags = TraceFlags.NOOP trace_state: TraceState = TraceState.NOOP is_remote: bool = False + http_tracer_factory: HttpTracerFactory = field(default_factory=_default_http_tracer_factory) @property def is_valid(self) -> bool: diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/noop.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/noop.py index fb9669a..61eb6c6 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/noop.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/noop.py @@ -7,6 +7,7 @@ from typing import Any, Final +from .http_tracer import HttpTracer, HttpTracerFactory from .identifiers import SpanId, TraceFlags, TraceId, TraceIdType, TraceState from .instrumentation_context import InstrumentationContext from .span import Span @@ -25,6 +26,26 @@ def close(self) -> None: _NOOP_SCOPE = _NoopScope() +class _NoopHttpTracer(HttpTracer): + """No-op :class:`HttpTracer` — every event callback inherits the do-nothing + default. Use the shared :data:`NOOP_HTTP_TRACER` singleton.""" + + +#: Shared no-op :class:`HttpTracer` singleton. Use when tracing is disabled. +NOOP_HTTP_TRACER: Final[HttpTracer] = _NoopHttpTracer() + + +class _NoopHttpTracerFactory: + """No-op tracer factory — every :meth:`create` returns :data:`NOOP_HTTP_TRACER`.""" + + def create(self) -> HttpTracer: + return NOOP_HTTP_TRACER + + +#: Shared no-op ``HttpTracerFactory`` singleton. +NOOP_HTTP_TRACER_FACTORY: Final[HttpTracerFactory] = _NoopHttpTracerFactory() + + class _NoopSpan(Span): """No-op :class:`Span` — records nothing, returns ``self`` from every mutator. @@ -66,6 +87,7 @@ def end(self, error: BaseException | None = None) -> None: trace_flags=TraceFlags.NOOP, trace_state=TraceState.NOOP, is_remote=False, + http_tracer_factory=NOOP_HTTP_TRACER_FACTORY, ) @@ -84,4 +106,10 @@ def start_span( NOOP_TRACER: Final[Tracer] = _NoopTracer() -__all__ = ["NOOP_INSTRUMENTATION_CONTEXT", "NOOP_SPAN", "NOOP_TRACER"] +__all__ = [ + "NOOP_HTTP_TRACER", + "NOOP_HTTP_TRACER_FACTORY", + "NOOP_INSTRUMENTATION_CONTEXT", + "NOOP_SPAN", + "NOOP_TRACER", +] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/__init__.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/__init__.py new file mode 100644 index 0000000..c42310c --- /dev/null +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/__init__.py @@ -0,0 +1,59 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Auto-pagination — drive a paged API through the pipeline. + +The public surface is a small set of pieces that compose: + +* :class:`Page` — a frozen page of items plus the request that reaches the + next page (and, when supported, the previous one). It is a context manager + so the underlying response closes deterministically. +* :class:`PaginationStrategy` — the SPI that turns one decoded response into a + :class:`Page`. Built-ins: :class:`CursorStrategy` (cursor / token), + :class:`PageNumberStrategy` (page index), and :class:`LinkHeaderStrategy` + (RFC 5988 ``Link`` header). +* :class:`Paginator` / :class:`AsyncPaginator` — iterate the sequence + item-by-item by default, or page-by-page via ``by_page``. Each page fetch + runs through the full pipeline, so retry, auth, redirect, and tracing apply + to every page. +* :func:`parse_link_header` / :func:`find_rel` — the standalone RFC 5988 + parser the link strategy is built on. +""" + +from __future__ import annotations + +from .link_header import ParsedLink, find_rel, parse_link_header +from .page import Page +from .paginator import ( + AsyncPaginator, + AsyncPipelineLike, + Paginator, + SendAsync, + SendSync, + SyncPipelineLike, +) +from .strategy import ( + CursorStrategy, + HasHeaders, + LinkHeaderStrategy, + PageNumberStrategy, + PaginationStrategy, +) + +__all__ = [ + "AsyncPaginator", + "AsyncPipelineLike", + "CursorStrategy", + "HasHeaders", + "LinkHeaderStrategy", + "Page", + "PageNumberStrategy", + "PaginationStrategy", + "Paginator", + "ParsedLink", + "SendAsync", + "SendSync", + "SyncPipelineLike", + "find_rel", + "parse_link_header", +] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/link_header.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/link_header.py new file mode 100644 index 0000000..c888ad8 --- /dev/null +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/link_header.py @@ -0,0 +1,149 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""RFC 5988 ``Link`` header parser — pure string logic, no I/O. + +A ``Link`` header carries one or more comma-separated link-values, each a +URI-Reference in angle brackets followed by semicolon-separated parameters:: + + Link: ; rel="next", + ; rel="last" + +This module parses that grammar into ``(target, params)`` pairs and exposes a +convenience lookup keyed by ``rel`` value. It is deliberately standalone: the +standard library has no equivalent, and the paginator's link-header strategy +depends only on this pure function. +""" + +from __future__ import annotations + +from collections.abc import Iterator + +#: A parsed link-value: the bracketed target URI plus its lower-cased +#: parameter map (parameter names are case-insensitive per RFC 5988 §5). +type ParsedLink = tuple[str, dict[str, str]] + + +def parse_link_header(value: str) -> tuple[ParsedLink, ...]: + """Parse an RFC 5988 ``Link`` header into its link-values. + + Args: + value: The raw header value (without the ``Link:`` name). An empty or + whitespace-only string yields an empty result. + + Returns: + One ``(target, params)`` pair per link-value, in source order. + ``params`` keys are lower-cased; quoted values are unquoted. + """ + return tuple(_iter_links(value)) + + +def find_rel(value: str, rel: str) -> str | None: + """Return the target URI of the first link-value whose ``rel`` matches. + + The ``rel`` parameter is a space-separated set of relation types + (RFC 5988 §5.3); a match succeeds when ``rel`` appears as one of those + types. Comparison is case-insensitive. + + Args: + value: The raw ``Link`` header value. + rel: The relation type to look for (e.g. ``"next"``). + + Returns: + The matching target URI, or ``None`` when no link-value carries the + requested relation. + """ + wanted = rel.casefold() + for target, params in _iter_links(value): + rels = params.get("rel", "") + if any(token.casefold() == wanted for token in rels.split()): + return target + return None + + +def _iter_links(value: str) -> Iterator[ParsedLink]: + for segment in _split_links(value): + parsed = _parse_link_value(segment) + if parsed is not None: + yield parsed + + +def _split_links(value: str) -> Iterator[str]: + """Split a header into link-value segments, ignoring commas inside quotes.""" + buffer: list[str] = [] + in_quotes = False + escaped = False + for char in value: + if escaped: + buffer.append(char) + escaped = False + elif char == "\\": + buffer.append(char) + escaped = True + elif char == '"': + in_quotes = not in_quotes + buffer.append(char) + elif char == "," and not in_quotes: + yield "".join(buffer) + buffer = [] + else: + buffer.append(char) + if buffer: + yield "".join(buffer) + + +def _parse_link_value(segment: str) -> ParsedLink | None: + segment = segment.strip() + if not segment.startswith("<"): + return None + end = segment.find(">") + if end < 0: + return None + target = segment[1:end].strip() + params = _parse_params(segment[end + 1 :]) + return target, params + + +def _parse_params(raw: str) -> dict[str, str]: + params: dict[str, str] = {} + for part in _split_params(raw): + name, sep, val = part.partition("=") + name = name.strip().casefold() + if not name or not sep: + continue + params[name] = _unquote(val.strip()) + return params + + +def _split_params(raw: str) -> Iterator[str]: + buffer: list[str] = [] + in_quotes = False + escaped = False + for char in raw: + if escaped: + buffer.append(char) + escaped = False + elif char == "\\": + buffer.append(char) + escaped = True + elif char == '"': + in_quotes = not in_quotes + buffer.append(char) + elif char == ";" and not in_quotes: + if buffer: + yield "".join(buffer) + buffer = [] + else: + buffer.append(char) + if buffer: + yield "".join(buffer) + + +def _unquote(value: str) -> str: + if len(value) >= 2 and value[0] == '"' and value[-1] == '"': + inner = value[1:-1] + return inner.replace('\\"', '"').replace("\\\\", "\\") + return value + + +__all__ = ["ParsedLink", "find_rel", "parse_link_header"] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/page.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/page.py new file mode 100644 index 0000000..47d641a --- /dev/null +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/page.py @@ -0,0 +1,101 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""A single page of results produced by a ``PaginationStrategy``.""" + +from __future__ import annotations + +from collections.abc import Coroutine, Iterator, Sequence +from dataclasses import dataclass, field +from inspect import isawaitable +from types import TracebackType +from typing import TYPE_CHECKING, Self, cast + +if TYPE_CHECKING: + from ..http.request.request import Request + + +@dataclass(frozen=True, slots=True) +class Page[T]: + """One page of items plus the requests that reach its neighbours. + + A strategy parses a response into a ``Page``: the ``items`` it carried, + the ``next_request`` to fetch the following page (``None`` at the end of + the sequence), and — when the API supports backward paging — an optional + ``prev_request``. The originating response is retained as ``raw`` so the + page can be used as a context manager that releases the underlying + connection on exit. + + Both the synchronous and asynchronous context-manager protocols are + implemented. ``__exit__`` closes a synchronous response; ``__aexit__`` + awaits an asynchronous one. Closing is idempotent and tolerates a ``raw`` + that exposes neither hook (e.g. a hand-built page in a test). + + Attributes: + items: The items on this page, in server order. + next_request: Request that fetches the next page, or ``None`` when + this is the final page. + prev_request: Request that fetches the previous page, when the API + exposes one; ``None`` otherwise. + raw: The originating response object (kept for connection cleanup and + for callers that need headers / status off the underlying page). + """ + + items: Sequence[T] + next_request: Request | None = None + prev_request: Request | None = None + raw: object | None = field(default=None, compare=False) + + @property + def has_next(self) -> bool: + """Whether a further page is reachable from this one.""" + return self.next_request is not None + + def close(self) -> None: + """Close the underlying synchronous response, if any. Idempotent.""" + close = getattr(self.raw, "close", None) + if close is None: + return + result = close() + if isawaitable(result): + # An async response was stored on a sync page; it cannot be + # awaited here, so close the coroutine to avoid a "never awaited" + # warning and defer real cleanup to the async exit path. + cast(Coroutine[object, object, object], result).close() + + async def aclose(self) -> None: + """Close the underlying asynchronous response, if any. Idempotent.""" + close = getattr(self.raw, "close", None) + if close is None: + return + result = close() + if isawaitable(result): + await result + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.close() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + await self.aclose() + + def __iter__(self) -> Iterator[T]: + return iter(self.items) + + +__all__ = ["Page"] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/paginator.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/paginator.py new file mode 100644 index 0000000..8d2537c --- /dev/null +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/paginator.py @@ -0,0 +1,216 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Paginators — iterate a paged operation item-by-item or page-by-page. + +A paginator drives the **pipeline**, not a bare transport: each page fetch is +sent through the full policy chain, so retry, auth-refresh, redirect, and +tracing apply to every page automatically. The caller supplies either a +``Pipeline`` / ``AsyncPipeline`` (the paginator runs it with a fresh dispatch +context per page) or a plain send-callable for full control. + +Iteration is item-by-item by default:: + + for item in Paginator(pipeline, strategy, first_request): + ... + +Use :meth:`Paginator.by_page` when the raw response or page boundaries matter:: + + for page in Paginator(pipeline, strategy, first_request).by_page(): + with page: + process(page.items, page.raw) + +The optional ``max_pages`` guard bounds how many pages are fetched — essential +when draining an unbounded server-side sequence. +""" + +from __future__ import annotations + +import json +from collections.abc import AsyncIterator, Awaitable, Callable, Iterator +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +from ..http.context.dispatch_context import DispatchContext +from .page import Page + +if TYPE_CHECKING: + from ..http.request.request import Request + from ..http.response.async_response import AsyncResponse + from ..http.response.response import Response + from .strategy import PaginationStrategy + +#: A callable that sends one request through the pipeline and returns its +#: response. The paginator builds one of these from a pipeline when given one, +#: or the caller passes their own for full dispatch control. +type SendSync = Callable[["Request"], "Response"] +type SendAsync = Callable[["Request"], Awaitable["AsyncResponse"]] + + +@runtime_checkable +class SyncPipelineLike(Protocol): + """Structural view of a sync pipeline: just the ``run`` entry point. + + ``Pipeline`` satisfies this; the paginator depends only on the structural + shape so it stays decoupled from the concrete pipeline class (and so test + doubles can stand in without subclassing). ``runtime_checkable`` so the + paginator can tell a pipeline from a bare send-callable at construction. + """ + + def run(self, request: Request, dispatch: DispatchContext) -> Response: ... + + +@runtime_checkable +class AsyncPipelineLike(Protocol): + """Structural view of an async pipeline: just its ``run`` coroutine.""" + + async def run(self, request: Request, dispatch: DispatchContext) -> AsyncResponse: ... + + +def _decode_body(raw: str) -> object: + """Decode a JSON body string into a Python value (``None`` when empty).""" + text = raw.strip() + if not text: + return None + return json.loads(text) + + +class Paginator[T]: + """Synchronous paginator over a strategy-defined page sequence. + + Args: + source: Either a ``Pipeline`` (run once per page with a fresh + dispatch context) or a send-callable ``Request -> Response``. + strategy: The :class:`PaginationStrategy` that parses each response + into a :class:`Page`. + initial_request: The request that fetches the first page. + max_pages: Optional cap on the number of pages fetched. ``None`` + means unbounded (drive it with care against open-ended APIs). + dispatch_factory: Builds the dispatch context for each page when + ``source`` is a ``Pipeline``. Defaults to ``DispatchContext.noop``. + """ + + __slots__ = ("_dispatch_factory", "_initial", "_max_pages", "_send", "_strategy") + + def __init__( + self, + source: SyncPipelineLike | SendSync, + strategy: PaginationStrategy[T], + initial_request: Request, + *, + max_pages: int | None = None, + dispatch_factory: Callable[[], DispatchContext] | None = None, + ) -> None: + self._strategy = strategy + self._initial = initial_request + self._max_pages = max_pages + self._dispatch_factory = dispatch_factory or DispatchContext.noop + self._send = self._normalise(source) + + def _normalise(self, source: SyncPipelineLike | SendSync) -> SendSync: + if isinstance(source, SyncPipelineLike): + pipeline = source + + def send(request: Request) -> Response: + return pipeline.run(request, self._dispatch_factory()) + + return send + return source + + def by_page(self) -> Iterator[Page[T]]: + """Yield each :class:`Page` in turn, honouring ``max_pages``. + + Yields: + Pages from first to last. Each page owns its response; iterate + within a ``with page:`` block, or call ``page.close()``, to + release the connection promptly. + """ + request: Request | None = self._initial + count = 0 + while request is not None: + if self._max_pages is not None and count >= self._max_pages: + return + response = self._send(request) + page = self._parse(response) + count += 1 + yield page + request = page.next_request + + def _parse(self, response: Response) -> Page[T]: + payload = _decode_body(response.body.string()) if response.body is not None else None + return self._strategy.parse(response, payload, response.request) + + def __iter__(self) -> Iterator[T]: + for page in self.by_page(): + with page: + yield from page.items + + +class AsyncPaginator[T]: + """Asynchronous twin of :class:`Paginator`. + + Mirrors the sync paginator exactly with ``async`` iteration semantics. + ``source`` is an ``AsyncPipeline`` or an async send-callable. + """ + + __slots__ = ("_dispatch_factory", "_initial", "_max_pages", "_send", "_strategy") + + def __init__( + self, + source: AsyncPipelineLike | SendAsync, + strategy: PaginationStrategy[T], + initial_request: Request, + *, + max_pages: int | None = None, + dispatch_factory: Callable[[], DispatchContext] | None = None, + ) -> None: + self._strategy = strategy + self._initial = initial_request + self._max_pages = max_pages + self._dispatch_factory = dispatch_factory or DispatchContext.noop + self._send = self._normalise(source) + + def _normalise(self, source: AsyncPipelineLike | SendAsync) -> SendAsync: + if isinstance(source, AsyncPipelineLike): + pipeline = source + + async def send(request: Request) -> AsyncResponse: + return await pipeline.run(request, self._dispatch_factory()) + + return send + return source + + async def by_page(self) -> AsyncIterator[Page[T]]: + """Async-yield each :class:`Page` in turn, honouring ``max_pages``.""" + request: Request | None = self._initial + count = 0 + while request is not None: + if self._max_pages is not None and count >= self._max_pages: + return + response = await self._send(request) + page = await self._parse(response) + count += 1 + yield page + request = page.next_request + + async def _parse(self, response: AsyncResponse) -> Page[T]: + payload = _decode_body(await response.body.string()) if response.body is not None else None + return self._strategy.parse(response, payload, response.request) + + def __aiter__(self) -> AsyncIterator[T]: + return self._iterate_items() + + async def _iterate_items(self) -> AsyncIterator[T]: + async for page in self.by_page(): + async with page: + for item in page.items: + yield item + + +__all__ = [ + "AsyncPaginator", + "AsyncPipelineLike", + "Paginator", + "SendAsync", + "SendSync", + "SyncPipelineLike", +] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/strategy.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/strategy.py new file mode 100644 index 0000000..ece11e7 --- /dev/null +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/strategy.py @@ -0,0 +1,266 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Pagination strategies — pure response-to-``Page`` translators. + +A strategy is the only place that understands a particular API's pagination +convention. It is deliberately I/O-free: the paginator performs the request, +decodes the body into a plain Python value, and hands the strategy that value +together with the response (for header inspection) and the template request +(to derive the next page's request from). Keeping strategies pure lets the +sync and async paginators share them verbatim. + +Three built-ins cover the common conventions: + +* ``CursorStrategy`` — read an opaque cursor / continuation token out of the + response body and resend it as a query parameter. One strategy covers both + "cursor" and "token" pagination; they differ only in field names. +* ``PageNumberStrategy`` — increment a page-index query parameter until the + server reports no more items (or an optional total-pages field is reached). +* ``LinkHeaderStrategy`` — follow the RFC 5988 ``Link`` header's ``rel="next"`` + target. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Protocol, cast, runtime_checkable +from urllib.parse import urljoin + +from ..http.common.url import Url +from .link_header import find_rel +from .page import Page + +if TYPE_CHECKING: + from ..http.common.headers import Headers + from ..http.request.request import Request + + +@runtime_checkable +class HasHeaders(Protocol): + """Structural view of a response: just the header surface a strategy reads.""" + + @property + def headers(self) -> Headers: ... + + +@runtime_checkable +class PaginationStrategy[T](Protocol): + """Translates one decoded response into a :class:`Page`. + + Implementations are pure functions of their inputs — they perform no I/O + and hold no mutable per-iteration state, so a single strategy instance is + safe to reuse across both the sync and async paginators. + """ + + def parse( + self, + response: HasHeaders, + payload: object, + template_request: Request, + ) -> Page[T]: + """Build the page that ``response`` represents. + + Args: + response: The response object, used for header inspection (e.g. + the ``Link`` header). Its body must already be decoded into + ``payload`` by the caller. + payload: The decoded response body (typically a ``dict`` or + ``list`` from JSON). ``None`` when the body was empty. + template_request: The request that produced ``response``; the + next page's request is derived from it so auth headers and + path survive. + + Returns: + A page whose ``next_request`` is ``None`` when the sequence is + exhausted. + """ + ... + + +def _dig(payload: object, path: Sequence[str]) -> object: + """Walk a dotted key path into nested mappings; ``None`` if any step misses.""" + current = payload + for key in path: + if not isinstance(current, dict) or key not in current: + return None + current = current[key] + return current + + +def _items_at[T](payload: object, path: Sequence[str]) -> list[T]: + found = _dig(payload, path) + if isinstance(found, list): + return cast("list[T]", found) + return [] + + +def _with_query_param(request: Request, name: str, value: str) -> Request: + """Return ``request`` with query parameter ``name`` set to ``value``.""" + url = request.url + return request.with_url(url.with_query(url.query.with_set(name, value))) + + +@dataclass(frozen=True, slots=True) +class CursorStrategy[T]: + """Cursor / continuation-token pagination. + + Reads a cursor value out of the response body and resends it as a query + parameter on the next request. Covers both the "cursor" convention + (opaque string under e.g. ``next_cursor``) and the "token" convention + (``next_page_token``) — they differ only in field names, supplied here. + + Args: + items_field: Dotted path to the item list in the body (e.g. + ``"data"`` or ``"result.items"``). + cursor_response_field: Dotted path to the cursor in the body. An + absent, empty, or ``null`` value ends the sequence. + cursor_param: Query-parameter name to carry the cursor on the next + request. + """ + + items_field: str = "items" + cursor_response_field: str = "next_cursor" + cursor_param: str = "cursor" + + def parse( + self, + response: HasHeaders, + payload: object, + template_request: Request, + ) -> Page[T]: + items: list[T] = _items_at(payload, self.items_field.split(".")) + cursor = _dig(payload, self.cursor_response_field.split(".")) + next_request: Request | None = None + if isinstance(cursor, str) and cursor: + next_request = _with_query_param(template_request, self.cursor_param, cursor) + return Page(items=items, next_request=next_request, raw=response) + + +@dataclass(frozen=True, slots=True) +class PageNumberStrategy[T]: + """Page-index pagination. + + Increments a numeric ``page`` query parameter each round. Termination is + determined by the body: when the current page yields fewer items than + ``page_size`` (when known) or an empty list, there is no next page. When + ``total_pages_field`` is given, it is honoured as an explicit bound. + + Args: + items_field: Dotted path to the item list in the body. + page_param: Query-parameter name carrying the 1-based page index. + start_page: The index of the first page (default ``1``). + page_size: Expected items per full page; a short page signals the + end. ``None`` disables the short-page heuristic. + total_pages_field: Optional dotted path to a total-page count in the + body; when present it bounds iteration explicitly. + """ + + items_field: str = "items" + page_param: str = "page" + start_page: int = 1 + page_size: int | None = None + total_pages_field: str | None = None + + def parse( + self, + response: HasHeaders, + payload: object, + template_request: Request, + ) -> Page[T]: + items: list[T] = _items_at(payload, self.items_field.split(".")) + current = self._current_page(template_request) + if self._is_last_page(items, payload, current): + return Page(items=items, next_request=None, raw=response) + next_request = _with_query_param(template_request, self.page_param, str(current + 1)) + return Page(items=items, next_request=next_request, raw=response) + + def _current_page(self, request: Request) -> int: + raw = request.url.query.get(self.page_param) + if raw is None: + return self.start_page + try: + return int(raw) + except ValueError: + return self.start_page + + def _is_last_page( + self, + items: Sequence[object], + payload: object, + current: int, + ) -> bool: + if not items: + return True + if self.total_pages_field is not None: + total = _dig(payload, self.total_pages_field.split(".")) + if isinstance(total, int): + return current >= total + if self.page_size is not None: + return len(items) < self.page_size + return False + + +@dataclass(frozen=True, slots=True) +class LinkHeaderStrategy[T]: + """RFC 5988 ``Link``-header pagination. + + Follows the ``rel="next"`` target in the response's ``Link`` header and, + when present, exposes the ``rel="prev"`` target as the page's previous + request. The next request reuses the template request's method, headers, + and body, swapping only the URL. A relative target (permitted by + RFC 5988) is resolved against the template request's URL, so an API that + returns ```` rather than an absolute URI still paginates. + + Args: + items_field: Dotted path to the item list in the body. + link_header_name: Header to read link relations from (default + ``"Link"``). + """ + + items_field: str = "items" + link_header_name: str = "Link" + + def parse( + self, + response: HasHeaders, + payload: object, + template_request: Request, + ) -> Page[T]: + items: list[T] = _items_at(payload, self.items_field.split(".")) + header = response.headers.get(self.link_header_name) or "" + next_request = self._request_for(header, "next", template_request) + prev_request = self._request_for(header, "prev", template_request) + return Page( + items=items, + next_request=next_request, + prev_request=prev_request, + raw=response, + ) + + @staticmethod + def _request_for(header: str, rel: str, template: Request) -> Request | None: + target = find_rel(header, rel) + if target is None: + return None + absolute = urljoin(str(template.url), target) + return template.with_url(Url.parse(absolute)) + + +if TYPE_CHECKING: + # Static structural-conformance checks: each built-in must satisfy the + # ``PaginationStrategy`` Protocol. Inheriting the Protocol would defeat + # ``slots=True`` (it pulls in ``__dict__``), so we verify by assignment. + _cursor_conforms: PaginationStrategy[object] = CursorStrategy() + _page_conforms: PaginationStrategy[object] = PageNumberStrategy() + _link_conforms: PaginationStrategy[object] = LinkHeaderStrategy() + + +__all__ = [ + "CursorStrategy", + "HasHeaders", + "LinkHeaderStrategy", + "PageNumberStrategy", + "PaginationStrategy", +] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/defaults.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/defaults.py index e015767..2d30f25 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/defaults.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/defaults.py @@ -8,9 +8,13 @@ from typing import TYPE_CHECKING from .async_staged_builder import AsyncStagedPipelineBuilder +from .policies.async_client_identity import AsyncClientIdentityPolicy +from .policies.async_idempotency import AsyncIdempotencyPolicy from .policies.async_redirect import AsyncRedirectPolicy from .policies.async_retry import AsyncRetryPolicy from .policies.async_set_date import AsyncSetDatePolicy +from .policies.client_identity import ClientIdentityPolicy +from .policies.idempotency import IdempotencyPolicy from .policies.logging_policy import LoggingPolicy from .policies.redirect import RedirectPolicy from .policies.retry import RetryPolicy @@ -29,8 +33,10 @@ def default_pipeline( client: HttpClient, *, redirect: RedirectPolicy | None = None, + idempotency: IdempotencyPolicy | None = None, retry: RetryPolicy | None = None, set_date: SetDatePolicy | None = None, + client_identity: ClientIdentityPolicy | None = None, auth: Policy | None = None, logging: LoggingPolicy | None = None, tracing: TracingPolicy | None = None, @@ -38,15 +44,24 @@ def default_pipeline( """Pre-configured :class:`StagedPipelineBuilder` with the canonical stack. Wires the policies that most consumers want by default in the order their - stages dictate: redirect → retry → set-date → auth → logging → tracing. - Each policy is opt-out (pass ``None``) or opt-in-with-override (pass a - pre-configured instance to replace the default). + stages dictate: redirect → idempotency → retry → set-date → + client-identity → auth → logging → tracing. Each policy is opt-out (pass + ``None``) or opt-in-with-override (pass a pre-configured instance to + replace the default). + + Idempotency sits before retry so a write request's ``Idempotency-Key`` is + minted once and reused across every retry; ``set-date`` and + ``client-identity`` sit just inside the retry wrapper. Args: client: Terminal HTTP transport. redirect: Override for :class:`RedirectPolicy`. ``None`` uses defaults. + idempotency: Override for :class:`IdempotencyPolicy`. ``None`` uses + defaults. retry: Override for :class:`RetryPolicy`. ``None`` uses defaults. set_date: Override for :class:`SetDatePolicy`. ``None`` uses defaults. + client_identity: Override for :class:`ClientIdentityPolicy`. ``None`` + uses defaults. auth: Optional authentication policy (``BearerTokenPolicy``, ``BasicAuthPolicy``, ``KeyCredentialPolicy``, etc.). No default — requests pass without authentication when this is ``None``. @@ -59,8 +74,10 @@ def default_pipeline( """ builder = StagedPipelineBuilder(client) builder.append(redirect or RedirectPolicy()) + builder.append(idempotency or IdempotencyPolicy()) builder.append(retry or RetryPolicy()) builder.append(set_date or SetDatePolicy()) + builder.append(client_identity or ClientIdentityPolicy()) if auth is not None: builder.append(auth) builder.append(logging or LoggingPolicy()) @@ -72,8 +89,10 @@ def default_async_pipeline( client: AsyncHttpClient, *, redirect: AsyncRedirectPolicy | None = None, + idempotency: AsyncIdempotencyPolicy | None = None, retry: AsyncRetryPolicy | None = None, set_date: AsyncSetDatePolicy | None = None, + client_identity: AsyncClientIdentityPolicy | None = None, auth: AsyncPolicy | None = None, ) -> AsyncStagedPipelineBuilder: """Async twin of :func:`default_pipeline`. @@ -84,8 +103,10 @@ def default_async_pipeline( """ builder = AsyncStagedPipelineBuilder(client) builder.append(redirect or AsyncRedirectPolicy()) + builder.append(idempotency or AsyncIdempotencyPolicy()) builder.append(retry or AsyncRetryPolicy()) builder.append(set_date or AsyncSetDatePolicy()) + builder.append(client_identity or AsyncClientIdentityPolicy()) if auth is not None: builder.append(auth) return builder diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/__init__.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/__init__.py index faff9db..23ac7c1 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/__init__.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/__init__.py @@ -6,9 +6,13 @@ from __future__ import annotations from ._history import RequestHistory +from .async_client_identity import AsyncClientIdentityPolicy +from .async_idempotency import AsyncIdempotencyPolicy from .async_redirect import AsyncRedirectPolicy from .async_retry import AsyncRetryPolicy from .async_set_date import AsyncSetDatePolicy +from .client_identity import ClientIdentityPolicy, default_user_agent +from .idempotency import IdempotencyPolicy from .logging_policy import LoggingPolicy from .redirect import RedirectPolicy from .retry import RetryMode, RetryPolicy @@ -16,9 +20,13 @@ from .tracing_policy import TracingPolicy __all__ = [ + "AsyncClientIdentityPolicy", + "AsyncIdempotencyPolicy", "AsyncRedirectPolicy", "AsyncRetryPolicy", "AsyncSetDatePolicy", + "ClientIdentityPolicy", + "IdempotencyPolicy", "LoggingPolicy", "RedirectPolicy", "RequestHistory", @@ -26,4 +34,5 @@ "RetryPolicy", "SetDatePolicy", "TracingPolicy", + "default_user_agent", ] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_client_identity.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_client_identity.py new file mode 100644 index 0000000..dd3f012 --- /dev/null +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_client_identity.py @@ -0,0 +1,74 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Async twin of :class:`ClientIdentityPolicy`.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar, Literal + +from ...http.common.http_header_name import USER_AGENT +from ..async_policy import AsyncPolicy +from ..stage import Stage +from .client_identity import default_user_agent + +if TYPE_CHECKING: + from ...http.request.request import Request + from ...http.response.async_response import AsyncResponse + from ..context import PipelineContext + + +class AsyncClientIdentityPolicy(AsyncPolicy): + """Async variant of :class:`ClientIdentityPolicy`. + + Behaviour mirrors the sync twin: the ``User-Agent`` defaults to + :func:`default_user_agent`, append-vs-replace is selectable, and the token + is guaranteed non-blank. Building the value is synchronous, so :meth:`send` + differs only in the ``await`` on the downstream call. + + Attributes: + STAGE: Pinned to :attr:`Stage.POST_RETRY` at the type level so + mis-slotting is caught by ``mypy``. + """ + + STAGE: ClassVar[Literal[Stage.POST_RETRY]] = Stage.POST_RETRY + __slots__ = ("_replace", "_user_agent") + + def __init__(self, *, user_agent: str | None = None, replace: bool = False) -> None: + """Build the policy. + + Args: + user_agent: ``User-Agent`` token to stamp. ``None`` (the default) + uses :func:`default_user_agent`. An empty or whitespace-only + value is rejected so the header is never blank. + replace: When ``True``, overwrite any caller-set ``User-Agent``. + When ``False`` (the default), append after the caller's value. + + Raises: + ValueError: If ``user_agent`` is provided but empty or whitespace. + """ + resolved = default_user_agent() if user_agent is None else user_agent + if not resolved.strip(): + raise ValueError("user_agent must be a non-empty token string") + self._user_agent = resolved + self._replace = replace + + async def send(self, request: Request, ctx: PipelineContext) -> AsyncResponse: + """Stamp ``request`` with the ``User-Agent`` header and dispatch. + + Args: + request: Outgoing request. + ctx: Pipeline context, forwarded unchanged. + + Returns: + The response from the downstream chain. + """ + existing = request.headers.get(USER_AGENT) + if self._replace or not existing or not existing.strip(): + value = self._user_agent + else: + value = f"{existing} {self._user_agent}" + return await self.next.send(request.with_header(USER_AGENT, value), ctx) + + +__all__ = ["AsyncClientIdentityPolicy"] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_idempotency.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_idempotency.py new file mode 100644 index 0000000..2f60355 --- /dev/null +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_idempotency.py @@ -0,0 +1,81 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Async twin of :class:`IdempotencyPolicy`.""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING, ClassVar, Literal + +from ...http.request.method import Method +from ..async_policy import AsyncPolicy +from ..stage import Stage +from .idempotency import _generate_key + +if TYPE_CHECKING: + from collections.abc import Iterable + + from ...http.request.request import Request + from ...http.response.async_response import AsyncResponse + from ..context import PipelineContext + +_DEFAULT_HEADER = "Idempotency-Key" +_DEFAULT_METHODS = frozenset({Method.POST, Method.PUT, Method.PATCH}) + + +class AsyncIdempotencyPolicy(AsyncPolicy): + """Async variant of :class:`IdempotencyPolicy`. + + Behaviour mirrors the sync twin: a key is minted once at + :attr:`Stage.POST_REDIRECT` (outside the retry wrapper) and reused across + every retry of the same write request, and a caller-supplied header is left + untouched. Key generation is synchronous, so :meth:`send` differs from the + sync version only in the ``await`` on the downstream call. + + Attributes: + STAGE: Pinned to :attr:`Stage.POST_REDIRECT` at the type level so + mis-slotting is caught by ``mypy``. + """ + + STAGE: ClassVar[Literal[Stage.POST_REDIRECT]] = Stage.POST_REDIRECT + __slots__ = ("_header", "_key_factory", "_methods") + + def __init__( + self, + *, + methods: Iterable[Method] = _DEFAULT_METHODS, + header: str = _DEFAULT_HEADER, + key_factory: Callable[[], str] = _generate_key, + ) -> None: + """Build the policy. + + Args: + methods: HTTP methods whose requests receive a key. Defaults to + ``POST``/``PUT``/``PATCH``. + header: Header name carrying the key. Defaults to + ``Idempotency-Key``. + key_factory: Zero-argument callable returning a fresh key string. + Defaults to a UUID4 generator; tests inject a deterministic + stub. + """ + self._methods = frozenset(methods) + self._header = header + self._key_factory = key_factory + + async def send(self, request: Request, ctx: PipelineContext) -> AsyncResponse: + """Stamp ``request`` with an idempotency key when applicable and dispatch. + + Args: + request: Outgoing request. + ctx: Pipeline context, forwarded unchanged. + + Returns: + The response from the downstream chain. + """ + if request.method in self._methods and self._header not in request.headers: + request = request.with_header(self._header, self._key_factory()) + return await self.next.send(request, ctx) + + +__all__ = ["AsyncIdempotencyPolicy"] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_redirect.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_redirect.py index 4d5eb9b..af5f40f 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_redirect.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_redirect.py @@ -16,7 +16,7 @@ from ...http.request.method import Method from ..async_policy import AsyncPolicy from ..stage import Stage -from .redirect import _REDIRECT_STATUSES, RedirectPolicy +from .redirect import _REDIRECT_STATUSES, RedirectPolicy, resolve_http_tracer if TYPE_CHECKING: from ...http.request.request import Request @@ -60,6 +60,8 @@ def __init__( async def send(self, request: Request, ctx: PipelineContext) -> AsyncResponse: cfg = self.config + tracer = resolve_http_tracer(ctx) + tracer.request_url_resolved(str(request.url)) visited: dict[str, None] = {str(request.url): None} hops = 0 current_request = request @@ -80,6 +82,7 @@ async def send(self, request: Request, ctx: PipelineContext) -> AsyncResponse: if next_key in visited: return response visited[next_key] = None + tracer.request_url_resolved(next_key) await response.close() current_request = next_request hops += 1 diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_retry.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_retry.py index 340c38e..540fb62 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_retry.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_retry.py @@ -11,6 +11,7 @@ from __future__ import annotations +import asyncio import logging import random from collections.abc import Iterable @@ -25,7 +26,8 @@ from ..async_policy import AsyncPolicy from ..stage import Stage from ._history import RequestHistory -from .retry import RetryMode, RetryPolicy, _parse_retry_after +from .redirect import resolve_http_tracer +from .retry import RetryMode, RetryPolicy, _StatusRetryError if TYPE_CHECKING: from ...http.request.request import Request @@ -64,6 +66,8 @@ def __init__( method_allowlist: Iterable[str] | None = None, retry_on_status_codes: Iterable[int] | None = None, respect_retry_after: bool = True, + retry_after_max: float | None = None, + full_jitter: bool = True, jitter: float = 0.25, clock: AsyncClock = ASYNC_SYSTEM_CLOCK, rand: random.Random | None = None, @@ -78,8 +82,11 @@ def __init__( "retry_mode": retry_mode, "timeout": timeout, "respect_retry_after": respect_retry_after, + "full_jitter": full_jitter, "jitter": jitter, } + if retry_after_max is not None: + kwargs["retry_after_max"] = retry_after_max if method_allowlist is not None: kwargs["method_allowlist"] = method_allowlist if retry_on_status_codes is not None: @@ -100,56 +107,43 @@ async def send(self, request: Request, ctx: PipelineContext) -> AsyncResponse: settings = cfg._configure_settings(ctx.options) absolute_deadline = self._clock.monotonic() + settings["timeout"] history: list[RequestHistory[AsyncResponse]] = settings["history"] + tracer = resolve_http_tracer(ctx) while True: + tracer.attempt_started(len(history)) try: response = await self.next.send(request, ctx) - if not cfg._is_retry(settings, request, response): - ctx.data["retry_history"] = tuple(history) - return response - history.append(RequestHistory(request=request, response=response)) - if not cfg._decrement_status(settings): - ctx.data["retry_history"] = tuple(history) - return response - ctx.data["retry_count"] = len(history) - await self._sleep_after_status(settings, response, absolute_deadline) - continue except ClientAuthenticationError: raise + except asyncio.CancelledError: + # CancelledError is a BaseException, not an SdkError, so the + # ``except SdkError`` below would not catch it — but an explicit + # re-raise documents and guarantees the invariant: a cancelled + # request is never retried, it propagates immediately. + raise except SdkError as err: history.append(RequestHistory(request=request, error=err)) if not cfg._decrement_for_error(settings, err): + tracer.attempt_retries_exhausted() ctx.data["retry_history"] = tuple(history) raise ctx.data["retry_count"] = len(history) - await self._sleep_after_error(settings, absolute_deadline) + delay = cfg._delay_for(settings, None) + tracer.attempt_failed(err, delay) + await self._sleep_bounded(delay, absolute_deadline) _LOGGER.debug("retrying after %s: %s", type(err).__name__, err) continue - - async def _sleep_after_status( - self, - settings: dict[str, Any], - response: AsyncResponse, - absolute_deadline: float, - ) -> None: - if self.config.respect_retry_after: - retry_after = _parse_retry_after(response.headers.get("Retry-After")) - if retry_after is not None: - await self._sleep_bounded(retry_after, absolute_deadline) - return - await self._sleep_bounded( - self.config._backoff_seconds(settings), - absolute_deadline, - ) - - async def _sleep_after_error( - self, - settings: dict[str, Any], - absolute_deadline: float, - ) -> None: - await self._sleep_bounded( - self.config._backoff_seconds(settings), - absolute_deadline, - ) + if not cfg._is_retry(settings, request, response): + ctx.data["retry_history"] = tuple(history) + return response + history.append(RequestHistory(request=request, response=response)) + if not cfg._decrement_status(settings): + tracer.attempt_retries_exhausted() + ctx.data["retry_history"] = tuple(history) + return response + ctx.data["retry_count"] = len(history) + delay = cfg._delay_for(settings, response) + tracer.attempt_failed(_StatusRetryError(int(response.status)), delay) + await self._sleep_bounded(delay, absolute_deadline) async def _sleep_bounded( self, diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/client_identity.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/client_identity.py new file mode 100644 index 0000000..8edefc8 --- /dev/null +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/client_identity.py @@ -0,0 +1,125 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Pipeline policy that stamps each request with an identifying ``User-Agent``.""" + +from __future__ import annotations + +import platform +from importlib.metadata import PackageNotFoundError, version +from typing import TYPE_CHECKING, ClassVar, Final, Literal + +from ...http.common.http_header_name import USER_AGENT +from ..policy import Policy +from ..stage import Stage + +if TYPE_CHECKING: + from ...http.request.request import Request + from ...http.response.response import Response + from ..context import PipelineContext + +_DIST_NAME: Final[str] = "dexpace-sdk-core" +_FALLBACK_VERSION: Final[str] = "0.0.0" + + +def _sdk_version() -> str: + """Return the installed core distribution version, or a safe fallback. + + The version is read from installed package metadata via + ``importlib.metadata``. When the distribution is not installed (for + example, running from a source tree without an editable install), a + placeholder is returned rather than raising, so the policy never produces + a blank or error-laden ``User-Agent``. + + Returns: + The ``dexpace-sdk-core`` version string, or ``"0.0.0"`` if it cannot + be resolved. + """ + try: + return version(_DIST_NAME) + except PackageNotFoundError: + return _FALLBACK_VERSION + + +def default_user_agent() -> str: + """Build the SDK's default ``User-Agent`` token string. + + The shape is ``dexpace-sdk/ python/`` — for + example ``dexpace-sdk/1.2.0 python/3.12.4``. Transport packages may append + their own ``/`` token by passing a longer ``user_agent`` to + :class:`ClientIdentityPolicy`. + + Returns: + A non-empty ``User-Agent`` string. + """ + return f"dexpace-sdk/{_sdk_version()} python/{platform.python_version()}" + + +class ClientIdentityPolicy(Policy): + """Stamps the outgoing request with an identifying ``User-Agent`` header. + + A consistent ``User-Agent`` lets servers and the SDK's own observability + attribute traffic to the toolkit and its version. The token string defaults + to :func:`default_user_agent` (``dexpace-sdk/ python/``). + + Two modes control interaction with a caller-set header: + + - **append** (the default): a caller-supplied ``User-Agent`` is preserved + and this policy's token is appended after it, space-separated, so both + identities reach the wire. + - **replace**: any caller-supplied ``User-Agent`` is overwritten. + + The configured token is required to be non-blank, so the policy never emits + an empty ``User-Agent`` header. + + Attributes: + STAGE: Pinned to :attr:`Stage.POST_RETRY` at the type level so + mis-slotting is caught by ``mypy``. + + Example: + ```python + Pipeline(transport, policies=[ClientIdentityPolicy()]) + ``` + """ + + STAGE: ClassVar[Literal[Stage.POST_RETRY]] = Stage.POST_RETRY + __slots__ = ("_replace", "_user_agent") + + def __init__(self, *, user_agent: str | None = None, replace: bool = False) -> None: + """Build the policy. + + Args: + user_agent: ``User-Agent`` token to stamp. ``None`` (the default) + uses :func:`default_user_agent`. An empty or whitespace-only + value is rejected so the header is never blank. + replace: When ``True``, overwrite any caller-set ``User-Agent``. + When ``False`` (the default), append after the caller's value. + + Raises: + ValueError: If ``user_agent`` is provided but empty or whitespace. + """ + resolved = default_user_agent() if user_agent is None else user_agent + if not resolved.strip(): + raise ValueError("user_agent must be a non-empty token string") + self._user_agent = resolved + self._replace = replace + + def send(self, request: Request, ctx: PipelineContext) -> Response: + """Stamp ``request`` with the ``User-Agent`` header and dispatch. + + Args: + request: Outgoing request. A new request is returned. + ctx: Pipeline context, forwarded unchanged. + + Returns: + The response from the downstream chain. + """ + existing = request.headers.get(USER_AGENT) + if self._replace or not existing or not existing.strip(): + value = self._user_agent + else: + value = f"{existing} {self._user_agent}" + return self.next.send(request.with_header(USER_AGENT, value), ctx) + + +__all__ = ["ClientIdentityPolicy", "default_user_agent"] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/idempotency.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/idempotency.py new file mode 100644 index 0000000..22a4ded --- /dev/null +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/idempotency.py @@ -0,0 +1,104 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Pipeline policy that stamps write requests with a stable ``Idempotency-Key``.""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING, ClassVar, Final, Literal +from uuid import uuid4 + +from ...http.request.method import Method +from ..policy import Policy +from ..stage import Stage + +if TYPE_CHECKING: + from collections.abc import Iterable + + from ...http.request.request import Request + from ...http.response.response import Response + from ..context import PipelineContext + +_DEFAULT_HEADER: Final[str] = "Idempotency-Key" +_DEFAULT_METHODS: Final[frozenset[Method]] = frozenset({Method.POST, Method.PUT, Method.PATCH}) + + +def _generate_key() -> str: + """Return a fresh random idempotency key (a UUID4 string).""" + return str(uuid4()) + + +class IdempotencyPolicy(Policy): + """Adds an ``Idempotency-Key`` header to write requests. + + The key is generated **once**, before the request is dispatched, and the + same value is carried across every retry of that request. This lets a + server detect a retried ``POST``/``PUT``/``PATCH`` as a duplicate of an + earlier attempt and avoid processing it twice — turning an at-least-once + delivery into an effectively-exactly-once one. + + A caller-supplied header is left untouched: if the request already carries + the configured header, this policy does nothing. Only the configured + methods (``POST``/``PUT``/``PATCH`` by default) are stamped; idempotency + keys on ``GET``/``DELETE`` are meaningless to most servers. + + The policy is placed at :attr:`Stage.POST_REDIRECT`, which runs *outside* + the retry wrapper (:attr:`Stage.RETRY`). The key is therefore minted on the + first pass and reused on every retry re-send, rather than re-rolled per + attempt the way :class:`SetDatePolicy` re-stamps the ``Date`` header. + + Attributes: + STAGE: Pinned to :attr:`Stage.POST_REDIRECT` at the type level so + mis-slotting is caught by ``mypy``. + + Example: + ```python + Pipeline(transport, policies=[RedirectPolicy(), IdempotencyPolicy(), RetryPolicy()]) + ``` + """ + + STAGE: ClassVar[Literal[Stage.POST_REDIRECT]] = Stage.POST_REDIRECT + __slots__ = ("_header", "_key_factory", "_methods") + + def __init__( + self, + *, + methods: Iterable[Method] = _DEFAULT_METHODS, + header: str = _DEFAULT_HEADER, + key_factory: Callable[[], str] = _generate_key, + ) -> None: + """Build the policy. + + Args: + methods: HTTP methods whose requests receive a key. Defaults to + ``POST``/``PUT``/``PATCH`` — the standard non-idempotent write + verbs. + header: Header name carrying the key. Defaults to + ``Idempotency-Key`` (the Stripe / IETF draft spelling). + key_factory: Zero-argument callable returning a fresh key string. + Defaults to a UUID4 generator; tests inject a deterministic + stub. + """ + self._methods = frozenset(methods) + self._header = header + self._key_factory = key_factory + + def send(self, request: Request, ctx: PipelineContext) -> Response: + """Stamp ``request`` with an idempotency key when applicable and dispatch. + + Args: + request: Outgoing request. A new request carrying the key is + returned when one is added; otherwise the request is forwarded + unchanged. + ctx: Pipeline context, forwarded unchanged. + + Returns: + The response from the downstream chain. + """ + if request.method in self._methods and self._header not in request.headers: + request = request.with_header(self._header, self._key_factory()) + return self.next.send(request, ctx) + + +__all__ = ["IdempotencyPolicy"] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/redirect.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/redirect.py index 8dcc3ec..1fd40e1 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/redirect.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/redirect.py @@ -25,7 +25,7 @@ from __future__ import annotations from dataclasses import replace -from typing import TYPE_CHECKING, ClassVar, Literal +from typing import TYPE_CHECKING, ClassVar, Literal, cast from urllib.parse import urljoin from ...http.common.url import Url @@ -36,12 +36,41 @@ if TYPE_CHECKING: from ...http.request.request import Request from ...http.response.response import Response + from ...instrumentation import HttpTracer from ..context import PipelineContext _REDIRECT_STATUSES: frozenset[int] = frozenset({301, 302, 303, 307, 308}) _CONTENT_HEADER_PREFIX: str = "content-" +#: ``ctx.data`` key holding the per-operation ``HttpTracer``. The first policy +#: in the chain to need it mints one from the call's +#: ``instrumentation_context.http_tracer_factory`` and stores it here so every +#: other policy (tracing, retry, redirect) emits onto the same instance. +HTTP_TRACER_KEY: str = "http_tracer" + + +def resolve_http_tracer(ctx: PipelineContext) -> HttpTracer: + """Return the per-operation ``HttpTracer``, minting one on first use. + + The tracer is cached in ``ctx.data[HTTP_TRACER_KEY]`` so every policy in + the chain shares a single instance for the operation. Defaults to the + no-op tracer when the call carries the no-op factory, so callers that do + not instrument pay nothing. + + Args: + ctx: The pipeline context for the in-flight operation. + + Returns: + The shared ``HttpTracer`` for this operation. + """ + existing = ctx.data.get(HTTP_TRACER_KEY) + if existing is not None: + return cast("HttpTracer", existing) + tracer = ctx.call.instrumentation_context.http_tracer_factory.create() + ctx.data[HTTP_TRACER_KEY] = tracer + return tracer + class RedirectPolicy(Policy): """Follow HTTP redirects per RFC 7231 §6.4 with credential stripping. @@ -98,6 +127,8 @@ def __init__( # ----- main loop ------------------------------------------------------ def send(self, request: Request, ctx: PipelineContext) -> Response: + tracer = resolve_http_tracer(ctx) + tracer.request_url_resolved(str(request.url)) visited: dict[str, None] = {str(request.url): None} hops = 0 current_request = request @@ -118,6 +149,7 @@ def send(self, request: Request, ctx: PipelineContext) -> Response: if next_key in visited: return response visited[next_key] = None + tracer.request_url_resolved(next_key) # Close the intermediate response — we are not handing it back to # the caller. The terminal response is closed by the caller via # the ``with`` block. @@ -204,4 +236,4 @@ def _reissue_preserving_body(self, request: Request, next_url: Url) -> Request: return reissued -__all__ = ["RedirectPolicy"] +__all__ = ["HTTP_TRACER_KEY", "RedirectPolicy", "resolve_http_tracer"] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/retry.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/retry.py index 871fdc1..a339425 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/retry.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/retry.py @@ -39,6 +39,7 @@ from ..policy import Policy from ..stage import Stage from ._history import RequestHistory +from .redirect import resolve_http_tracer if TYPE_CHECKING: from ...http.request.request import Request @@ -47,6 +48,21 @@ _LOGGER = logging.getLogger(__name__) +#: Default ceiling, in seconds, applied to a server-supplied ``Retry-After`` or +#: ``X-RateLimit-Reset`` delay so a buggy or hostile header cannot make the +#: client sleep for hours. One hour is generous for legitimate rate limits. +_DEFAULT_RETRY_AFTER_MAX: Final[float] = 3600.0 + +#: Header carrying the epoch second at which a rate-limit window resets. Sent by +#: GitHub, Stripe, Slack, and others alongside (or instead of) ``Retry-After``. +_RATE_LIMIT_RESET_HEADER: Final[str] = "X-RateLimit-Reset" + +#: Upward jitter fraction applied to an ``X-RateLimit-Reset`` wait. The delay is +#: multiplied by a random sample in ``[1.0, 1.0 + this]`` so a client never wakes +#: before the window resets, while a fleet that observed the same reset instant +#: spreads its retries instead of firing in lockstep. +_RATE_LIMIT_RESET_JITTER: Final[float] = 0.1 + @runtime_checkable class _ResponseLike(Protocol): @@ -103,12 +119,26 @@ class RetryPolicy(Policy): method_allowlist: HTTP methods that get full retry semantics. POST and PATCH are retried only on 500/503/504. retry_on_status_codes: Status codes that trigger a retry. - respect_retry_after: When ``True``, sleep for the ``Retry-After`` - header value (if present) instead of the computed backoff. - jitter: Fractional band applied to each computed backoff to break - thundering herds. ``0.25`` multiplies the backoff by a random - sample in ``[0.75, 1.25]``. Set to ``0`` for deterministic - backoff. + respect_retry_after: When ``True``, sleep for the server-supplied + delay (``Retry-After`` header in seconds/HTTP-date, or an + ``X-RateLimit-Reset`` epoch) when present, instead of the + computed backoff. The server delay is itself capped at + ``retry_after_max`` and an ``X-RateLimit-Reset`` wait gets a small + upward jitter so a fleet of clients does not all retry at the exact + reset instant (and never wakes before it). + retry_after_max: Ceiling, in seconds, on a server-supplied + ``Retry-After`` / ``X-RateLimit-Reset`` delay. Protects against a + buggy or hostile header forcing a multi-hour sleep. Defaults to + one hour. + full_jitter: When ``True`` (the default), exponential backoff uses + *full jitter* — the computed delay is multiplied by a random + sample in ``[0.5, 1.0]`` (AWS's recommended scheme), spreading + retries evenly across the window. When ``False``, the symmetric + ``jitter`` band is applied instead. + jitter: Symmetric fractional band applied to the computed backoff + when ``full_jitter`` is ``False``. ``0.25`` multiplies the + backoff by a random sample in ``[0.75, 1.25]``. Set to ``0`` for + deterministic backoff. Example: ```python @@ -137,6 +167,8 @@ def __init__( method_allowlist: Iterable[str] = _DEFAULT_METHOD_ALLOWLIST, retry_on_status_codes: Iterable[int] = _DEFAULT_STATUS_RETRIES, respect_retry_after: bool = True, + retry_after_max: float = _DEFAULT_RETRY_AFTER_MAX, + full_jitter: bool = True, jitter: float = 0.25, clock: Clock = SYSTEM_CLOCK, rand: random.Random | None = None, @@ -152,6 +184,8 @@ def __init__( self.method_allowlist = frozenset(m.upper() for m in method_allowlist) self.retry_on_status_codes = frozenset(retry_on_status_codes) self.respect_retry_after = respect_retry_after + self.retry_after_max = retry_after_max + self._full_jitter = full_jitter self._jitter = jitter self._clock = clock self._rand = rand if rand is not None else random.Random() @@ -171,7 +205,9 @@ def send(self, request: Request, ctx: PipelineContext) -> Response: settings = self._configure_settings(ctx.options) absolute_deadline = self._clock.monotonic() + settings["timeout"] history: list[RequestHistory[Response]] = settings["history"] + tracer = resolve_http_tracer(ctx) while True: + tracer.attempt_started(len(history)) try: response = self.next.send(request, ctx) if not self._is_retry(settings, request, response): @@ -179,20 +215,26 @@ def send(self, request: Request, ctx: PipelineContext) -> Response: return response history.append(RequestHistory(request=request, response=response)) if not self._decrement_status(settings): + tracer.attempt_retries_exhausted() ctx.data["retry_history"] = tuple(history) return response ctx.data["retry_count"] = len(history) - self._sleep_for(settings, response, absolute_deadline) + delay = self._delay_for(settings, response) + tracer.attempt_failed(_StatusRetryError(int(response.status)), delay) + self._sleep_bounded(delay, absolute_deadline) continue except ClientAuthenticationError: raise except SdkError as err: history.append(RequestHistory(request=request, error=err)) if not self._decrement_for_error(settings, err): + tracer.attempt_retries_exhausted() ctx.data["retry_history"] = tuple(history) raise ctx.data["retry_count"] = len(history) - self._sleep_for(settings, None, absolute_deadline) + delay = self._delay_for(settings, None) + tracer.attempt_failed(err, delay) + self._sleep_bounded(delay, absolute_deadline) _LOGGER.debug("retrying after %s: %s", type(err).__name__, err) continue @@ -277,19 +319,57 @@ def _decrement_for_error( # ----- backoff / sleep ------------------------------------------------ - def _sleep_for( + def _delay_for( self, settings: dict[str, Any], response: _ResponseLike | None, - absolute_deadline: float, - ) -> None: - """Sleep before the next attempt, respecting the absolute deadline.""" + ) -> float: + """Compute the delay before the next attempt, in seconds. + + Prefers a server-supplied signal (``Retry-After`` or + ``X-RateLimit-Reset``) when ``respect_retry_after`` is set and the + response carries one — capped at ``retry_after_max``. Otherwise falls + back to the jittered computed backoff. + + Args: + settings: Mutable per-call settings dict. + response: The response that triggered the retry, or ``None`` for a + network-side error (which carries no server timing header). + + Returns: + Non-negative seconds to wait. + """ if response is not None and self.respect_retry_after: - retry_after = _parse_retry_after(response.headers.get("Retry-After")) - if retry_after is not None: - self._sleep_bounded(retry_after, absolute_deadline) - return - self._sleep_bounded(self._backoff_seconds(settings), absolute_deadline) + server_delay = self._server_delay(response) + if server_delay is not None: + return server_delay + return self._backoff_seconds(settings) + + def _server_delay(self, response: _ResponseLike) -> float | None: + """Resolve a server-supplied retry delay, capped and jittered. + + ``Retry-After`` (seconds or HTTP-date) takes precedence; an + ``X-RateLimit-Reset`` epoch is the fallback and gets a slight *upward* + jitter. The jitter only ever lengthens the wait — retrying before the + window actually resets just earns another rate-limit response — while + the small positive spread keeps a fleet of clients that observed the + same reset instant from retrying in lockstep. Both are capped at + ``retry_after_max``. + + Returns: + Seconds to wait, or ``None`` when neither header is present. + """ + retry_after = _parse_retry_after(response.headers.get("Retry-After")) + if retry_after is not None: + return min(retry_after, self.retry_after_max) + reset = _parse_rate_limit_reset( + response.headers.get(_RATE_LIMIT_RESET_HEADER), + self._clock.now(), + ) + if reset is None: + return None + jittered = reset * self._rand.uniform(1.0, 1.0 + _RATE_LIMIT_RESET_JITTER) + return min(jittered, self.retry_after_max) def _backoff_seconds(self, settings: dict[str, Any]) -> float: attempts = len(settings["history"]) @@ -300,6 +380,8 @@ def _backoff_seconds(self, settings: dict[str, Any]) -> float: else: backoff = float(settings["backoff"]) * (2 ** (attempts - 1)) bounded = min(float(settings["max_backoff"]), backoff) + if self._full_jitter: + return bounded * self._rand.uniform(0.5, 1.0) if self._jitter == 0: return bounded return bounded * self._rand.uniform(1 - self._jitter, 1 + self._jitter) @@ -336,6 +418,20 @@ def _sleep_bounded(self, duration: float, absolute_deadline: float) -> None: raise ServiceResponseTimeoutError("Retry budget exhausted (timeout reached)") +class _StatusRetryError(Exception): + """Marker error passed to ``HttpTracer.attempt_failed`` for status retries. + + A retryable HTTP status response is not itself an exception, but the + tracer's ``attempt_failed`` callback wants a ``BaseException`` describing + why the attempt failed. This lightweight wrapper carries the status code so + consumers can distinguish a status-driven retry from a transport error. + """ + + def __init__(self, status: int) -> None: + super().__init__(f"retryable HTTP status {status}") + self.status = status + + _RETRY_AFTER_DELTA_PATTERN = re.compile(r"^\s*\d+(\.\d+)?\s*$") @@ -363,6 +459,34 @@ def _parse_retry_after(value: str | None) -> float | None: return max(0.0, delta) +_RATE_LIMIT_RESET_PATTERN = re.compile(r"^\s*\d+(\.\d+)?\s*$") + + +def _parse_rate_limit_reset(value: str | None, now: float) -> float | None: + """Parse an ``X-RateLimit-Reset`` epoch header into a delay. + + The header carries the wall-clock second at which the rate-limit window + resets (GitHub, Stripe, Slack). The delay is the difference between that + instant and ``now``, floored at zero (a reset already in the past means + retry immediately). + + Args: + value: Raw header value (epoch seconds). ``None`` or unparseable + returns ``None``. + now: Current wall-clock time, in seconds since the epoch, used to + compute the delta. Injected so the value is deterministic in tests. + + Returns: + Seconds to wait (>= 0), or ``None`` when the header is missing or not a + plain epoch number. + """ + if value is None or not value.strip(): + return None + if not _RATE_LIMIT_RESET_PATTERN.match(value): + return None + return max(0.0, float(value) - now) + + def _has_budget(settings: dict[str, Any]) -> bool: """Return ``True`` while every retry counter is non-negative.""" counts: tuple[int, ...] = ( diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/tracing_policy.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/tracing_policy.py index 05420ee..4e5760e 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/tracing_policy.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/tracing_policy.py @@ -1,20 +1,39 @@ # Copyright (c) 2026 dexpace and Omar Aljarrah. # Licensed under the MIT License. See LICENSE.md in the repository root for details. -"""Pipeline policy that opens a span around the downstream chain.""" +"""Pipeline policy that opens a span around the downstream chain. + +Besides the OpenTelemetry span, the policy drives two correlation seams: + +- It mints a per-operation ``HttpTracer`` from + ``ctx.call.instrumentation_context.http_tracer_factory`` and emits the + fine-grained operation/request/response lifecycle events an SRE wants + (``operation_started``, ``request_sent``, ``response_headers_received``, + ``response_received``, ``operation_succeeded`` / ``operation_failed``). Per + attempt events (``attempt_started`` / ``attempt_failed`` / + ``attempt_retries_exhausted``) are owned by the retry policy. +- It binds the active trace/span ids into ``contextvars`` for the duration of + the request so ``ClientLogger`` can stamp ``trace.id`` / ``span.id`` onto + every log record emitted downstream. + +Both seams are no-op-safe: the default tracer factory returns +``NOOP_HTTP_TRACER`` and the no-op span carries the sentinel trace ids. +""" from __future__ import annotations from typing import TYPE_CHECKING -from ...instrumentation import NOOP_TRACER, Tracer +from ...instrumentation import NOOP_TRACER, Tracer, bind_correlation from ..policy import Policy from ..stage import Stage +from .redirect import resolve_http_tracer if TYPE_CHECKING: from ...http.common.url import Url from ...http.request.request import Request from ...http.response.response import Response + from ...instrumentation import HttpTracer, Span from ..context import PipelineContext @@ -32,6 +51,12 @@ class TracingPolicy(Policy): - ``http.request.resend_count``: Retry attempt count from ``ctx.data["retry_count"]`` (when retry policy is upstream). + While the span is open the active trace/span ids are bound into the + correlation ``contextvars`` so downstream log records carry them, and a + per-operation ``HttpTracer`` (from the call's + ``instrumentation_context.http_tracer_factory``) receives the + operation/request/response lifecycle events. + Disable per-call by setting ``ctx.options["tracing_enabled"] = False``. """ @@ -45,29 +70,86 @@ def send(self, request: Request, ctx: PipelineContext) -> Response: if not ctx.options.get("tracing_enabled", True): return self.next.send(request, ctx) parent = ctx.call.instrumentation_context + # Share one per-operation tracer with the redirect / retry policies via + # ``ctx.data`` (whichever policy runs first mints it). + http_tracer = resolve_http_tracer(ctx) span = self._tracer.start_span(f"HTTP {request.method}", parent=parent) - host, port = _split_host(request.url) - span.set_attribute("http.request.method", str(request.method)) - span.set_attribute("url.full", str(request.url)) - if host: - span.set_attribute("server.address", host) - if port is not None: - span.set_attribute("server.port", port) + _set_request_attributes(span, request) + http_tracer.operation_started() + with bind_correlation(trace_id=_trace_id(span), span_id=_span_id(span)): + return self._dispatch(request, ctx, span, http_tracer) + + def _dispatch( + self, + request: Request, + ctx: PipelineContext, + span: Span, + http_tracer: HttpTracer, + ) -> Response: + """Run the downstream chain, emitting tracer events around it.""" + _notify_request_sent(http_tracer, request) try: with span.make_current(): response = self.next.send(request, ctx) except BaseException as err: span.set_error(type(err).__name__) span.end(error=err) + http_tracer.operation_failed(err) raise + _notify_response(http_tracer, response) span.set_attribute("http.response.status_code", int(response.status)) retry_count = ctx.data.get("retry_count") if isinstance(retry_count, int) and retry_count > 0: span.set_attribute("http.request.resend_count", retry_count) span.end() + http_tracer.operation_succeeded() return response +def _set_request_attributes(span: Span, request: Request) -> None: + """Stamp the OpenTelemetry request attributes onto the span.""" + host, port = _split_host(request.url) + span.set_attribute("http.request.method", str(request.method)) + span.set_attribute("url.full", str(request.url)) + if host: + span.set_attribute("server.address", host) + if port is not None: + span.set_attribute("server.port", port) + + +def _notify_request_sent(http_tracer: HttpTracer, request: Request) -> None: + """Emit ``request_sent`` with the known body byte count, if any.""" + body = request.body + if body is None: + http_tracer.request_sent(0) + return + length = body.content_length() + if length >= 0: + http_tracer.request_sent(length) + + +def _notify_response(http_tracer: HttpTracer, response: Response) -> None: + """Emit ``response_headers_received`` then ``response_received``.""" + headers = {name: ", ".join(values) for name, values in response.headers.items()} + http_tracer.response_headers_received(int(response.status), headers) + body = response.body + if body is None: + http_tracer.response_received(0) + return + length = body.content_length() + if length >= 0: + http_tracer.response_received(length) + + +def _trace_id(span: Span) -> str | None: + value = span.context.trace_id.value + return value if span.context.is_valid else None + + +def _span_id(span: Span) -> str | None: + return span.context.span_id.value if span.context.is_valid else None + + def _split_host(url: Url) -> tuple[str | None, int | None]: return url.host or None, url.port diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/serde/__init__.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/serde/__init__.py index 2cceabd..6a99137 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/serde/__init__.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/serde/__init__.py @@ -5,15 +5,55 @@ from __future__ import annotations +from .codec import ( + ALIAS_KEY, + DISCRIMINATOR_KEY, + REGISTRY_KEY, + Codec, + CodecError, + discriminated, + field_alias, + variant, +) from .json_serde import JSON_SERDE, JsonDeserializer, JsonSerde, JsonSerializer from .serde import Deserializer, Serde, Serializer +from .tristate import ( + ABSENT, + NULL, + Present, + Tristate, + fold, + is_absent, + is_null, + is_present, + of_optional, + present, +) __all__ = [ + "ABSENT", + "ALIAS_KEY", + "DISCRIMINATOR_KEY", "JSON_SERDE", + "NULL", + "REGISTRY_KEY", + "Codec", + "CodecError", "Deserializer", "JsonDeserializer", "JsonSerde", "JsonSerializer", + "Present", "Serde", "Serializer", + "Tristate", + "discriminated", + "field_alias", + "fold", + "is_absent", + "is_null", + "is_present", + "of_optional", + "present", + "variant", ] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/serde/codec.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/serde/codec.py new file mode 100644 index 0000000..11d0aba --- /dev/null +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/serde/codec.py @@ -0,0 +1,741 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Typed-model codec sitting above ``Serde``. + +``Serde`` is document-in / document-out: it turns wire bytes into plain +``dict`` / ``list`` / scalar documents and back. This module bridges that +plain-document layer and frozen dataclass models:: + + wire bytes -> Serde.deserializer -> document -> Codec.decode -> dataclass + dataclass -> Codec.encode -> document -> Serde.serializer -> wire bytes + +The codec never touches JSON (or any other) syntax, so it is format-agnostic +and reusable with any ``Serde``. It is deliberately validation-free: it +reconstructs declared types, handles aliases, ``Tristate`` fields, datetimes, +enums, containers and discriminated unions, but performs no schema checks or +scalar coercion. The model's own ``__post_init__`` invariants still run because +construction goes through the normal constructor. +""" + +from __future__ import annotations + +import collections.abc as cabc +import dataclasses +import datetime as _dt +import enum +import types +import typing +from typing import Final, Union, cast, get_args, get_origin, get_type_hints + +from ..errors import DeserializationError, SerializationError +from .tristate import ABSENT, NULL, Present, Tristate + +if typing.TYPE_CHECKING: + from collections.abc import Callable, Mapping + +ALIAS_KEY: Final = "alias" +"""``field.metadata`` key naming the wire name for a dataclass field.""" + +DISCRIMINATOR_KEY: Final = "__codec_discriminator__" +"""Class attribute naming a discriminated union's tag/discriminator field.""" + +REGISTRY_KEY: Final = "__codec_registry__" +"""Class attribute holding a discriminated union's ``tag -> concrete`` map.""" + + +class CodecError(DeserializationError): + """A document could not be decoded into the requested typed model. + + Carries a wire-name breadcrumb so failures point at the offending location + in the source document. Subclasses ``DeserializationError`` (hence + ``ValueError`` and ``SdkError``) so existing handlers continue to catch it. + + Attributes: + path: Wire-name breadcrumb to the offending location, e.g. + ``("methods", "[0]", "last4")``. + target_name: Name of the type that was being decoded, if known. + """ + + def __init__( + self, + reason: str, + *, + path: tuple[str, ...] = (), + target_name: str | None = None, + error: BaseException | None = None, + ) -> None: + """Initialise the error. + + Args: + reason: Human-readable failure description. + path: Wire-name breadcrumb to the offending location. + target_name: Name of the type being decoded, if known. + error: Underlying cause, if any. + """ + self.path = path + self.target_name = target_name + super().__init__(_render(reason, path, target_name), error=error) + + +def _render(reason: str, path: tuple[str, ...], target_name: str | None) -> str: + """Render a codec error message with its path and target context.""" + rendered = _render_path(path) + prefix = f"field path '{rendered}': " if rendered else "" + suffix = f" (decoding {target_name})" if target_name else "" + return f"{prefix}{reason}{suffix}" + + +def _render_path(path: tuple[str, ...]) -> str: + """Join a breadcrumb into ``a.b[0].c`` form.""" + out = "" + for part in path: + if part.startswith("["): + out += part + else: + out += f".{part}" if out else part + return out + + +@dataclasses.dataclass(frozen=True, slots=True) +class _ModelInfo: + """Cached per-model decode metadata.""" + + hints: Mapping[str, object] + field_to_wire: Mapping[str, str] + wire_to_field: Mapping[str, str] + + +# Bounded by the (finite) set of model classes defined at import time. Models +# are not created dynamically at runtime, so this cache never grows unbounded +# in practice; the lack of an explicit size cap is acceptable for that reason. +_MODEL_CACHE: dict[type, _ModelInfo] = {} + + +def field_alias( + wire_name: str, + /, + *, + default: object = dataclasses.MISSING, + default_factory: Callable[[], object] | None = None, +) -> object: + """Declare a dataclass field whose wire name differs from its Python name. + + Sugar over ``dataclasses.field(metadata={ALIAS_KEY: wire_name}, ...)``. + Raw ``field(metadata={"alias": ...})`` works identically. + + Args: + wire_name: The key used in the wire document for this field. + default: Optional default value (mutually exclusive with + ``default_factory``). + default_factory: Optional zero-arg factory producing the default. + + Returns: + A ``dataclasses.Field`` carrying the alias metadata. + """ + metadata = {ALIAS_KEY: wire_name} + if default is not dataclasses.MISSING: + return dataclasses.field(default=default, metadata=metadata) + if default_factory is not None: + return dataclasses.field(default_factory=default_factory, metadata=metadata) + return dataclasses.field(metadata=metadata) + + +def discriminated[T](tag_field: str, /) -> Callable[[type[T]], type[T]]: + """Mark a base/union class as a discriminated union. + + Attaches an empty variant registry and records which wire field carries the + discriminator tag. Apply to the base type; register concrete variants with + ``@variant``. + + Args: + tag_field: Wire name of the field carrying the discriminator value. + + Returns: + A class decorator returning the class unchanged. + """ + + def decorate(cls: type[T]) -> type[T]: + setattr(cls, DISCRIMINATOR_KEY, tag_field) + setattr(cls, REGISTRY_KEY, {}) + return cls + + return decorate + + +def variant[T](tag_value: str, /) -> Callable[[type[T]], type[T]]: + """Register a concrete dataclass under ``tag_value`` in its base's registry. + + Walks the decorated class's MRO to find the nearest base carrying a registry + (declared via ``@discriminated``) and registers the class there. + + Args: + tag_value: The discriminator value selecting this variant. + + Returns: + A class decorator returning the class unchanged. + + Raises: + TypeError: If no ``@discriminated`` base is found in the MRO. + ValueError: If ``tag_value`` is already registered under that base. + """ + + def decorate(cls: type[T]) -> type[T]: + registry = _find_registry(cls) + if tag_value in registry: + raise ValueError( + f"discriminator value {tag_value!r} already registered " + f"(by {registry[tag_value].__name__})", + ) + registry[tag_value] = cls + return cls + + return decorate + + +def _find_registry(cls: type) -> dict[str, type]: + """Return the variant registry owned by the nearest ``@discriminated`` base.""" + for base in cls.__mro__[1:]: + registry = base.__dict__.get(REGISTRY_KEY) + if registry is not None: + return cast("dict[str, type]", registry) + raise TypeError( + f"{cls.__name__} has no @discriminated base; apply @discriminated to its union base first", + ) + + +class Codec: + """Stateless engine converting between documents and typed models. + + Constructed once and reused. Effectively immutable after construction; its + only mutable state is a shared module-level type-hint cache whose dict + operations are atomic under CPython's GIL, so instances are safe to share + across threads. + """ + + __slots__ = ("_tolerate_unknown",) + + def __init__(self, *, tolerate_unknown: bool = True) -> None: + """Configure the codec. + + Args: + tolerate_unknown: When ``True`` (default), wire keys not claimed by + any field are silently dropped on decode, so a growing server + payload does not break older clients. When ``False``, an + unclaimed key raises ``CodecError``. + """ + self._tolerate_unknown = tolerate_unknown + + def decode[T](self, data: object, target: type[T]) -> T: + """Decode a plain document into an instance of ``target``. + + Args: + data: A plain document (``dict`` / ``list`` / scalar) as produced by + ``Serde.deserializer``. + target: The type to reconstruct — a dataclass, a discriminated base, + ``list[X]`` / ``dict[str, X]``, a datetime/enum, or a scalar. + + Returns: + A fully constructed instance of ``target``. + + Raises: + CodecError: On any structural mismatch or conversion failure, with a + wire-name path pointing at the offending location. + """ + return cast("T", _decode_value(data, target, (), self._tolerate_unknown)) + + def encode(self, value: object) -> object: + """Encode a typed value into a plain document. + + Args: + value: A dataclass, container, datetime, enum, ``Tristate`` field + value, or scalar. + + Returns: + A plain document (``dict`` / ``list`` / scalar) ready for + ``Serde.serializer``. + + Raises: + SerializationError: If ``value`` cannot be turned into a document. + """ + return _encode_value(value) + + +# --------------------------------------------------------------------------- # +# Decode # +# --------------------------------------------------------------------------- # + + +def _decode_value( + data: object, + target: object, + path: tuple[str, ...], + tolerate_unknown: bool, +) -> object: + """Decode ``data`` into ``target``, dispatching on the type's shape.""" + if target is object or target is typing.Any: + return data + origin = get_origin(target) + if origin is None: + return _decode_atomic(data, target, path, tolerate_unknown) + if _is_tristate(target): + return _decode_tristate(data, target, path, tolerate_unknown) + if origin in (Union, types.UnionType): + return _decode_union(data, target, path, tolerate_unknown) + return _decode_container(data, target, origin, path, tolerate_unknown) + + +def _decode_atomic( + data: object, + target: object, + path: tuple[str, ...], + tolerate_unknown: bool, +) -> object: + """Decode into a non-parametrised target (dataclass, datetime, enum, scalar).""" + if isinstance(target, type): + if REGISTRY_KEY in target.__dict__: + return _dispatch_union(data, target, path, tolerate_unknown) + if dataclasses.is_dataclass(target): + return _decode_dataclass(data, target, path, tolerate_unknown) + if issubclass(target, enum.Enum): + return _decode_enum(data, target, path) + if issubclass(target, (_dt.datetime, _dt.date, _dt.time)): + return _decode_temporal(data, target, path) + return data + + +def _decode_dataclass( + data: object, + target: type, + path: tuple[str, ...], + tolerate_unknown: bool, + *, + exempt_key: str | None = None, +) -> object: + """Decode a mapping into a plain dataclass, field by field. + + Args: + data: The wire mapping to decode. + target: The dataclass type to construct. + path: Wire-name breadcrumb to this location. + tolerate_unknown: Whether unclaimed keys are dropped or rejected. + exempt_key: A wire key always permitted under strict mode even when no + field claims it — used for a discriminated union's tag, which is a + structural key rather than a stray unknown one. + """ + if not isinstance(data, cabc.Mapping): + raise CodecError( + f"expected an object, got {type(data).__name__}", + path=path, + target_name=target.__name__, + ) + info = _resolve_info(target) + kwargs = _decode_fields(data, target, info, path, tolerate_unknown) + if not tolerate_unknown: + _reject_unknown(data, info, path, target.__name__, exempt_key=exempt_key) + try: + return target(**kwargs) + except (TypeError, ValueError) as err: + raise CodecError(str(err), path=path, target_name=target.__name__, error=err) from err + + +def _decode_fields( + data: Mapping[object, object], + target: type, + info: _ModelInfo, + path: tuple[str, ...], + tolerate_unknown: bool, +) -> dict[str, object]: + """Build the constructor kwargs for ``target`` from ``data``.""" + kwargs: dict[str, object] = {} + for f in dataclasses.fields(target): + wire = info.field_to_wire[f.name] + hint = info.hints[f.name] + if wire not in data: + _require_present_or_default(f, wire, path, target.__name__) + continue + kwargs[f.name] = _decode_value(data[wire], hint, (*path, wire), tolerate_unknown) + return kwargs + + +def _require_present_or_default( + f: dataclasses.Field[object], + wire: str, + path: tuple[str, ...], + target_name: str, +) -> None: + """Ensure a missing field has a default; otherwise raise ``CodecError``.""" + has_default = f.default is not dataclasses.MISSING + has_factory = f.default_factory is not dataclasses.MISSING + if not (has_default or has_factory): + raise CodecError( + f"missing required field {f.name!r} (wire {wire!r})", + path=path, + target_name=target_name, + ) + + +def _reject_unknown( + data: Mapping[object, object], + info: _ModelInfo, + path: tuple[str, ...], + target_name: str, + *, + exempt_key: str | None = None, +) -> None: + """Raise if ``data`` carries a key not claimed by any field. + + ``exempt_key`` names a wire key that is always permitted even when no field + claims it (a discriminated union's tag), so strict mode does not punish the + very key that drove variant dispatch. + """ + for key in data: + if key == exempt_key: + continue + if key not in info.wire_to_field: + raise CodecError( + f"unknown field {key!r}", + path=path, + target_name=target_name, + ) + + +def _decode_container( + data: object, + target: object, + origin: object, + path: tuple[str, ...], + tolerate_unknown: bool, +) -> object: + """Decode a parametrised container (list/tuple/set/dict/mapping).""" + args = get_args(target) + if origin is dict or _is_mapping_origin(origin): + return _decode_mapping(data, args, path, tolerate_unknown) + if origin is tuple: + return _decode_tuple(data, args, path, tolerate_unknown) + if origin in (list, set, frozenset) or _is_sequence_origin(origin): + return _decode_sequence(data, origin, args, path, tolerate_unknown) + return data + + +def _decode_sequence( + data: object, + origin: object, + args: tuple[object, ...], + path: tuple[str, ...], + tolerate_unknown: bool, +) -> object: + """Decode a homogeneous sequence into list/set/frozenset (default list).""" + if not isinstance(data, cabc.Iterable) or isinstance(data, (str, bytes, cabc.Mapping)): + raise CodecError(f"expected an array, got {type(data).__name__}", path=path) + elem = args[0] if args else object + items = [ + _decode_value(item, elem, (*path, f"[{i}]"), tolerate_unknown) + for i, item in enumerate(data) + ] + if origin in (set, cabc.Set, cabc.MutableSet): + return set(items) + if origin is frozenset: + return frozenset(items) + return items + + +def _decode_tuple( + data: object, + args: tuple[object, ...], + path: tuple[str, ...], + tolerate_unknown: bool, +) -> object: + """Decode a homogeneous (``tuple[X, ...]``) or fixed-arity tuple.""" + if not isinstance(data, cabc.Iterable) or isinstance(data, (str, bytes, cabc.Mapping)): + raise CodecError(f"expected an array, got {type(data).__name__}", path=path) + seq = list(data) + if len(args) == 2 and args[1] is Ellipsis: + elem = args[0] + return tuple( + _decode_value(v, elem, (*path, f"[{i}]"), tolerate_unknown) for i, v in enumerate(seq) + ) + arity = len(args) + if len(seq) != arity: + raise CodecError( + f"expected an array of {arity} element(s), got {len(seq)}", + path=path, + ) + return tuple( + _decode_value(v, t, (*path, f"[{i}]"), tolerate_unknown) + for i, (v, t) in enumerate(zip(seq, args, strict=True)) + ) + + +def _decode_mapping( + data: object, + args: tuple[object, ...], + path: tuple[str, ...], + tolerate_unknown: bool, +) -> object: + """Decode a mapping, recovering each key and value through its declared type. + + Wire object keys are always strings, so a declared key type such as ``int``, + an enum, or ``UUID`` is recovered by recursing on the key the same way the + codec recurses on values; ``str`` and ``object`` key types pass through. + """ + if not isinstance(data, cabc.Mapping): + raise CodecError(f"expected an object, got {type(data).__name__}", path=path) + key_type = args[0] if len(args) == 2 else object + value_type = args[1] if len(args) == 2 else object + return { + _decode_value(key, key_type, (*path, str(key)), tolerate_unknown): _decode_value( + val, + value_type, + (*path, str(key)), + tolerate_unknown, + ) + for key, val in data.items() + } + + +def _decode_union( + data: object, + target: object, + path: tuple[str, ...], + tolerate_unknown: bool, +) -> object: + """Decode an ``X | None`` union: ``None`` passthrough, else decode ``X``. + + Only single-arm optionals (``X | None``) recover their inner type; ``None`` + is passed through only when ``NoneType`` is genuinely a union member, so a + non-optional union such as ``int | str`` does not silently accept ``None``. + Unions with two or more non-``None`` arms are tagless and cannot be resolved + structurally, so their payload passes through untouched — use a discriminated + union (``@discriminated`` / ``@variant``) when an arm must be reconstructed. + """ + all_args = get_args(target) + args = [a for a in all_args if a is not type(None)] + if data is None and type(None) in all_args: + return None + if len(args) == 1: + return _decode_value(data, args[0], path, tolerate_unknown) + return data + + +def _decode_tristate( + data: object, + target: object, + path: tuple[str, ...], + tolerate_unknown: bool, +) -> object: + """Decode a present key into ``NULL`` or ``Present(inner)``. + + A missing key is handled upstream in ``_decode_fields`` (the kwarg is + omitted so the field default applies); this path only runs for present keys. + """ + if data is None: + return NULL + inner = _tristate_inner(target) + return Present(_decode_value(data, inner, path, tolerate_unknown)) + + +def _decode_enum(data: object, target: type[enum.Enum], path: tuple[str, ...]) -> object: + """Decode a value into an enum member by value.""" + try: + return target(data) + except ValueError as err: + raise CodecError( + f"{data!r} is not a valid {target.__name__}", + path=path, + target_name=target.__name__, + error=err, + ) from err + + +def _decode_temporal(data: object, target: type, path: tuple[str, ...]) -> object: + """Decode an ISO-8601 string into datetime/date/time.""" + if not isinstance(data, str): + raise CodecError( + f"expected an ISO-8601 string, got {type(data).__name__}", + path=path, + target_name=target.__name__, + ) + try: + return target.fromisoformat(data) # type: ignore[attr-defined] + except ValueError as err: + raise CodecError( + f"{data!r} is not a valid {target.__name__}", + path=path, + target_name=target.__name__, + error=err, + ) from err + + +def _dispatch_union( + data: object, + base: type, + path: tuple[str, ...], + tolerate_unknown: bool, +) -> object: + """Resolve a discriminated union to a concrete variant and decode it.""" + if not isinstance(data, cabc.Mapping): + raise CodecError( + f"expected an object, got {type(data).__name__}", + path=path, + target_name=base.__name__, + ) + tag_field: str = getattr(base, DISCRIMINATOR_KEY) + registry = cast("dict[str, type]", getattr(base, REGISTRY_KEY)) + if tag_field not in data: + raise CodecError( + f"missing discriminator field {tag_field!r}", + path=path, + target_name=base.__name__, + ) + tag = data[tag_field] + concrete = registry.get(cast("str", tag)) + if concrete is None: + known = sorted(registry) + raise CodecError( + f"unknown discriminator value {tag!r}; known: {known}", + path=path, + target_name=base.__name__, + ) + return _decode_dataclass( + data, + concrete, + path, + tolerate_unknown, + exempt_key=tag_field, + ) + + +# --------------------------------------------------------------------------- # +# Encode # +# --------------------------------------------------------------------------- # + + +def _encode_value(value: object) -> object: + """Encode a typed value into a plain document.""" + if isinstance(value, enum.Enum): + # Checked before the scalar branch: ``StrEnum`` members are ``str`` and + # ``IntEnum`` members are ``int``, so an earlier scalar check would + # return the member itself instead of collapsing it to ``value.value``. + return value.value + if value is None or isinstance(value, (bool, int, float, str)): + return value + if isinstance(value, Present): + return _encode_value(value.value) + if value is NULL or value is ABSENT: + # Bare tristate sentinels have no enclosing key to fold against; both + # collapse to ``None`` at the top level. The absent-vs-null distinction + # is only observable when folding a dataclass field (see + # ``_encode_dataclass``). + return None + if dataclasses.is_dataclass(value) and not isinstance(value, type): + return _encode_dataclass(value) + if isinstance(value, (_dt.datetime, _dt.date, _dt.time)): + return value.isoformat() + if isinstance(value, bytes): + try: + return value.decode("utf-8") + except UnicodeDecodeError as err: + raise SerializationError("cannot encode non-UTF-8 bytes value") from err + if isinstance(value, cabc.Mapping): + return {k: _encode_value(v) for k, v in value.items()} + if isinstance(value, (list, tuple, set, frozenset)): + return [_encode_value(v) for v in value] + raise SerializationError(f"cannot encode value of type {type(value).__name__}") + + +def _encode_dataclass(value: object) -> dict[str, object]: + """Encode a dataclass into a document, folding tristate fields.""" + assert dataclasses.is_dataclass(value) and not isinstance(value, type) + info = _resolve_info(type(value)) + out: dict[str, object] = {} + for f in dataclasses.fields(value): + wire = info.field_to_wire[f.name] + attr = getattr(value, f.name) + if attr is ABSENT: + continue + if attr is NULL: + out[wire] = None + continue + out[wire] = _encode_value(attr) + return out + + +# --------------------------------------------------------------------------- # +# Type introspection helpers # +# --------------------------------------------------------------------------- # + + +def _resolve_info(target: type) -> _ModelInfo: + """Resolve and cache decode metadata for a dataclass ``target``.""" + info = _MODEL_CACHE.get(target) + if info is not None: + return info + hints = get_type_hints(target, include_extras=True) + field_to_wire: dict[str, str] = {} + wire_to_field: dict[str, str] = {} + for f in dataclasses.fields(target): + wire = f.metadata.get(ALIAS_KEY, f.name) + field_to_wire[f.name] = wire + wire_to_field[wire] = f.name + info = _ModelInfo(hints=hints, field_to_wire=field_to_wire, wire_to_field=wire_to_field) + _MODEL_CACHE[target] = info + return info + + +def _is_tristate(target: object) -> bool: + """Return whether ``target`` is a ``Tristate[X]`` (or its expanded union). + + ``get_type_hints`` resolves a ``type`` alias such as ``Tristate[str]`` to a + ``GenericAlias`` whose origin is the ``Tristate`` alias object itself, not a + ``Union``. Older / expanded forms surface as a ``Present``-bearing union, so + both shapes are recognised. + """ + if get_origin(target) is Tristate: + return True + if get_origin(target) in (Union, types.UnionType): + return any(get_origin(arg) is Present or arg is Present for arg in get_args(target)) + return False + + +def _tristate_inner(target: object) -> object: + """Recover ``X`` from a ``Tristate[X]`` (or its expanded union form).""" + if get_origin(target) is Tristate: + args = get_args(target) + return args[0] if args else object + for arg in get_args(target): + if get_origin(arg) is Present: + inner = get_args(arg) + return inner[0] if inner else object + if arg is Present: + return object + return object + + +def _is_mapping_origin(origin: object) -> bool: + """Return whether ``origin`` is an abstract Mapping origin.""" + return origin in (cabc.Mapping, cabc.MutableMapping) + + +def _is_sequence_origin(origin: object) -> bool: + """Return whether ``origin`` is an abstract Sequence/Set origin.""" + return origin in ( + cabc.Sequence, + cabc.MutableSequence, + cabc.Set, + cabc.MutableSet, + cabc.Iterable, + cabc.Collection, + ) + + +__all__ = [ + "ALIAS_KEY", + "DISCRIMINATOR_KEY", + "REGISTRY_KEY", + "Codec", + "CodecError", + "discriminated", + "field_alias", + "variant", +] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/serde/tristate.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/serde/tristate.py new file mode 100644 index 0000000..cd529e9 --- /dev/null +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/serde/tristate.py @@ -0,0 +1,177 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Three-valued optional type distinguishing "omitted" from "explicit null". + +A field typed ``T | None`` cannot tell apart two wire states that matter for +merge-update (``PATCH``) APIs: a key that was *omitted entirely* versus a key +that was *sent as JSON ``null``*. Omitting ``name`` means "leave it unchanged"; +sending ``name: null`` means "clear it". + +``Tristate[T]`` is a sealed type with exactly three inhabitants: + +- ``ABSENT`` — the key was omitted; on serialize, skip it. +- ``NULL`` — the key was present with value ``null``; on serialize, write ``null``. +- ``Present(value)`` — a real value; on serialize, write ``value``. + +``fold`` forces callers to handle all three cases, so the absent-vs-null +distinction can never be silently dropped. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Final, TypeGuard, final + +if TYPE_CHECKING: + from collections.abc import Callable + + +@final +class _Absent: + """Singleton type for the ``ABSENT`` sentinel — the key was omitted.""" + + __slots__ = () + _instance: _Absent | None = None + + def __new__(cls) -> _Absent: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __repr__(self) -> str: + return "ABSENT" + + def __reduce__(self) -> str: + return "ABSENT" + + +@final +class _Null: + """Singleton type for the ``NULL`` sentinel — the key was an explicit null.""" + + __slots__ = () + _instance: _Null | None = None + + def __new__(cls) -> _Null: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __repr__(self) -> str: + return "NULL" + + def __reduce__(self) -> str: + return "NULL" + + +@final +@dataclass(frozen=True, slots=True) +class Present[T]: + """A real, present value within a ``Tristate``. + + Attributes: + value: The wrapped value. May itself be falsy (``0``, ``""``, ``[]``); + presence is encoded by the wrapper, never by truthiness. + """ + + value: T + + +ABSENT: Final[_Absent] = _Absent() +"""The key was omitted entirely — serialize by skipping the key.""" + +NULL: Final[_Null] = _Null() +"""The key was present with an explicit null — serialize as ``null``.""" + + +type Tristate[T] = _Absent | _Null | Present[T] +"""Three-valued optional: ``ABSENT`` | ``NULL`` | ``Present[T]``.""" + + +def present[T](value: T) -> Present[T]: + """Wrap a concrete value as ``Present``. + + Args: + value: The value to wrap; falsy values are preserved as present. + + Returns: + A ``Present`` holding ``value``. + """ + return Present(value) + + +def of_optional[T](value: T | None) -> _Null | Present[T]: + """Lift a plain optional into a ``Tristate``, mapping ``None`` to ``NULL``. + + Use this when the source can only distinguish "value" from "no value" (a + bare ``T | None``) and ``None`` should mean an explicit null on the wire. + The result can never be ``ABSENT``; omission must be expressed by the caller + choosing ``ABSENT`` directly. + + Args: + value: A value or ``None``. + + Returns: + ``NULL`` if ``value is None``, otherwise ``Present(value)``. + """ + if value is None: + return NULL + return Present(value) + + +def fold[T, R]( + state: Tristate[T], + *, + on_absent: Callable[[], R], + on_null: Callable[[], R], + on_present: Callable[[T], R], +) -> R: + """Collapse a ``Tristate`` to a single value, handling every case. + + Exactly one branch runs. Because all three handlers are required, callers + cannot silently forget the absent-vs-null distinction. + + Args: + state: The tristate to inspect. + on_absent: Called with no arguments when ``state`` is ``ABSENT``. + on_null: Called with no arguments when ``state`` is ``NULL``. + on_present: Called with the wrapped value when ``state`` is ``Present``. + + Returns: + The result of whichever handler matched ``state``. + """ + if isinstance(state, Present): + return on_present(state.value) + if state is NULL: + return on_null() + return on_absent() + + +def is_absent[T](state: Tristate[T]) -> TypeGuard[_Absent]: + """Return whether ``state`` is the ``ABSENT`` sentinel.""" + return state is ABSENT + + +def is_null[T](state: Tristate[T]) -> TypeGuard[_Null]: + """Return whether ``state`` is the ``NULL`` sentinel.""" + return state is NULL + + +def is_present[T](state: Tristate[T]) -> TypeGuard[Present[T]]: + """Return whether ``state`` is a ``Present`` value.""" + return isinstance(state, Present) + + +__all__ = [ + "ABSENT", + "NULL", + "Present", + "Tristate", + "fold", + "is_absent", + "is_null", + "is_present", + "of_optional", + "present", +] diff --git a/packages/dexpace-sdk-core/tests/auth/test_digest_charset.py b/packages/dexpace-sdk-core/tests/auth/test_digest_charset.py new file mode 100644 index 0000000..5f042b7 --- /dev/null +++ b/packages/dexpace-sdk-core/tests/auth/test_digest_charset.py @@ -0,0 +1,108 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Tests for Digest credential-charset selection (RFC 7616 §3.4). + +The ``charset`` directive controls how ``username`` and ``password`` are +encoded before hashing. ``charset=UTF-8`` selects UTF-8; its absence (or any +other value) falls back to the legacy ISO-8859-1 default. For a password with +a non-ASCII character the two encodings yield distinct ``response`` digests, +so the chosen branch is observable end to end. +""" + +from __future__ import annotations + +import re + +from dexpace.sdk.core.http.auth import ( + AuthenticateChallenge, + DigestChallengeHandler, +) +from dexpace.sdk.core.http.common.url import Url +from dexpace.sdk.core.http.request.method import Method + +_USERNAME = "Mufasa" +# ``é`` (U+00E9) encodes to one byte in ISO-8859-1 and two in UTF-8. +_PASSWORD = "Circle of Lifé" +_REALM = "http-auth@example.org" +_NONCE = "7ypf/xlj9XXwfDPEoM4URrv/xwf94BcCAzFZH4GiTo0v" +_CNONCE = "f2/wE4q74E6zIJEtWaHKaf5wv/H5QzzpXusqGemxURZJ" +_URL = Url(scheme="https", host="example.org", path="/dir/index.html") + +# Reference MD5/qop=auth responses computed for ``_PASSWORD`` under each codec. +_RESPONSE_UTF8 = "9c931f7ef105acbcc1e9f99ba923d170" +_RESPONSE_ISO_8859_1 = "0d97ea03337d88fc75698b9ef88d349d" + + +def _parse_auth(value: str) -> dict[str, str]: + assert value.startswith("Digest ") + body = value[len("Digest ") :] + out: dict[str, str] = {} + parts = re.split(r",\s*(?=[a-zA-Z][a-zA-Z0-9_-]*=)", body) + for part in parts: + key, _, raw = part.partition("=") + raw = raw.strip() + if raw.startswith('"') and raw.endswith('"'): + raw = raw[1:-1] + out[key.strip().lower()] = raw + return out + + +def _handle(parameters: dict[str, str]) -> dict[str, str]: + handler = DigestChallengeHandler( + _USERNAME, + _PASSWORD, + preferred_algorithms=("MD5",), + cnonce_factory=lambda: _CNONCE, + ) + challenge = AuthenticateChallenge(scheme="Digest", parameters=parameters) + result = handler.handle(Method.GET, _URL, [challenge], is_proxy=False) + assert result is not None + _, value = result + return _parse_auth(value) + + +def _base_params() -> dict[str, str]: + return { + "realm": _REALM, + "qop": "auth", + "nonce": _NONCE, + "algorithm": "MD5", + } + + +def test_uses_utf8_when_charset_advertised() -> None: + params = {**_base_params(), "charset": "UTF-8"} + + parsed = _handle(params) + + assert parsed["response"] == _RESPONSE_UTF8 + + +def test_charset_directive_is_case_insensitive() -> None: + params = {**_base_params(), "charset": "utf-8"} + + parsed = _handle(params) + + assert parsed["response"] == _RESPONSE_UTF8 + + +def test_uses_iso_8859_1_when_charset_absent() -> None: + parsed = _handle(_base_params()) + + assert parsed["response"] == _RESPONSE_ISO_8859_1 + + +def test_uses_iso_8859_1_for_unrecognised_charset() -> None: + params = {**_base_params(), "charset": "US-ASCII"} + + parsed = _handle(params) + + assert parsed["response"] == _RESPONSE_ISO_8859_1 + + +def test_charset_branches_diverge_for_non_ascii_secret() -> None: + with_utf8 = _handle({**_base_params(), "charset": "UTF-8"}) + without = _handle(_base_params()) + + assert with_utf8["response"] != without["response"] diff --git a/packages/dexpace-sdk-core/tests/errors/test_error_caps.py b/packages/dexpace-sdk-core/tests/errors/test_error_caps.py new file mode 100644 index 0000000..4c3626d --- /dev/null +++ b/packages/dexpace-sdk-core/tests/errors/test_error_caps.py @@ -0,0 +1,134 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Tests for the ``retryable`` flag and non-consuming ``body_snapshot``. + +Covers the two capabilities added to ``HttpResponseError``: a status-derived +``retryable`` flag the retry policy can read directly, and a +``body_snapshot`` preview that never drains a single-use response body. +""" + +from __future__ import annotations + +import pytest + +from dexpace.sdk.core.errors import HttpResponseError, ResourceNotFoundError +from dexpace.sdk.core.http.common import MediaType, Protocol, Url +from dexpace.sdk.core.http.request import Method, Request +from dexpace.sdk.core.http.response import ( + LoggableResponseBody, + Response, + ResponseBody, + Status, +) + + +def _response(status: Status, *, body: ResponseBody | None = None) -> Response: + request = Request(method=Method.GET, url=Url.parse("https://example.com/")) + return Response( + request=request, + protocol=Protocol.HTTP_1_1, + status=status, + body=body, + ) + + +# ----- retryable flag ----------------------------------------------------- + + +@pytest.mark.parametrize( + "status", + [ + Status.REQUEST_TIMEOUT, + Status.TOO_MANY_REQUESTS, + Status.INTERNAL_SERVER_ERROR, + Status.BAD_GATEWAY, + Status.SERVICE_UNAVAILABLE, + Status.GATEWAY_TIMEOUT, + ], + ids=["408", "429", "500", "502", "503", "504"], +) +def test_retryable_is_true_for_transient_status(status: Status) -> None: + err = HttpResponseError(response=_response(status)) + assert err.retryable is True + + +@pytest.mark.parametrize( + "status", + [Status.BAD_REQUEST, Status.NOT_FOUND, Status.CONFLICT, Status.NOT_IMPLEMENTED], + ids=["400", "404", "409", "501"], +) +def test_retryable_is_false_for_terminal_status(status: Status) -> None: + err = HttpResponseError(response=_response(status)) + assert err.retryable is False + + +def test_retryable_is_false_when_no_response() -> None: + err = HttpResponseError("no response captured") + assert err.retryable is False + + +def test_retryable_override_forces_true() -> None: + err = HttpResponseError(response=_response(Status.NOT_FOUND), retryable=True) + assert err.retryable is True + + +def test_retryable_override_forces_false() -> None: + err = HttpResponseError(response=_response(Status.SERVICE_UNAVAILABLE), retryable=False) + assert err.retryable is False + + +def test_retryable_inherited_by_subclasses() -> None: + err = ResourceNotFoundError(response=_response(Status.NOT_FOUND)) + assert err.retryable is False + + +# ----- body_snapshot ------------------------------------------------------ + + +def test_body_snapshot_previews_loggable_body_without_consuming() -> None: + loggable = LoggableResponseBody( + ResponseBody.from_bytes(b'{"error":"boom"}', MediaType.parse("application/json")), + ) + err = HttpResponseError(response=_response(Status.BAD_REQUEST, body=loggable)) + + preview = err.body_snapshot() + + assert preview == b'{"error":"boom"}' + # Preview must not consume: the body is still readable afterwards. + assert loggable.bytes() == b'{"error":"boom"}' + + +def test_body_snapshot_truncates_to_max_bytes() -> None: + loggable = LoggableResponseBody(ResponseBody.from_bytes(b"0123456789")) + err = HttpResponseError(response=_response(Status.BAD_REQUEST, body=loggable)) + + assert err.body_snapshot(4) == b"0123" + + +def test_body_snapshot_returns_empty_for_single_use_body() -> None: + # A plain bytes-backed body is single-use; previewing it must not drain it. + body = ResponseBody.from_bytes(b"unsafe-to-peek") + err = HttpResponseError(response=_response(Status.BAD_REQUEST, body=body)) + + assert err.body_snapshot() == b"" + # The underlying body is untouched and still fully readable. + assert body.bytes() == b"unsafe-to-peek" + + +def test_body_snapshot_returns_empty_when_no_body() -> None: + err = HttpResponseError(response=_response(Status.BAD_REQUEST)) + assert err.body_snapshot() == b"" + + +def test_body_snapshot_returns_empty_when_no_response() -> None: + err = HttpResponseError("no response captured") + assert err.body_snapshot() == b"" + + +def test_body_snapshot_rejects_negative_max_bytes() -> None: + loggable = LoggableResponseBody(ResponseBody.from_bytes(b"data")) + err = HttpResponseError(response=_response(Status.BAD_REQUEST, body=loggable)) + + with pytest.raises(ValueError, match="max_bytes must be non-negative"): + err.body_snapshot(-1) diff --git a/packages/dexpace-sdk-core/tests/http/test_async_cancellation.py b/packages/dexpace-sdk-core/tests/http/test_async_cancellation.py new file mode 100644 index 0000000..ecf1353 --- /dev/null +++ b/packages/dexpace-sdk-core/tests/http/test_async_cancellation.py @@ -0,0 +1,215 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Cancellation discipline for async response bodies and the SSE stream. + +These tests cancel a task mid-read / mid-stream and assert two invariants of +the shielded-cleanup convention (P9): + +- the underlying transport handle is released (``close`` / ``aclose`` runs to + completion), and +- ``asyncio.CancelledError`` continues to propagate — cleanup never swallows + it. +""" + +from __future__ import annotations + +import asyncio + +import pytest + +from dexpace.sdk.core.http.response import AsyncResponse, AsyncResponseBody +from dexpace.sdk.core.http.response.async_response_body import _shielded_cleanup +from dexpace.sdk.core.http.sse.parser import parse_async_events + + +class _SlowStream: + """Async stream whose ``read`` blocks and whose ``close`` is slow. + + ``read`` parks on an event so a consumer can be cancelled while awaiting + it. ``close`` awaits a short sleep so the test can observe whether the + close ran to completion under cancellation rather than being interrupted. + """ + + def __init__(self, *, payload: bytes = b"chunk") -> None: + self._payload = payload + self._gate = asyncio.Event() + self.closed = False + self.close_completed = False + + async def read(self, size: int = -1) -> bytes: + await self._gate.wait() + return self._payload + + async def close(self) -> object: + self.closed = True + # A real transport close yields to the loop; make sure the await + # completes even though the enclosing task was cancelled. + await asyncio.sleep(0) + self.close_completed = True + return None + + +async def test_aiter_bytes_releases_stream_and_propagates_when_cancelled_mid_read() -> None: + stream = _SlowStream() + body = AsyncResponseBody.from_async_stream(stream) + + async def consume() -> None: + async for _ in body.aiter_bytes(): + pass + + task = asyncio.ensure_future(consume()) + await asyncio.sleep(0) # let the task reach the blocking read + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + assert stream.closed is True + assert stream.close_completed is True + + +async def test_response_aexit_releases_body_when_cancelled() -> None: + stream = _SlowStream() + body = AsyncResponseBody.from_async_stream(stream) + from dexpace.sdk.core.http.common.protocol import Protocol + from dexpace.sdk.core.http.common.url import Url + from dexpace.sdk.core.http.request import Method, Request + from dexpace.sdk.core.http.response import Status + + request = Request(method=Method.GET, url=Url.parse("https://example.test/")) + response = AsyncResponse( + request=request, + protocol=Protocol.HTTP_1_1, + status=Status.OK, + body=body, + ) + + async def consume() -> None: + async with response: + async for _ in body.aiter_bytes(): + pass + + task = asyncio.ensure_future(consume()) + await asyncio.sleep(0) + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + assert stream.closed is True + assert stream.close_completed is True + + +async def test_close_runs_to_completion_under_direct_cancellation() -> None: + """A scope cancelled while awaiting ``close`` still releases the stream.""" + stream = _SlowStream() + body = AsyncResponseBody.from_async_stream(stream) + # Consume nothing; close directly inside a task that is then cancelled. + + started = asyncio.Event() + + async def closer() -> None: + started.set() + await body.close() + + task = asyncio.ensure_future(closer()) + await started.wait() + await asyncio.sleep(0) # let close() reach its inner await + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + assert stream.closed is True + assert stream.close_completed is True + + +class _SlowSseChunks: + """Async byte iterator for SSE whose ``aclose`` is slow and observable.""" + + def __init__(self) -> None: + self._gate = asyncio.Event() + self._sent = False + self.aclosed = False + self.aclose_completed = False + + def __aiter__(self) -> _SlowSseChunks: + return self + + async def __anext__(self) -> bytes: + if not self._sent: + self._sent = True + return b"data: first\n\n" + await self._gate.wait() # block forever until cancelled + return b"" + + async def aclose(self) -> None: + self.aclosed = True + await asyncio.sleep(0) + self.aclose_completed = True + + +async def test_sse_stream_releases_upstream_and_propagates_when_cancelled_mid_stream() -> None: + chunks = _SlowSseChunks() + stream = parse_async_events(chunks) + seen: list[str] = [] + + async def consume() -> None: + async with stream: + async for event in stream: + seen.append(event.data) + + task = asyncio.ensure_future(consume()) + # Pump the loop until the first event is consumed and the iterator blocks. + for _ in range(10): + await asyncio.sleep(0) + if seen: + break + assert seen == ["first"] + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + assert chunks.aclosed is True + assert chunks.aclose_completed is True + + +async def test_sse_aclose_is_idempotent() -> None: + chunks = _SlowSseChunks() + stream = parse_async_events(chunks) + await stream.aclose() + await stream.aclose() + assert chunks.aclosed is True + assert chunks.aclose_completed is True + + +async def test_shielded_cleanup_surfaces_failure_without_cancellation() -> None: + async def failing() -> None: + raise RuntimeError("close failed") + + # With no outer cancellation pending, a cleanup failure surfaces unchanged. + with pytest.raises(RuntimeError, match="close failed"): + await _shielded_cleanup(failing()) + + +async def test_shielded_cleanup_cancellation_takes_precedence_over_failure() -> None: + gate = asyncio.Event() + + async def failing_after_gate() -> None: + await gate.wait() + raise RuntimeError("close failed") + + task = asyncio.ensure_future(_shielded_cleanup(failing_after_gate())) + # Let the task park on the shielded wait, then cancel its wait (not the + # shielded cleanup itself). + for _ in range(5): + await asyncio.sleep(0) + task.cancel() + await asyncio.sleep(0) + # Release the cleanup so it now finishes — by raising. A pending + # cancellation must win over that cleanup error. + gate.set() + with pytest.raises(asyncio.CancelledError): + await task diff --git a/packages/dexpace-sdk-core/tests/http/test_loggable_body_fixes.py b/packages/dexpace-sdk-core/tests/http/test_loggable_body_fixes.py new file mode 100644 index 0000000..b6becbe --- /dev/null +++ b/packages/dexpace-sdk-core/tests/http/test_loggable_body_fixes.py @@ -0,0 +1,210 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Tests for the F1/F2/F3 loggable-body fixes. + +F1: a mid-drain error retains the partial bytes, stores the exception, and +re-raises it from ``iter_bytes`` on every call while ``snapshot`` still +yields the partial bytes. +F2: the one-time drain is thread-safe — concurrent first readers consume the +underlying single-use body exactly once. +F3: ``snapshot(max_bytes)`` caps the copy without materialising more than +``max_bytes``. +""" + +from __future__ import annotations + +import threading +from collections.abc import Iterator + +import pytest + +from dexpace.sdk.core.http.common.media_type import MediaType +from dexpace.sdk.core.http.request import LoggableRequestBody, RequestBody +from dexpace.sdk.core.http.response import LoggableResponseBody +from dexpace.sdk.core.http.response.response_body import ResponseBody + + +class _FailingResponseBody(ResponseBody): + """A single-use body that yields some chunks then raises mid-stream.""" + + def __init__(self, chunks: list[bytes], error: BaseException) -> None: + self._chunks = chunks + self._error = error + self._consumed = False + self.closed = False + + def media_type(self) -> MediaType | None: + return None + + def content_length(self) -> int: + return -1 + + def iter_bytes(self, chunk_size: int = 64 * 1024) -> Iterator[bytes]: + if self._consumed: + raise RuntimeError("ResponseBody has already been consumed") + self._consumed = True + yield from self._chunks + raise self._error + + def close(self) -> None: + self.closed = True + + +class _CountingResponseBody(ResponseBody): + """A single-use body that records how often it is iterated. + + The first chunk is delayed slightly so the drain holds its lock long + enough for every racing reader to reach the ``_drained`` check, reliably + exercising the double-checked-locking path. + """ + + def __init__(self, data: bytes) -> None: + self._data = data + self.iter_calls = 0 + self._consumed = False + + def media_type(self) -> MediaType | None: + return None + + def content_length(self) -> int: + return len(self._data) + + def iter_bytes(self, chunk_size: int = 64 * 1024) -> Iterator[bytes]: + self.iter_calls += 1 + if self._consumed: + raise RuntimeError("ResponseBody has already been consumed") + self._consumed = True + # Widen the critical section so racing readers pile up on the lock. + threading.Event().wait(0.05) + yield self._data + + def close(self) -> None: + return None + + +class TestResponseErrorPath: + def test_iter_bytes_reraises_stored_error_on_every_call(self) -> None: + boom = ConnectionError("network dropped") + inner = _FailingResponseBody([b"abc", b"def"], boom) + body = LoggableResponseBody(inner) + + with pytest.raises(ConnectionError) as first: + list(body.iter_bytes()) + assert first.value is boom + + # Re-raises on every subsequent call, not just the first. + with pytest.raises(ConnectionError) as second: + list(body.iter_bytes()) + assert second.value is boom + + def test_snapshot_returns_partial_bytes_not_empty(self) -> None: + inner = _FailingResponseBody([b"abc", b"def"], ConnectionError("drop")) + body = LoggableResponseBody(inner) + + # snapshot must surface the partial read for post-mortem logging. + assert body.snapshot() == b"abcdef" + assert body.captured_size == 6 + + def test_snapshot_partial_then_iter_still_raises(self) -> None: + boom = ConnectionError("drop") + inner = _FailingResponseBody([b"xy"], boom) + body = LoggableResponseBody(inner) + + assert body.snapshot() == b"xy" + with pytest.raises(ConnectionError) as exc: + list(body.iter_bytes()) + assert exc.value is boom + + def test_bounded_snapshot_caps_partial_bytes(self) -> None: + inner = _FailingResponseBody([b"abcdef"], ConnectionError("drop")) + body = LoggableResponseBody(inner) + assert body.snapshot(max_bytes=3) == b"abc" + + +class TestResponseThreadSafety: + def test_concurrent_first_read_drains_exactly_once(self) -> None: + threads_count = 8 + start = threading.Barrier(threads_count) + inner = _CountingResponseBody(b"payload") + body = LoggableResponseBody(inner) + + results: list[bytes] = [] + errors: list[BaseException] = [] + lock = threading.Lock() + + def reader() -> None: + # All readers fire iter_bytes at the same instant. + start.wait(timeout=5) + try: + data = b"".join(body.iter_bytes()) + except BaseException as exc: + with lock: + errors.append(exc) + else: + with lock: + results.append(data) + + threads = [threading.Thread(target=reader) for _ in range(threads_count)] + for thread in threads: + thread.start() + for thread in threads: + thread.join(timeout=5) + + assert errors == [] + assert inner.iter_calls == 1 + assert results == [b"payload"] * threads_count + + +class TestResponseBoundedSnapshot: + def test_snapshot_caps_copy(self) -> None: + inner = ResponseBody.from_bytes(b"0123456789") + body = LoggableResponseBody(inner) + assert body.snapshot(max_bytes=4) == b"0123" + + def test_snapshot_max_bytes_larger_than_body(self) -> None: + inner = ResponseBody.from_bytes(b"abc") + body = LoggableResponseBody(inner) + assert body.snapshot(max_bytes=100) == b"abc" + + def test_snapshot_none_returns_full(self) -> None: + inner = ResponseBody.from_bytes(b"abcde") + body = LoggableResponseBody(inner) + assert body.snapshot() == b"abcde" + + def test_snapshot_negative_max_bytes_rejected(self) -> None: + inner = ResponseBody.from_bytes(b"abc") + body = LoggableResponseBody(inner) + with pytest.raises(ValueError, match="non-negative"): + body.snapshot(max_bytes=-1) + + +class TestRequestBoundedSnapshot: + def test_snapshot_caps_copy(self) -> None: + body = LoggableRequestBody(RequestBody.from_bytes(b"0123456789")) + list(body.iter_bytes()) + assert body.snapshot(max_bytes=4) == b"0123" + + def test_snapshot_max_bytes_larger_than_tap(self) -> None: + body = LoggableRequestBody(RequestBody.from_bytes(b"abc")) + list(body.iter_bytes()) + assert body.snapshot(max_bytes=100) == b"abc" + + def test_snapshot_none_returns_full(self) -> None: + body = LoggableRequestBody(RequestBody.from_bytes(b"abcde")) + list(body.iter_bytes()) + assert body.snapshot() == b"abcde" + + def test_snapshot_negative_max_bytes_rejected(self) -> None: + body = LoggableRequestBody(RequestBody.from_bytes(b"abc")) + list(body.iter_bytes()) + with pytest.raises(ValueError, match="non-negative"): + body.snapshot(max_bytes=-1) + + def test_snapshot_after_cap_does_not_block_further_writes(self) -> None: + # getbuffer() must be released before more writes; a follow-up + # snapshot proves the temporary view did not pin the BytesIO. + body = LoggableRequestBody(RequestBody.from_bytes(b"0123456789")) + list(body.iter_bytes()) + assert body.snapshot(max_bytes=2) == b"01" + assert body.snapshot() == b"0123456789" diff --git a/packages/dexpace-sdk-core/tests/http/test_media_type_params.py b/packages/dexpace-sdk-core/tests/http/test_media_type_params.py new file mode 100644 index 0000000..054dea5 --- /dev/null +++ b/packages/dexpace-sdk-core/tests/http/test_media_type_params.py @@ -0,0 +1,108 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""F8 parameter-parsing edge cases for ``MediaType``. + +Covers the four correctness-bearing details called out in the platform +analysis: splitting a parameter on the *first* ``=`` only, stripping and +unescaping quoted-strings per RFC 7230 §3.2.6, lower-casing type/subtype and +parameter *keys* while preserving parameter *values*, and degrading an unknown +charset to ``None`` instead of raising. +""" + +from __future__ import annotations + +import pytest + +from dexpace.sdk.core.http.common import MediaType + + +class TestSplitOnFirstEquals: + def test_boundary_with_multiple_equals_keeps_remainder(self) -> None: + mt = MediaType.parse("multipart/form-data; boundary=abc=def") + assert dict(mt.parameters)["boundary"] == "abc=def" + + def test_base64_padding_in_value_preserved(self) -> None: + mt = MediaType.parse("application/octet-stream; tag=YWJj==") + assert dict(mt.parameters)["tag"] == "YWJj==" + + def test_quoted_value_containing_equals(self) -> None: + mt = MediaType.parse('multipart/mixed; boundary="a=b=c"') + assert dict(mt.parameters)["boundary"] == "a=b=c" + + +class TestQuotedStringHandling: + def test_strips_surrounding_quotes(self) -> None: + mt = MediaType.parse('text/plain; charset="utf-8"') + assert mt.charset == "utf-8" + + def test_unescapes_escaped_quote(self) -> None: + # \" -> " + mt = MediaType.parse('text/plain; foo="a\\"b"') + assert dict(mt.parameters)["foo"] == 'a"b' + + def test_unescapes_escaped_backslash(self) -> None: + # \\ -> \ + mt = MediaType.parse('text/plain; foo="a\\\\b"') + assert dict(mt.parameters)["foo"] == "a\\b" + + def test_unescapes_mixed_quoted_pairs(self) -> None: + mt = MediaType.parse('text/plain; foo="a\\\\b\\"c"') + assert dict(mt.parameters)["foo"] == 'a\\b"c' + + def test_bare_value_left_unchanged(self) -> None: + mt = MediaType.parse("text/plain; charset=utf-8") + assert mt.charset == "utf-8" + + def test_quoted_value_with_separators_preserved(self) -> None: + # A quoted-string may legitimately contain token separators such as + # spaces and semicolons (the latter only because it is quoted). + mt = MediaType.parse('multipart/form-data; boundary="foo bar"') + assert dict(mt.parameters)["boundary"] == "foo bar" + + +class TestCaseFolding: + def test_type_and_subtype_lowercased(self) -> None: + mt = MediaType.parse("APPLICATION/JSON") + assert mt.full_type == "application/json" + + def test_parameter_key_lowercased(self) -> None: + mt = MediaType.parse("text/plain; CHARSET=utf-8") + assert "charset" in dict(mt.parameters) + + def test_parameter_value_case_preserved(self) -> None: + # Boundaries and base64 tags are case-sensitive — only the key folds. + mt = MediaType.parse("multipart/form-data; Boundary=AbCdEf") + assert dict(mt.parameters)["boundary"] == "AbCdEf" + + def test_charset_value_case_preserved(self) -> None: + mt = MediaType.parse("text/plain; charset=UTF-8") + assert mt.charset == "UTF-8" + + +class TestUnknownCharsetDegradesToNone: + def test_unknown_charset_returns_none(self) -> None: + mt = MediaType.parse("text/plain; charset=not-a-real-encoding") + # The parameter is retained verbatim ... + assert dict(mt.parameters)["charset"] == "not-a-real-encoding" + # ... but the typed accessor degrades to None rather than raising. + assert mt.charset is None + + def test_unknown_charset_does_not_raise_on_parse(self) -> None: + # Parsing must succeed even with a nonsense charset. + mt = MediaType.parse("text/plain; charset=utf-99") + assert mt.full_type == "text/plain" + assert mt.charset is None + + @pytest.mark.parametrize( + "charset", + ["utf-8", "UTF-8", "latin-1", "iso-8859-1", "ascii", "utf-16"], + ids=["utf8", "utf8-upper", "latin1", "iso8859", "ascii", "utf16"], + ) + def test_known_charsets_round_trip(self, charset: str) -> None: + mt = MediaType.parse(f"text/plain; charset={charset}") + assert mt.charset == charset + + def test_absent_charset_returns_none(self) -> None: + mt = MediaType.parse("application/json") + assert mt.charset is None diff --git a/packages/dexpace-sdk-core/tests/instrumentation/test_correlation.py b/packages/dexpace-sdk-core/tests/instrumentation/test_correlation.py new file mode 100644 index 0000000..2691d64 --- /dev/null +++ b/packages/dexpace-sdk-core/tests/instrumentation/test_correlation.py @@ -0,0 +1,130 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Tests for context-local trace/span correlation and log stamping.""" + +from __future__ import annotations + +import asyncio +import logging + +import pytest +from _pytest.logging import LogCaptureFixture + +from dexpace.sdk.core.instrumentation import ( + ClientLogger, + CorrelationFilter, + bind_correlation, + get_span_id, + get_trace_id, + set_span_id, + set_trace_id, +) + + +@pytest.fixture(autouse=True) +def _clear_correlation() -> None: + """Each test starts with no bound ids (contextvars default to ``None``).""" + set_trace_id(None) + set_span_id(None) + + +def test_getters_default_to_none() -> None: + assert get_trace_id() is None + assert get_span_id() is None + + +def test_set_and_get_roundtrip() -> None: + set_trace_id("trace-1") + set_span_id("span-1") + assert get_trace_id() == "trace-1" + assert get_span_id() == "span-1" + + +def test_set_returns_token_that_resets() -> None: + set_trace_id("outer") + token = set_trace_id("inner") + assert get_trace_id() == "inner" + token.var.reset(token) + assert get_trace_id() == "outer" + + +def test_bind_correlation_scopes_and_restores() -> None: + set_trace_id("outer-trace") + set_span_id("outer-span") + with bind_correlation(trace_id="inner-trace", span_id="inner-span"): + assert get_trace_id() == "inner-trace" + assert get_span_id() == "inner-span" + assert get_trace_id() == "outer-trace" + assert get_span_id() == "outer-span" + + +def test_bind_correlation_restores_on_exception() -> None: + set_trace_id("outer") + with pytest.raises(RuntimeError), bind_correlation(trace_id="inner"): + assert get_trace_id() == "inner" + raise RuntimeError("boom") + assert get_trace_id() == "outer" + + +def test_logger_stamps_bound_ids_into_message(caplog: LogCaptureFixture) -> None: + caplog.set_level(logging.INFO, logger="dexpace.test.corr.msg") + logger = ClientLogger("dexpace.test.corr.msg") + with bind_correlation(trace_id="t-42", span_id="s-7"): + logger.info("request") + + rendered = caplog.records[-1].getMessage() + assert "trace.id=t-42" in rendered + assert "span.id=s-7" in rendered + + +def test_logger_omits_unset_ids_from_message(caplog: LogCaptureFixture) -> None: + caplog.set_level(logging.INFO, logger="dexpace.test.corr.unset") + logger = ClientLogger("dexpace.test.corr.unset") + logger.info("request") + + rendered = caplog.records[-1].getMessage() + assert "trace.id=" not in rendered + assert "span.id=" not in rendered + + +def test_filter_sets_record_attributes(caplog: LogCaptureFixture) -> None: + caplog.set_level(logging.INFO, logger="dexpace.test.corr.attr") + logger = ClientLogger("dexpace.test.corr.attr") + with bind_correlation(trace_id="trace-x", span_id="span-y"): + logger.info("event") + + record = caplog.records[-1] + assert record.trace_id == "trace-x" # type: ignore[attr-defined] + assert record.span_id == "span-y" # type: ignore[attr-defined] + assert getattr(record, "trace.id") == "trace-x" + assert getattr(record, "span.id") == "span-y" + + +def test_filter_sets_none_when_unbound(caplog: LogCaptureFixture) -> None: + caplog.set_level(logging.INFO, logger="dexpace.test.corr.none") + logger = ClientLogger("dexpace.test.corr.none") + logger.info("event") + + record = caplog.records[-1] + assert record.trace_id is None # type: ignore[attr-defined] + assert record.span_id is None # type: ignore[attr-defined] + + +def test_correlation_filter_installed_once() -> None: + name = "dexpace.test.corr.once" + ClientLogger(name) + ClientLogger(name) + installed = [f for f in logging.getLogger(name).filters if isinstance(f, CorrelationFilter)] + assert len(installed) == 1 + + +def test_ids_propagate_across_await() -> None: + async def _run() -> tuple[str | None, str | None]: + with bind_correlation(trace_id="async-trace", span_id="async-span"): + await asyncio.sleep(0) + return get_trace_id(), get_span_id() + + trace_id, span_id = asyncio.run(_run()) + assert trace_id == "async-trace" + assert span_id == "async-span" diff --git a/packages/dexpace-sdk-core/tests/instrumentation/test_http_tracer.py b/packages/dexpace-sdk-core/tests/instrumentation/test_http_tracer.py new file mode 100644 index 0000000..7f599b5 --- /dev/null +++ b/packages/dexpace-sdk-core/tests/instrumentation/test_http_tracer.py @@ -0,0 +1,75 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Tests for the ``HttpTracer`` no-op contract and factory wiring.""" + +from __future__ import annotations + +from dexpace.sdk.core.instrumentation import ( + NOOP_HTTP_TRACER, + NOOP_HTTP_TRACER_FACTORY, + NOOP_INSTRUMENTATION_CONTEXT, + HttpTracer, + HttpTracerFactory, +) + + +def test_noop_tracer_callbacks_are_silent_and_return_none() -> None: + tracer = NOOP_HTTP_TRACER + + # Every callback is declared to return None; calling them must not raise. + tracer.operation_started() + tracer.operation_succeeded() + tracer.operation_failed(RuntimeError("boom")) + tracer.attempt_started(0) + tracer.attempt_failed(RuntimeError("boom"), 0.5) + tracer.attempt_retries_exhausted() + tracer.request_url_resolved("https://example.test/v1") + tracer.request_sent(128) + tracer.response_headers_received(200, {"content-type": "application/json"}) + tracer.response_received(256) + tracer.connection_acquired("example.test", 443) + + +def test_noop_tracer_is_an_http_tracer() -> None: + assert isinstance(NOOP_HTTP_TRACER, HttpTracer) + + +def test_noop_factory_creates_the_shared_noop_tracer() -> None: + created = NOOP_HTTP_TRACER_FACTORY.create() + assert created is NOOP_HTTP_TRACER + + +def test_noop_factory_satisfies_the_factory_protocol() -> None: + assert isinstance(NOOP_HTTP_TRACER_FACTORY, HttpTracerFactory) + + +def test_instrumentation_context_defaults_to_noop_factory() -> None: + assert NOOP_INSTRUMENTATION_CONTEXT.http_tracer_factory is NOOP_HTTP_TRACER_FACTORY + + +def test_subclass_overrides_only_chosen_events() -> None: + events: list[tuple[str, object]] = [] + + class _RecordingTracer(HttpTracer): + def attempt_started(self, attempt: int) -> None: + events.append(("attempt_started", attempt)) + + def attempt_failed(self, error: BaseException, next_delay: float) -> None: + events.append(("attempt_failed", next_delay)) + + tracer = _RecordingTracer() + tracer.operation_started() # inherited no-op, must not raise + tracer.attempt_started(2) + tracer.attempt_failed(ValueError("x"), 1.25) + tracer.connection_acquired("host", 80) # inherited no-op + + assert events == [("attempt_started", 2), ("attempt_failed", 1.25)] + + +def test_custom_factory_is_a_factory() -> None: + class _Factory: + def create(self) -> HttpTracer: + return NOOP_HTTP_TRACER + + assert isinstance(_Factory(), HttpTracerFactory) diff --git a/packages/dexpace-sdk-core/tests/pagination/__init__.py b/packages/dexpace-sdk-core/tests/pagination/__init__.py new file mode 100644 index 0000000..a69f5b7 --- /dev/null +++ b/packages/dexpace-sdk-core/tests/pagination/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. diff --git a/packages/dexpace-sdk-core/tests/pagination/test_async_paginator.py b/packages/dexpace-sdk-core/tests/pagination/test_async_paginator.py new file mode 100644 index 0000000..53486ca --- /dev/null +++ b/packages/dexpace-sdk-core/tests/pagination/test_async_paginator.py @@ -0,0 +1,135 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Tests for ``AsyncPaginator`` driven through a mock async pipeline.""" + +from __future__ import annotations + +import json +from collections.abc import AsyncIterator + +from dexpace.sdk.core.http.common import Headers, MediaType, Protocol, Url +from dexpace.sdk.core.http.context import DispatchContext +from dexpace.sdk.core.http.request import Method, Request +from dexpace.sdk.core.http.response import AsyncResponse, Status +from dexpace.sdk.core.http.response.async_response_body import AsyncResponseBody +from dexpace.sdk.core.pagination import AsyncPaginator, CursorStrategy + + +class _MockAsyncPipeline: + """Async stand-in pipeline mapping a cursor value to a canned page body.""" + + def __init__(self, pages: dict[str | None, dict[str, object]]) -> None: + self._pages = pages + self.calls: list[Request] = [] + self.closed_bodies: list[_TrackingAsyncBody] = [] + + async def run(self, request: Request, _dispatch: DispatchContext) -> AsyncResponse: + self.calls.append(request) + cursor = request.url.query.get("cursor") + payload = self._pages[cursor] + body = _TrackingAsyncBody(json.dumps(payload).encode("utf-8")) + self.closed_bodies.append(body) + return AsyncResponse( + request=request, + protocol=Protocol.HTTP_1_1, + status=Status.OK, + headers=Headers(), + body=body, + ) + + +class _TrackingAsyncBody(AsyncResponseBody): + """In-memory async body that records when it is closed.""" + + def __init__(self, data: bytes) -> None: + self._data = data + self.closed = False + + def media_type(self) -> MediaType | None: + return None + + def content_length(self) -> int: + return len(self._data) + + async def aiter_bytes(self, chunk_size: int = 64 * 1024) -> AsyncIterator[bytes]: + yield self._data + + async def close(self) -> None: + self.closed = True + + +def _first_request() -> Request: + return Request(method=Method.GET, url=Url.parse("https://api.example.com/items")) + + +def _strategy() -> CursorStrategy[int]: + return CursorStrategy( + items_field="data", + cursor_response_field="next_cursor", + cursor_param="cursor", + ) + + +def _three_page_pipeline() -> _MockAsyncPipeline: + return _MockAsyncPipeline( + { + None: {"data": [1, 2], "next_cursor": "c1"}, + "c1": {"data": [3, 4], "next_cursor": "c2"}, + "c2": {"data": [5], "next_cursor": None}, + }, + ) + + +async def test_iterates_items_across_all_pages() -> None: + pipeline = _three_page_pipeline() + paginator: AsyncPaginator[int] = AsyncPaginator(pipeline, _strategy(), _first_request()) + items = [item async for item in paginator] + assert items == [1, 2, 3, 4, 5] + + +async def test_drives_pipeline_once_per_page_with_cursor() -> None: + pipeline = _three_page_pipeline() + paginator: AsyncPaginator[int] = AsyncPaginator(pipeline, _strategy(), _first_request()) + _ = [item async for item in paginator] + assert len(pipeline.calls) == 3 + assert pipeline.calls[1].url.query.get("cursor") == "c1" + assert pipeline.calls[2].url.query.get("cursor") == "c2" + + +async def test_by_page_yields_pages() -> None: + pipeline = _three_page_pipeline() + paginator: AsyncPaginator[int] = AsyncPaginator(pipeline, _strategy(), _first_request()) + pages = [list(page.items) async for page in paginator.by_page()] + assert pages == [[1, 2], [3, 4], [5]] + + +async def test_max_pages_bounds_iteration() -> None: + pipeline = _three_page_pipeline() + paginator: AsyncPaginator[int] = AsyncPaginator( + pipeline, + _strategy(), + _first_request(), + max_pages=2, + ) + items = [item async for item in paginator] + assert items == [1, 2, 3, 4] + assert len(pipeline.calls) == 2 + + +async def test_item_iteration_closes_each_page_body() -> None: + pipeline = _three_page_pipeline() + paginator: AsyncPaginator[int] = AsyncPaginator(pipeline, _strategy(), _first_request()) + _ = [item async for item in paginator] + assert all(body.closed for body in pipeline.closed_bodies) + + +async def test_accepts_a_plain_async_send_callable() -> None: + pipeline = _three_page_pipeline() + + async def send(request: Request) -> AsyncResponse: + return await pipeline.run(request, DispatchContext.noop()) + + paginator: AsyncPaginator[int] = AsyncPaginator(send, _strategy(), _first_request()) + items = [item async for item in paginator] + assert items == [1, 2, 3, 4, 5] diff --git a/packages/dexpace-sdk-core/tests/pagination/test_link_header.py b/packages/dexpace-sdk-core/tests/pagination/test_link_header.py new file mode 100644 index 0000000..c5ff207 --- /dev/null +++ b/packages/dexpace-sdk-core/tests/pagination/test_link_header.py @@ -0,0 +1,83 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Tests for the RFC 5988 ``Link`` header parser.""" + +from __future__ import annotations + +import pytest + +from dexpace.sdk.core.pagination.link_header import find_rel, parse_link_header + + +def test_parse_single_link_value_extracts_target_and_rel() -> None: + header = '; rel="next"' + parsed = parse_link_header(header) + assert parsed == (("https://api.example.com/items?page=2", {"rel": "next"}),) + + +def test_parse_multiple_link_values_in_order() -> None: + header = ( + '; rel="next", ' + '; rel="last"' + ) + parsed = parse_link_header(header) + assert [target for target, _ in parsed] == [ + "https://api.example.com/items?page=2", + "https://api.example.com/items?page=9", + ] + assert parsed[0][1]["rel"] == "next" + assert parsed[1][1]["rel"] == "last" + + +def test_comma_inside_quoted_param_does_not_split_link_values() -> None: + header = '; rel="next"; title="one, two"' + parsed = parse_link_header(header) + assert len(parsed) == 1 + assert parsed[0][1]["title"] == "one, two" + + +def test_param_names_are_case_insensitive() -> None: + header = '; REL="next"' + assert find_rel(header, "next") == "https://api.example.com/a" + + +@pytest.mark.parametrize( + ("header", "rel", "expected"), + [ + ('; rel="next"', "next", "https://x/1"), + ('; rel="next"', "prev", None), + ('; rel="first next"', "next", "https://x/1"), + ('; rel="NEXT"', "next", "https://x/1"), + ("", "next", None), + (" ", "next", None), + ], + ids=[ + "next-present", + "prev-absent", + "next-in-space-separated-set", + "case-insensitive-rel-value", + "empty-header", + "whitespace-header", + ], +) +def test_find_rel_returns_matching_target(header: str, rel: str, expected: str | None) -> None: + assert find_rel(header, rel) == expected + + +def test_unquoted_param_value_is_accepted() -> None: + header = "; rel=next" + assert find_rel(header, "next") == "https://x/1" + + +def test_escaped_quote_inside_quoted_value_is_unescaped() -> None: + header = r'; rel="next"; title="a \"quoted\" word"' + parsed = parse_link_header(header) + assert parsed[0][1]["title"] == 'a "quoted" word' + + +def test_malformed_segment_without_brackets_is_ignored() -> None: + header = 'rel="next", ; rel="last"' + parsed = parse_link_header(header) + assert len(parsed) == 1 + assert parsed[0][0] == "https://x/1" diff --git a/packages/dexpace-sdk-core/tests/pagination/test_paginator.py b/packages/dexpace-sdk-core/tests/pagination/test_paginator.py new file mode 100644 index 0000000..ac3738c --- /dev/null +++ b/packages/dexpace-sdk-core/tests/pagination/test_paginator.py @@ -0,0 +1,144 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Tests for the synchronous ``Paginator`` driven through a mock pipeline.""" + +from __future__ import annotations + +import json +from collections.abc import Iterator + +from dexpace.sdk.core.http.common import Headers, MediaType, Protocol, Url +from dexpace.sdk.core.http.context import DispatchContext +from dexpace.sdk.core.http.request import Method, Request +from dexpace.sdk.core.http.response import Response, Status +from dexpace.sdk.core.http.response.response_body import ResponseBody +from dexpace.sdk.core.pagination import CursorStrategy, Paginator + + +class _MockPipeline: + """Stand-in ``Pipeline``: maps a cursor query value to a canned page body. + + Records every request it is handed so tests can assert the paginator + drove the pipeline (not a bare transport) and built the right next URL. + """ + + def __init__(self, pages: dict[str | None, dict[str, object]]) -> None: + self._pages = pages + self.calls: list[Request] = [] + self.closed_bodies: list[_TrackingBody] = [] + + def run(self, request: Request, _dispatch: DispatchContext) -> Response: + self.calls.append(request) + cursor = request.url.query.get("cursor") + payload = self._pages[cursor] + body = _TrackingBody(json.dumps(payload).encode("utf-8")) + self.closed_bodies.append(body) + return Response( + request=request, + protocol=Protocol.HTTP_1_1, + status=Status.OK, + headers=Headers(), + body=body, + ) + + +class _TrackingBody(ResponseBody): + """In-memory body that records when it is closed (for cleanup assertions).""" + + def __init__(self, data: bytes) -> None: + self._data = data + self.closed = False + + def media_type(self) -> MediaType | None: + return None + + def content_length(self) -> int: + return len(self._data) + + def iter_bytes(self, chunk_size: int = 64 * 1024) -> Iterator[bytes]: + yield self._data + + def close(self) -> None: + self.closed = True + + +def _first_request() -> Request: + return Request(method=Method.GET, url=Url.parse("https://api.example.com/items")) + + +def _strategy() -> CursorStrategy[int]: + return CursorStrategy( + items_field="data", + cursor_response_field="next_cursor", + cursor_param="cursor", + ) + + +def _three_page_pipeline() -> _MockPipeline: + return _MockPipeline( + { + None: {"data": [1, 2], "next_cursor": "c1"}, + "c1": {"data": [3, 4], "next_cursor": "c2"}, + "c2": {"data": [5], "next_cursor": None}, + }, + ) + + +def test_iterates_items_across_all_pages() -> None: + pipeline = _three_page_pipeline() + paginator: Paginator[int] = Paginator(pipeline, _strategy(), _first_request()) + assert list(paginator) == [1, 2, 3, 4, 5] + + +def test_drives_the_pipeline_once_per_page() -> None: + pipeline = _three_page_pipeline() + paginator: Paginator[int] = Paginator(pipeline, _strategy(), _first_request()) + list(paginator) + assert len(pipeline.calls) == 3 + assert pipeline.calls[1].url.query.get("cursor") == "c1" + assert pipeline.calls[2].url.query.get("cursor") == "c2" + + +def test_by_page_yields_pages_with_raw_response() -> None: + pipeline = _three_page_pipeline() + paginator: Paginator[int] = Paginator(pipeline, _strategy(), _first_request()) + pages = list(paginator.by_page()) + assert [list(page.items) for page in pages] == [[1, 2], [3, 4], [5]] + assert all(isinstance(page.raw, Response) for page in pages) + + +def test_item_iteration_closes_each_page_body() -> None: + pipeline = _three_page_pipeline() + paginator: Paginator[int] = Paginator(pipeline, _strategy(), _first_request()) + list(paginator) + assert all(body.closed for body in pipeline.closed_bodies) + + +def test_max_pages_bounds_iteration() -> None: + pipeline = _three_page_pipeline() + paginator: Paginator[int] = Paginator( + pipeline, + _strategy(), + _first_request(), + max_pages=2, + ) + assert list(paginator) == [1, 2, 3, 4] + assert len(pipeline.calls) == 2 + + +def test_accepts_a_plain_send_callable() -> None: + pipeline = _three_page_pipeline() + + def send(request: Request) -> Response: + return pipeline.run(request, DispatchContext.noop()) + + paginator: Paginator[int] = Paginator(send, _strategy(), _first_request()) + assert list(paginator) == [1, 2, 3, 4, 5] + + +def test_single_page_sequence_terminates() -> None: + pipeline = _MockPipeline({None: {"data": [7, 8], "next_cursor": None}}) + paginator: Paginator[int] = Paginator(pipeline, _strategy(), _first_request()) + assert list(paginator) == [7, 8] + assert len(pipeline.calls) == 1 diff --git a/packages/dexpace-sdk-core/tests/pagination/test_strategies.py b/packages/dexpace-sdk-core/tests/pagination/test_strategies.py new file mode 100644 index 0000000..faf83aa --- /dev/null +++ b/packages/dexpace-sdk-core/tests/pagination/test_strategies.py @@ -0,0 +1,184 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Tests for the built-in pagination strategies.""" + +from __future__ import annotations + +from dexpace.sdk.core.http.common import Headers, Protocol, Url +from dexpace.sdk.core.http.request import Method, Request +from dexpace.sdk.core.http.response import Response, Status +from dexpace.sdk.core.pagination import ( + CursorStrategy, + LinkHeaderStrategy, + PageNumberStrategy, +) + + +def _request(url: str = "https://api.example.com/items") -> Request: + return Request(method=Method.GET, url=Url.parse(url)) + + +def _response(req: Request, *, headers: Headers | None = None) -> Response: + return Response( + request=req, + protocol=Protocol.HTTP_1_1, + status=Status.OK, + headers=headers or Headers(), + ) + + +class TestCursorStrategy: + def test_extracts_items_and_builds_next_with_cursor_param(self) -> None: + strategy: CursorStrategy[int] = CursorStrategy( + items_field="data", + cursor_response_field="next_cursor", + cursor_param="cursor", + ) + req = _request() + payload = {"data": [1, 2, 3], "next_cursor": "abc123"} + page = strategy.parse(_response(req), payload, req) + assert page.items == [1, 2, 3] + assert page.next_request is not None + assert page.next_request.url.query.get("cursor") == "abc123" + + def test_token_convention_uses_configured_field_names(self) -> None: + strategy: CursorStrategy[int] = CursorStrategy( + items_field="results", + cursor_response_field="next_page_token", + cursor_param="page_token", + ) + req = _request() + payload = {"results": [9], "next_page_token": "tok"} + page = strategy.parse(_response(req), payload, req) + assert page.next_request is not None + assert page.next_request.url.query.get("page_token") == "tok" + + def test_absent_cursor_ends_sequence(self) -> None: + strategy: CursorStrategy[int] = CursorStrategy(items_field="data") + req = _request() + page = strategy.parse(_response(req), {"data": [1]}, req) + assert page.next_request is None + assert not page.has_next + + def test_empty_cursor_string_ends_sequence(self) -> None: + strategy: CursorStrategy[int] = CursorStrategy(items_field="data") + req = _request() + page = strategy.parse(_response(req), {"data": [], "next_cursor": ""}, req) + assert page.next_request is None + + def test_nested_dotted_item_path(self) -> None: + strategy: CursorStrategy[int] = CursorStrategy(items_field="result.items") + req = _request() + payload = {"result": {"items": [1, 2]}, "next_cursor": "c"} + page = strategy.parse(_response(req), payload, req) + assert page.items == [1, 2] + + +class TestPageNumberStrategy: + def test_increments_page_param_when_full_page(self) -> None: + strategy: PageNumberStrategy[int] = PageNumberStrategy( + items_field="data", + page_size=2, + ) + req = _request("https://api.example.com/items?page=1") + page = strategy.parse(_response(req), {"data": [1, 2]}, req) + assert page.next_request is not None + assert page.next_request.url.query.get("page") == "2" + + def test_short_page_ends_sequence(self) -> None: + strategy: PageNumberStrategy[int] = PageNumberStrategy( + items_field="data", + page_size=2, + ) + req = _request("https://api.example.com/items?page=1") + page = strategy.parse(_response(req), {"data": [1]}, req) + assert page.next_request is None + + def test_empty_page_ends_sequence(self) -> None: + strategy: PageNumberStrategy[int] = PageNumberStrategy(items_field="data") + req = _request("https://api.example.com/items?page=4") + page = strategy.parse(_response(req), {"data": []}, req) + assert page.next_request is None + + def test_total_pages_field_bounds_iteration(self) -> None: + strategy: PageNumberStrategy[int] = PageNumberStrategy( + items_field="data", + total_pages_field="total_pages", + ) + req = _request("https://api.example.com/items?page=3") + last = strategy.parse(_response(req), {"data": [1], "total_pages": 3}, req) + assert last.next_request is None + req2 = _request("https://api.example.com/items?page=2") + more = strategy.parse(_response(req2), {"data": [1], "total_pages": 3}, req2) + assert more.next_request is not None + assert more.next_request.url.query.get("page") == "3" + + def test_defaults_to_start_page_when_param_absent(self) -> None: + strategy: PageNumberStrategy[int] = PageNumberStrategy( + items_field="data", + start_page=1, + page_size=2, + ) + req = _request() + page = strategy.parse(_response(req), {"data": [1, 2]}, req) + assert page.next_request is not None + assert page.next_request.url.query.get("page") == "2" + + +class TestLinkHeaderStrategy: + def test_follows_rel_next_target(self) -> None: + strategy: LinkHeaderStrategy[int] = LinkHeaderStrategy(items_field="data") + headers = Headers( + [("Link", '; rel="next"')], + ) + req = _request() + page = strategy.parse(_response(req, headers=headers), {"data": [1]}, req) + assert page.next_request is not None + assert str(page.next_request.url) == "https://api.example.com/items?page=2" + + def test_exposes_prev_request_when_present(self) -> None: + strategy: LinkHeaderStrategy[int] = LinkHeaderStrategy(items_field="data") + headers = Headers( + [ + ( + "Link", + '; rel="next", ' + '; rel="prev"', + ), + ], + ) + req = _request() + page = strategy.parse(_response(req, headers=headers), {"data": [1]}, req) + assert page.prev_request is not None + assert str(page.prev_request.url) == "https://api.example.com/items?page=1" + + def test_no_link_header_ends_sequence(self) -> None: + strategy: LinkHeaderStrategy[int] = LinkHeaderStrategy(items_field="data") + req = _request() + page = strategy.parse(_response(req), {"data": [1]}, req) + assert page.next_request is None + assert page.prev_request is None + + def test_next_request_preserves_method_and_headers(self) -> None: + strategy: LinkHeaderStrategy[int] = LinkHeaderStrategy(items_field="data") + headers = Headers([("Link", '; rel="next"')]) + req = Request( + method=Method.GET, + url=Url.parse("https://api.example.com/items"), + headers=Headers([("Authorization", "Bearer t")]), + ) + page = strategy.parse(_response(req, headers=headers), {"data": [1]}, req) + assert page.next_request is not None + assert page.next_request.method is Method.GET + assert page.next_request.headers.get("authorization") == "Bearer t" + + def test_relative_target_is_resolved_against_request_url(self) -> None: + # RFC 5988 permits a relative target; it must be resolved against the + # request URL rather than raising on a missing scheme. + strategy: LinkHeaderStrategy[int] = LinkHeaderStrategy(items_field="data") + headers = Headers([("Link", '; rel="next"')]) + req = _request("https://api.example.com/items?page=1") + page = strategy.parse(_response(req, headers=headers), {"data": [1]}, req) + assert page.next_request is not None + assert str(page.next_request.url) == "https://api.example.com/items?page=2" diff --git a/packages/dexpace-sdk-core/tests/pipeline/test_async_retry_cancel.py b/packages/dexpace-sdk-core/tests/pipeline/test_async_retry_cancel.py new file mode 100644 index 0000000..d4759d7 --- /dev/null +++ b/packages/dexpace-sdk-core/tests/pipeline/test_async_retry_cancel.py @@ -0,0 +1,158 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Cancellation discipline for ``AsyncRetryPolicy`` (P9). + +``asyncio.CancelledError`` is a ``BaseException``, not an ``SdkError``, so the +policy's ``except SdkError`` clause already cannot catch it — but the policy +adds an explicit re-raise as a documented, tested invariant. These tests pin +the contract: a cancelled in-flight attempt propagates immediately and is +never retried. +""" + +from __future__ import annotations + +import asyncio + +import pytest + +from dexpace.sdk.core.client.async_http_client import AsyncHttpClient +from dexpace.sdk.core.errors import ServiceRequestError +from dexpace.sdk.core.http.common import Protocol, Url +from dexpace.sdk.core.http.context import DispatchContext +from dexpace.sdk.core.http.request import Method, Request +from dexpace.sdk.core.http.response import AsyncResponse, Status +from dexpace.sdk.core.instrumentation import ( + InstrumentationContext, + SpanId, + TraceFlags, + TraceId, + TraceIdType, + TraceState, +) +from dexpace.sdk.core.instrumentation.noop import NOOP_SPAN +from dexpace.sdk.core.pipeline import AsyncPipeline +from dexpace.sdk.core.pipeline.policies.async_retry import AsyncRetryPolicy + + +class _AsyncFakeClock: + """Deterministic ``AsyncClock`` for tests; advances time on sleep.""" + + __slots__ = ("_t",) + + def __init__(self, start: float = 0.0) -> None: + self._t = start + + def now(self) -> float: + return self._t + + def monotonic(self) -> float: + return self._t + + async def sleep(self, duration: float) -> None: + self._t += max(0.0, duration) + + +def _instr(trace: str) -> InstrumentationContext: + return InstrumentationContext( + trace_id_type=TraceIdType.W3C, + trace_id=TraceId(trace), + span_id=SpanId("0" * 16), + span=NOOP_SPAN, + trace_flags=TraceFlags.NOOP, + trace_state=TraceState.NOOP, + ) + + +def _get() -> Request: + return Request(method=Method.GET, url=Url.parse("https://example.com/")) + + +class _CancellingClient(AsyncHttpClient): + """Raises ``CancelledError`` on the configured attempt.""" + + def __init__(self, cancel_on: int = 1) -> None: + self._cancel_on = cancel_on + self.attempts = 0 + + async def execute(self, request: Request) -> AsyncResponse: + self.attempts += 1 + if self.attempts >= self._cancel_on: + raise asyncio.CancelledError + return AsyncResponse( + request=request, + protocol=Protocol.HTTP_1_1, + status=Status.SERVICE_UNAVAILABLE, + ) + + +class TestCancelledNotRetried: + async def test_cancelled_error_propagates_immediately(self) -> None: + client = _CancellingClient(cancel_on=1) + retry = AsyncRetryPolicy(clock=_AsyncFakeClock()) + with pytest.raises(asyncio.CancelledError): + async with AsyncPipeline(client, policies=[retry]) as p: + await p.run(_get(), DispatchContext(_instr("0" * 16 + "1"))) + # Exactly one attempt — no retry of a cancellation. + assert client.attempts == 1 + + async def test_cancelled_after_a_retryable_failure_still_not_retried(self) -> None: + # First attempt is a retryable 503, second attempt is cancelled. + client = _CancellingClient(cancel_on=2) + retry = AsyncRetryPolicy(clock=_AsyncFakeClock()) + with pytest.raises(asyncio.CancelledError): + async with AsyncPipeline(client, policies=[retry]) as p: + await p.run(_get(), DispatchContext(_instr("0" * 16 + "2"))) + # Two attempts total: the 503 was retried, the cancellation was not. + assert client.attempts == 2 + + async def test_task_cancellation_propagates_through_policy(self) -> None: + started = asyncio.Event() + + class _HangingClient(AsyncHttpClient): + def __init__(self) -> None: + self.attempts = 0 + + async def execute(self, request: Request) -> AsyncResponse: + self.attempts += 1 + started.set() + await asyncio.sleep(3600) # never completes before cancel + raise AssertionError("unreachable") + + client = _HangingClient() + retry = AsyncRetryPolicy(clock=_AsyncFakeClock()) + + async def _drive() -> AsyncResponse: + async with AsyncPipeline(client, policies=[retry]) as p: + return await p.run(_get(), DispatchContext(_instr("0" * 16 + "3"))) + + task = asyncio.ensure_future(_drive()) + await started.wait() + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + assert client.attempts == 1 + + async def test_service_error_is_still_retried(self) -> None: + # Guard against over-broad cancellation handling: ordinary SDK errors + # must still flow into the retry path. + class _FlakyClient(AsyncHttpClient): + def __init__(self) -> None: + self.attempts = 0 + + async def execute(self, request: Request) -> AsyncResponse: + self.attempts += 1 + if self.attempts == 1: + raise ServiceRequestError("connection reset") + return AsyncResponse( + request=request, + protocol=Protocol.HTTP_1_1, + status=Status.OK, + ) + + client = _FlakyClient() + retry = AsyncRetryPolicy(clock=_AsyncFakeClock()) + async with AsyncPipeline(client, policies=[retry]) as p: + response = await p.run(_get(), DispatchContext(_instr("0" * 16 + "4"))) + assert response.status is Status.OK + assert client.attempts == 2 diff --git a/packages/dexpace-sdk-core/tests/pipeline/test_client_identity_policy.py b/packages/dexpace-sdk-core/tests/pipeline/test_client_identity_policy.py new file mode 100644 index 0000000..71d7765 --- /dev/null +++ b/packages/dexpace-sdk-core/tests/pipeline/test_client_identity_policy.py @@ -0,0 +1,159 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Tests for ``ClientIdentityPolicy`` and ``AsyncClientIdentityPolicy``.""" + +from __future__ import annotations + +import re + +import pytest + +from dexpace.sdk.core.client.async_http_client import AsyncHttpClient +from dexpace.sdk.core.client.http_client import HttpClient +from dexpace.sdk.core.http.common import Protocol, Url +from dexpace.sdk.core.http.context import DispatchContext +from dexpace.sdk.core.http.request import Method, Request +from dexpace.sdk.core.http.response import AsyncResponse, Response, Status +from dexpace.sdk.core.instrumentation import ( + InstrumentationContext, + SpanId, + TraceFlags, + TraceId, + TraceIdType, + TraceState, +) +from dexpace.sdk.core.instrumentation.noop import NOOP_SPAN +from dexpace.sdk.core.pipeline import AsyncPipeline, Pipeline +from dexpace.sdk.core.pipeline.policies.async_client_identity import AsyncClientIdentityPolicy +from dexpace.sdk.core.pipeline.policies.client_identity import ( + ClientIdentityPolicy, + default_user_agent, +) + +_UA = "User-Agent" +_DEFAULT_UA = re.compile(r"^dexpace-sdk/\S+ python/\d+\.\d+\.\d+$") + + +def _instr(trace: str) -> InstrumentationContext: + return InstrumentationContext( + trace_id_type=TraceIdType.W3C, + trace_id=TraceId(trace), + span_id=SpanId("0" * 16), + span=NOOP_SPAN, + trace_flags=TraceFlags.NOOP, + trace_state=TraceState.NOOP, + ) + + +def _request(*, user_agent: str | None = None) -> Request: + req = Request(method=Method.GET, url=Url.parse("https://api.example.com/v1")) + if user_agent is not None: + req = req.with_header(_UA, user_agent) + return req + + +class _RecordingClient(HttpClient): + def __init__(self) -> None: + self.calls: list[Request] = [] + + def execute(self, request: Request) -> Response: + self.calls.append(request) + return Response(request=request, protocol=Protocol.HTTP_1_1, status=Status.OK) + + +class _RecordingAsyncClient(AsyncHttpClient): + def __init__(self) -> None: + self.calls: list[Request] = [] + + async def execute(self, request: Request) -> AsyncResponse: + self.calls.append(request) + return AsyncResponse(request=request, protocol=Protocol.HTTP_1_1, status=Status.OK) + + +def test_default_user_agent_shape() -> None: + assert _DEFAULT_UA.match(default_user_agent()) is not None + + +def test_default_user_agent_never_blank() -> None: + ua = default_user_agent() + assert ua.strip() + assert "dexpace-sdk/" in ua + + +def test_stamps_default_user_agent() -> None: + client = _RecordingClient() + with Pipeline(client, policies=[ClientIdentityPolicy()]) as p: + p.run(_request(), DispatchContext(_instr("0" * 16 + "1"))) + ua = client.calls[0].headers.get(_UA) + assert ua is not None + assert _DEFAULT_UA.match(ua) is not None + + +def test_append_preserves_caller_value() -> None: + client = _RecordingClient() + policy = ClientIdentityPolicy(user_agent="my-token") + with Pipeline(client, policies=[policy]) as p: + p.run(_request(user_agent="caller/1.0"), DispatchContext(_instr("0" * 16 + "2"))) + assert client.calls[0].headers.get(_UA) == "caller/1.0 my-token" + + +def test_replace_overwrites_caller_value() -> None: + client = _RecordingClient() + policy = ClientIdentityPolicy(user_agent="my-token", replace=True) + with Pipeline(client, policies=[policy]) as p: + p.run(_request(user_agent="caller/1.0"), DispatchContext(_instr("0" * 16 + "3"))) + assert client.calls[0].headers.get(_UA) == "my-token" + + +def test_append_with_no_caller_value_uses_token_alone() -> None: + client = _RecordingClient() + policy = ClientIdentityPolicy(user_agent="my-token") + with Pipeline(client, policies=[policy]) as p: + p.run(_request(), DispatchContext(_instr("0" * 16 + "4"))) + assert client.calls[0].headers.get(_UA) == "my-token" + + +def test_blank_caller_value_replaced_not_appended() -> None: + client = _RecordingClient() + policy = ClientIdentityPolicy(user_agent="my-token") + with Pipeline(client, policies=[policy]) as p: + p.run(_request(user_agent=" "), DispatchContext(_instr("0" * 16 + "5"))) + assert client.calls[0].headers.get(_UA) == "my-token" + + +@pytest.mark.parametrize("bad", ["", " ", "\t\n"]) +def test_blank_token_rejected(bad: str) -> None: + with pytest.raises(ValueError, match="non-empty"): + ClientIdentityPolicy(user_agent=bad) + + +async def test_async_stamps_default_user_agent() -> None: + client = _RecordingAsyncClient() + async with AsyncPipeline(client, policies=[AsyncClientIdentityPolicy()]) as p: + await p.run(_request(), DispatchContext(_instr("0" * 16 + "6"))) + ua = client.calls[0].headers.get(_UA) + assert ua is not None + assert _DEFAULT_UA.match(ua) is not None + + +async def test_async_append_preserves_caller_value() -> None: + client = _RecordingAsyncClient() + policy = AsyncClientIdentityPolicy(user_agent="my-token") + async with AsyncPipeline(client, policies=[policy]) as p: + await p.run(_request(user_agent="caller/1.0"), DispatchContext(_instr("0" * 16 + "7"))) + assert client.calls[0].headers.get(_UA) == "caller/1.0 my-token" + + +async def test_async_replace_overwrites_caller_value() -> None: + client = _RecordingAsyncClient() + policy = AsyncClientIdentityPolicy(user_agent="my-token", replace=True) + async with AsyncPipeline(client, policies=[policy]) as p: + await p.run(_request(user_agent="caller/1.0"), DispatchContext(_instr("0" * 16 + "8"))) + assert client.calls[0].headers.get(_UA) == "my-token" + + +@pytest.mark.parametrize("bad", ["", " "]) +def test_async_blank_token_rejected(bad: str) -> None: + with pytest.raises(ValueError, match="non-empty"): + AsyncClientIdentityPolicy(user_agent=bad) diff --git a/packages/dexpace-sdk-core/tests/pipeline/test_defaults.py b/packages/dexpace-sdk-core/tests/pipeline/test_defaults.py index f29238d..054e455 100644 --- a/packages/dexpace-sdk-core/tests/pipeline/test_defaults.py +++ b/packages/dexpace-sdk-core/tests/pipeline/test_defaults.py @@ -54,11 +54,14 @@ def test_default_pipeline_returns_builder() -> None: def test_default_pipeline_wires_canonical_stack() -> None: pipeline = default_pipeline(_StubTransport()).build() stages = _stages_of(pipeline) - # Canonical order: REDIRECT, RETRY, POST_RETRY (set-date), LOGGING, POST_LOGGING + # Canonical order: REDIRECT, POST_REDIRECT (idempotency), RETRY, + # POST_RETRY (set-date then client-identity), LOGGING, POST_LOGGING (tracing). assert stages == [ Stage.REDIRECT, + Stage.POST_REDIRECT, Stage.RETRY, Stage.POST_RETRY, + Stage.POST_RETRY, Stage.LOGGING, Stage.POST_LOGGING, ] diff --git a/packages/dexpace-sdk-core/tests/pipeline/test_idempotency_policy.py b/packages/dexpace-sdk-core/tests/pipeline/test_idempotency_policy.py new file mode 100644 index 0000000..ceac030 --- /dev/null +++ b/packages/dexpace-sdk-core/tests/pipeline/test_idempotency_policy.py @@ -0,0 +1,183 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Tests for ``IdempotencyPolicy`` and ``AsyncIdempotencyPolicy``.""" + +from __future__ import annotations + +import pytest + +from dexpace.sdk.core.client.async_http_client import AsyncHttpClient +from dexpace.sdk.core.client.http_client import HttpClient +from dexpace.sdk.core.http.common import Protocol, Url +from dexpace.sdk.core.http.context import DispatchContext +from dexpace.sdk.core.http.request import Method, Request +from dexpace.sdk.core.http.response import AsyncResponse, Response, Status +from dexpace.sdk.core.instrumentation import ( + InstrumentationContext, + SpanId, + TraceFlags, + TraceId, + TraceIdType, + TraceState, +) +from dexpace.sdk.core.instrumentation.noop import NOOP_SPAN +from dexpace.sdk.core.pipeline import AsyncPipeline, Pipeline +from dexpace.sdk.core.pipeline.policies.async_idempotency import AsyncIdempotencyPolicy +from dexpace.sdk.core.pipeline.policies.idempotency import IdempotencyPolicy +from dexpace.sdk.core.pipeline.stage import Stage + +_HEADER = "Idempotency-Key" + + +def _instr(trace: str) -> InstrumentationContext: + return InstrumentationContext( + trace_id_type=TraceIdType.W3C, + trace_id=TraceId(trace), + span_id=SpanId("0" * 16), + span=NOOP_SPAN, + trace_flags=TraceFlags.NOOP, + trace_state=TraceState.NOOP, + ) + + +def _request(method: Method = Method.POST, *, key: str | None = None) -> Request: + req = Request(method=method, url=Url.parse("https://api.example.com/v1")) + if key is not None: + req = req.with_header(_HEADER, key) + return req + + +class _CountingFactory: + """Deterministic key factory yielding ``key-1``, ``key-2``, ... per call.""" + + __slots__ = ("_n",) + + def __init__(self) -> None: + self._n = 0 + + def __call__(self) -> str: + self._n += 1 + return f"key-{self._n}" + + +class _RecordingClient(HttpClient): + """Captures requests handed to the transport for assertion.""" + + def __init__(self) -> None: + self.calls: list[Request] = [] + + def execute(self, request: Request) -> Response: + self.calls.append(request) + return Response(request=request, protocol=Protocol.HTTP_1_1, status=Status.OK) + + +class _RecordingAsyncClient(AsyncHttpClient): + def __init__(self) -> None: + self.calls: list[Request] = [] + + async def execute(self, request: Request) -> AsyncResponse: + self.calls.append(request) + return AsyncResponse(request=request, protocol=Protocol.HTTP_1_1, status=Status.OK) + + +def test_stage_runs_before_retry() -> None: + assert IdempotencyPolicy.STAGE < Stage.RETRY + assert AsyncIdempotencyPolicy.STAGE < Stage.RETRY + + +def test_key_added_to_post() -> None: + client = _RecordingClient() + with Pipeline(client, policies=[IdempotencyPolicy()]) as p: + p.run(_request(Method.POST), DispatchContext(_instr("0" * 16 + "1"))) + assert client.calls[0].headers.get(_HEADER) is not None + + +@pytest.mark.parametrize("method", [Method.POST, Method.PUT, Method.PATCH]) +def test_key_added_to_write_methods(method: Method) -> None: + client = _RecordingClient() + with Pipeline(client, policies=[IdempotencyPolicy()]) as p: + p.run(_request(method), DispatchContext(_instr("0" * 16 + "2"))) + assert client.calls[0].headers.get(_HEADER) is not None + + +@pytest.mark.parametrize("method", [Method.GET, Method.DELETE, Method.HEAD, Method.OPTIONS]) +def test_key_not_added_to_non_write_methods(method: Method) -> None: + client = _RecordingClient() + with Pipeline(client, policies=[IdempotencyPolicy()]) as p: + p.run(_request(method), DispatchContext(_instr("0" * 16 + "3"))) + assert client.calls[0].headers.get(_HEADER) is None + + +def test_caller_set_key_preserved() -> None: + client = _RecordingClient() + with Pipeline(client, policies=[IdempotencyPolicy()]) as p: + p.run(_request(Method.POST, key="caller-key"), DispatchContext(_instr("0" * 16 + "4"))) + assert client.calls[0].headers.get(_HEADER) == "caller-key" + + +def test_default_key_is_uuid4_shaped() -> None: + from uuid import UUID + + client = _RecordingClient() + with Pipeline(client, policies=[IdempotencyPolicy()]) as p: + p.run(_request(Method.POST), DispatchContext(_instr("0" * 16 + "5"))) + key = client.calls[0].headers.get(_HEADER) + assert key is not None + parsed = UUID(key) # raises ValueError if malformed + assert parsed.version == 4 + + +def test_existing_key_not_regenerated() -> None: + """A request that already carries a key is forwarded untouched. + + This is the mechanism by which a key survives retries: the policy sits + outside the retry wrapper, mints the key on the first pass, and on any + re-send sees the header is present and leaves it alone. + """ + factory = _CountingFactory() + client = _RecordingClient() + policy = IdempotencyPolicy(key_factory=factory) + with Pipeline(client, policies=[policy]) as p: + # First send mints key-1. + first = p.run(_request(Method.POST), DispatchContext(_instr("0" * 16 + "6"))) + carried = first.request.headers.get(_HEADER) + assert carried == "key-1" + # Re-send the already-stamped request; the policy must not mint key-2. + resent = _request(Method.POST, key=carried) + p.run(resent, DispatchContext(_instr("0" * 16 + "7"))) + assert client.calls[1].headers.get(_HEADER) == "key-1" + + +def test_custom_methods_and_header() -> None: + client = _RecordingClient() + policy = IdempotencyPolicy(methods=[Method.DELETE], header="X-Idem") + with Pipeline(client, policies=[policy]) as p: + p.run(_request(Method.DELETE), DispatchContext(_instr("0" * 16 + "8"))) + p.run(_request(Method.POST), DispatchContext(_instr("0" * 16 + "9"))) + assert client.calls[0].headers.get("X-Idem") is not None + assert client.calls[1].headers.get("X-Idem") is None + + +async def test_async_key_added_to_post() -> None: + client = _RecordingAsyncClient() + async with AsyncPipeline(client, policies=[AsyncIdempotencyPolicy()]) as p: + await p.run(_request(Method.POST), DispatchContext(_instr("0" * 16 + "a"))) + assert client.calls[0].headers.get(_HEADER) is not None + + +async def test_async_caller_set_key_preserved() -> None: + client = _RecordingAsyncClient() + async with AsyncPipeline(client, policies=[AsyncIdempotencyPolicy()]) as p: + await p.run( + _request(Method.POST, key="caller-key"), + DispatchContext(_instr("0" * 16 + "b")), + ) + assert client.calls[0].headers.get(_HEADER) == "caller-key" + + +async def test_async_get_not_stamped() -> None: + client = _RecordingAsyncClient() + async with AsyncPipeline(client, policies=[AsyncIdempotencyPolicy()]) as p: + await p.run(_request(Method.GET), DispatchContext(_instr("0" * 16 + "c"))) + assert client.calls[0].headers.get(_HEADER) is None diff --git a/packages/dexpace-sdk-core/tests/pipeline/test_log_correlation.py b/packages/dexpace-sdk-core/tests/pipeline/test_log_correlation.py new file mode 100644 index 0000000..f889128 --- /dev/null +++ b/packages/dexpace-sdk-core/tests/pipeline/test_log_correlation.py @@ -0,0 +1,226 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Tests that ``TracingPolicy`` binds trace/span ids for log correlation (P8). + +The policy sets the active trace and span ids in the ``correlation`` +``contextvars`` for the duration of the downstream send, so any log record +emitted while the request is in flight carries them. The bindings are scoped: +they restore the prior values once the request completes (success or error). +""" + +from __future__ import annotations + +import logging +from typing import Any + +from _pytest.logging import LogCaptureFixture + +from dexpace.sdk.core.client.http_client import HttpClient +from dexpace.sdk.core.errors import ServiceRequestError +from dexpace.sdk.core.http.common import Protocol, Url +from dexpace.sdk.core.http.context import DispatchContext +from dexpace.sdk.core.http.request import Method, Request +from dexpace.sdk.core.http.response import Response, Status +from dexpace.sdk.core.instrumentation import ( + ClientLogger, + InstrumentationContext, + Span, + SpanId, + TraceFlags, + TraceId, + TraceIdType, + Tracer, + TraceState, + TracingScope, + get_span_id, + get_trace_id, +) +from dexpace.sdk.core.pipeline import Pipeline +from dexpace.sdk.core.pipeline.context import PipelineContext +from dexpace.sdk.core.pipeline.policies import TracingPolicy +from dexpace.sdk.core.pipeline.policy import Policy +from dexpace.sdk.core.pipeline.stage import Stage + +_TRACE = "1" * 32 +_SPAN = "2" * 16 + + +def _valid_context() -> InstrumentationContext: + """Build a recording context with real (non-sentinel) ids.""" + return InstrumentationContext( + trace_id_type=TraceIdType.W3C, + trace_id=TraceId(_TRACE), + span_id=SpanId(_SPAN), + span=_RecordingSpan.placeholder(), + trace_flags=TraceFlags.NOOP, + trace_state=TraceState.NOOP, + ) + + +class _Scope(TracingScope): + def close(self) -> None: + return None + + def __enter__(self) -> _Scope: + return self + + def __exit__(self, *exc_info: object) -> None: + self.close() + + +class _RecordingSpan(Span): + """Span that reports a valid context so the policy binds real ids.""" + + def __init__(self) -> None: + self._ended = False + + @classmethod + def placeholder(cls) -> _RecordingSpan: + return cls() + + @property + def is_recording(self) -> bool: + return True + + @property + def context(self) -> InstrumentationContext: + # A self-referential context would recurse; build a leaf one with the + # same ids but a no-op span field is not needed here — the policy only + # reads ``trace_id`` / ``span_id`` / ``is_valid``. + return InstrumentationContext( + trace_id_type=TraceIdType.W3C, + trace_id=TraceId(_TRACE), + span_id=SpanId(_SPAN), + span=self, + trace_flags=TraceFlags.NOOP, + trace_state=TraceState.NOOP, + ) + + def set_attribute(self, key: str, value: Any) -> _RecordingSpan: + return self + + def set_error(self, error_type: str) -> _RecordingSpan: + return self + + def make_current(self) -> TracingScope: + return _Scope() + + def end(self, error: BaseException | None = None) -> None: + self._ended = True + + +class _RecordingTracer(Tracer): + def start_span( + self, + name: str, + parent: InstrumentationContext | None = None, + ) -> Span: + del name, parent + return _RecordingSpan() + + +class _CaptureCorrelationPolicy(Policy): + """Innermost policy that snapshots the bound ids when the request runs.""" + + STAGE = Stage.PRE_SEND + __slots__ = ("seen_span", "seen_trace") + + def __init__(self) -> None: + self.seen_trace: str | None = None + self.seen_span: str | None = None + + def send(self, request: Request, ctx: PipelineContext) -> Response: + self.seen_trace = get_trace_id() + self.seen_span = get_span_id() + return self.next.send(request, ctx) + + +class _OkClient(HttpClient): + def __init__(self, *, raise_exc: BaseException | None = None) -> None: + self.raise_exc = raise_exc + + def execute(self, request: Request) -> Response: + if self.raise_exc is not None: + raise self.raise_exc + return Response(request=request, protocol=Protocol.HTTP_1_1, status=Status.OK) + + +def _request() -> Request: + return Request(method=Method.GET, url=Url.parse("https://api.example.com/v1")) + + +class TestCorrelationBinding: + def test_ids_bound_during_request(self) -> None: + capture = _CaptureCorrelationPolicy() + with Pipeline( + _OkClient(), + policies=[TracingPolicy(tracer=_RecordingTracer()), capture], + ) as p: + p.run(_request(), DispatchContext(_valid_context())) + assert capture.seen_trace == _TRACE + assert capture.seen_span == _SPAN + + def test_ids_reset_after_request(self) -> None: + assert get_trace_id() is None + assert get_span_id() is None + with Pipeline( + _OkClient(), + policies=[TracingPolicy(tracer=_RecordingTracer())], + ) as p: + p.run(_request(), DispatchContext(_valid_context())) + # The scoped binding restored the (unset) prior values. + assert get_trace_id() is None + assert get_span_id() is None + + def test_ids_reset_after_exception(self) -> None: + boom = ServiceRequestError("connect failed") + raised: BaseException | None = None + with Pipeline( + _OkClient(raise_exc=boom), + policies=[TracingPolicy(tracer=_RecordingTracer())], + ) as p: + try: + p.run(_request(), DispatchContext(_valid_context())) + except ServiceRequestError as err: + raised = err + assert raised is boom + assert get_trace_id() is None + assert get_span_id() is None + + def test_noop_span_binds_no_ids(self) -> None: + # A context whose span carries sentinel ids must not bind a fake trace. + capture = _CaptureCorrelationPolicy() + with Pipeline( + _OkClient(), + # No tracer -> NOOP_TRACER -> NOOP_SPAN (is_valid is False). + policies=[TracingPolicy(), capture], + ) as p: + p.run(_request(), DispatchContext(_valid_context())) + assert capture.seen_trace is None + assert capture.seen_span is None + + +class TestLogRecordCorrelation: + def test_log_record_carries_bound_ids(self, caplog: LogCaptureFixture) -> None: + logger = ClientLogger("dexpace.sdk.core.test.correlation") + + class _LoggingPolicy(Policy): + STAGE = Stage.PRE_SEND + __slots__ = () + + def send(self, request: Request, ctx: PipelineContext) -> Response: + logger.info("in.flight") + return self.next.send(request, ctx) + + caplog.set_level(logging.INFO, logger="dexpace.sdk.core.test.correlation") + with Pipeline( + _OkClient(), + policies=[TracingPolicy(tracer=_RecordingTracer()), _LoggingPolicy()], + ) as p: + p.run(_request(), DispatchContext(_valid_context())) + records = [r for r in caplog.records if r.getMessage().startswith("in.flight")] + assert records, "expected an in-flight log record" + record = records[0] + assert getattr(record, "trace.id") == _TRACE + assert getattr(record, "span.id") == _SPAN diff --git a/packages/dexpace-sdk-core/tests/pipeline/test_retry.py b/packages/dexpace-sdk-core/tests/pipeline/test_retry.py index 8811172..a2e27bc 100644 --- a/packages/dexpace-sdk-core/tests/pipeline/test_retry.py +++ b/packages/dexpace-sdk-core/tests/pipeline/test_retry.py @@ -319,6 +319,7 @@ def test_jitter_varies_backoff(self) -> None: retry = RetryPolicy( backoff_factor=1.0, backoff_max=1000.0, + full_jitter=False, jitter=0.25, rand=random.Random(42), clock=FakeClock(), @@ -342,6 +343,7 @@ def test_no_jitter_when_zero(self) -> None: retry = RetryPolicy( backoff_factor=1.0, backoff_max=1000.0, + full_jitter=False, jitter=0.0, rand=random.Random(42), clock=FakeClock(), @@ -455,6 +457,7 @@ def test_retry_advances_clock_by_backoff_seconds() -> None: retry = RetryPolicy( backoff_factor=1.0, backoff_max=1000.0, + full_jitter=False, jitter=0.0, clock=clock, ) diff --git a/packages/dexpace-sdk-core/tests/pipeline/test_retry_tuning.py b/packages/dexpace-sdk-core/tests/pipeline/test_retry_tuning.py new file mode 100644 index 0000000..26d23b6 --- /dev/null +++ b/packages/dexpace-sdk-core/tests/pipeline/test_retry_tuning.py @@ -0,0 +1,371 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Tests for the retry-tuning refinements (P6) on ``RetryPolicy``. + +Covers the ``X-RateLimit-Reset`` epoch header, full-jitter exponential +backoff, the server ``Retry-After`` ceiling, and the ``HttpTracer`` attempt +events emitted from the retry loop. +""" + +from __future__ import annotations + +import random +from collections.abc import Mapping, Sequence +from dataclasses import replace + +import pytest + +from dexpace.sdk.core.client.http_client import HttpClient +from dexpace.sdk.core.http.common import Protocol, Url +from dexpace.sdk.core.http.context import DispatchContext +from dexpace.sdk.core.http.request import Method, Request +from dexpace.sdk.core.http.response import Response, Status +from dexpace.sdk.core.instrumentation import ( + InstrumentationContext, + SpanId, + TraceFlags, + TraceId, + TraceIdType, + TraceState, +) +from dexpace.sdk.core.instrumentation.http_tracer import HttpTracer, HttpTracerFactory +from dexpace.sdk.core.instrumentation.noop import NOOP_SPAN +from dexpace.sdk.core.pipeline import Pipeline +from dexpace.sdk.core.pipeline.policies import RetryPolicy, TracingPolicy +from dexpace.sdk.core.pipeline.policies.retry import ( + _parse_rate_limit_reset, + _StatusRetryError, +) + +from ..conftest import FakeClock + + +def _instr( + trace: str, + tracer_factory: HttpTracerFactory | None = None, +) -> InstrumentationContext: + base = InstrumentationContext( + trace_id_type=TraceIdType.W3C, + trace_id=TraceId(trace), + span_id=SpanId("0" * 16), + span=NOOP_SPAN, + trace_flags=TraceFlags.NOOP, + trace_state=TraceState.NOOP, + ) + if tracer_factory is None: + return base + return replace(base, http_tracer_factory=tracer_factory) + + +def _get() -> Request: + return Request(method=Method.GET, url=Url.parse("https://example.com/")) + + +class _ScriptedClient(HttpClient): + """Returns one response or raises one error per call, with headers.""" + + def __init__( + self, + outcomes: Sequence[Status | BaseException], + headers: Mapping[str, str] | None = None, + ) -> None: + self._outcomes = list(outcomes) + self._headers = dict(headers or {}) + self.attempts = 0 + + def execute(self, request: Request) -> Response: + outcome = self._outcomes[self.attempts] + self.attempts += 1 + if isinstance(outcome, BaseException): + raise outcome + response = Response(request=request, protocol=Protocol.HTTP_1_1, status=outcome) + if not outcome.is_success: + for name, value in self._headers.items(): + response = response.with_header(name, value) + return response + + +class _RecordingTracer(HttpTracer): + """Captures every attempt event the retry loop emits.""" + + def __init__(self) -> None: + self.started: list[int] = [] + self.failed: list[tuple[BaseException, float]] = [] + self.exhausted = 0 + + def attempt_started(self, attempt: int) -> None: + self.started.append(attempt) + + def attempt_failed(self, error: BaseException, next_delay: float) -> None: + self.failed.append((error, next_delay)) + + def attempt_retries_exhausted(self) -> None: + self.exhausted += 1 + + +class _RecordingTracerFactory: + def __init__(self, tracer: HttpTracer) -> None: + self._tracer = tracer + + def create(self) -> HttpTracer: + return self._tracer + + +# ----- X-RateLimit-Reset -------------------------------------------------- + + +class TestRateLimitResetParsing: + def test_epoch_in_future_yields_positive_delay(self) -> None: + assert _parse_rate_limit_reset("150", now=100.0) == 50.0 + + def test_epoch_in_past_floors_at_zero(self) -> None: + assert _parse_rate_limit_reset("90", now=100.0) == 0.0 + + def test_missing_or_blank_returns_none(self) -> None: + assert _parse_rate_limit_reset(None, now=0.0) is None + assert _parse_rate_limit_reset(" ", now=0.0) is None + + def test_non_numeric_returns_none(self) -> None: + assert _parse_rate_limit_reset("soon", now=0.0) is None + + +class TestRateLimitResetHonored: + def test_never_wakes_before_reset(self) -> None: + clock = FakeClock(start=1_000.0) + client = _ScriptedClient( + [Status.TOO_MANY_REQUESTS, Status.OK], + headers={"X-RateLimit-Reset": "1040"}, + ) + # Bottom of the upward jitter band [1.0, 1.1] -> exactly the reset wait. + retry = RetryPolicy(clock=clock, rand=_FixedRandom(0.0)) + with Pipeline(client, policies=[retry]) as p: + response = p.run(_get(), DispatchContext(_instr("0" * 16 + "1"))) + assert response.status is Status.OK + # 40s until reset, jitter * 1.0 -> waits to the reset instant, never before. + assert clock.monotonic() == pytest.approx(1_040.0) + + def test_reset_jitter_only_lengthens_the_wait(self) -> None: + clock = FakeClock(start=1_000.0) + client = _ScriptedClient( + [Status.TOO_MANY_REQUESTS, Status.OK], + headers={"X-RateLimit-Reset": "1040"}, + ) + # Top of the band [1.0, 1.1] -> 40s * 1.1 = 44s, i.e. slightly past reset. + retry = RetryPolicy(clock=clock, rand=_FixedRandom(1.0)) + with Pipeline(client, policies=[retry]) as p: + p.run(_get(), DispatchContext(_instr("0" * 16 + "1"))) + assert clock.monotonic() == pytest.approx(1_044.0) + + def test_retry_after_takes_precedence_over_reset(self) -> None: + clock = FakeClock(start=1_000.0) + client = _ScriptedClient( + [Status.TOO_MANY_REQUESTS, Status.OK], + headers={"Retry-After": "5", "X-RateLimit-Reset": "9999999"}, + ) + retry = RetryPolicy(clock=clock, rand=_FixedRandom(1.0)) + with Pipeline(client, policies=[retry]) as p: + p.run(_get(), DispatchContext(_instr("0" * 16 + "2"))) + assert clock.monotonic() == pytest.approx(1_005.0) + + +# ----- full jitter -------------------------------------------------------- + + +class TestFullJitter: + def test_seeded_full_jitter_is_reproducible(self) -> None: + # Same seed must produce the same slept duration, and it must land in + # the full-jitter band [base*0.5, base*1.0]. + def run_once() -> float: + clock = FakeClock() + client = _ScriptedClient( + [Status.SERVICE_UNAVAILABLE, Status.SERVICE_UNAVAILABLE, Status.OK] + ) + retry = RetryPolicy( + backoff_factor=3.0, + backoff_max=1_000.0, + clock=clock, + rand=random.Random(99), + ) + with Pipeline(client, policies=[retry]) as p: + p.run(_get(), DispatchContext(_instr("0" * 16 + "3"))) + return clock.monotonic() + + first = run_once() + second = run_once() + assert first == pytest.approx(second) + # attempts==2 -> base = 3.0 * 2**1 = 6.0; full jitter -> [3.0, 6.0]. + assert 3.0 <= first <= 6.0 + + def test_full_jitter_stays_within_band(self) -> None: + clock = FakeClock() + rng = random.Random(7) + client = _ScriptedClient( + [Status.SERVICE_UNAVAILABLE, Status.SERVICE_UNAVAILABLE, Status.OK] + ) + retry = RetryPolicy(backoff_factor=4.0, backoff_max=1_000.0, clock=clock, rand=rng) + with Pipeline(client, policies=[retry]) as p: + p.run(_get(), DispatchContext(_instr("0" * 16 + "4"))) + # Single non-zero sleep: attempts==2 -> base = 4.0 * 2 = 8.0. + # Full jitter keeps it in [4.0, 8.0]. + slept = clock.monotonic() + assert 4.0 <= slept <= 8.0 + + def test_full_jitter_disabled_uses_symmetric_band(self) -> None: + clock = FakeClock() + client = _ScriptedClient( + [Status.SERVICE_UNAVAILABLE, Status.SERVICE_UNAVAILABLE, Status.OK] + ) + retry = RetryPolicy( + backoff_factor=2.0, + backoff_max=1_000.0, + full_jitter=False, + jitter=0.0, + clock=clock, + ) + with Pipeline(client, policies=[retry]) as p: + p.run(_get(), DispatchContext(_instr("0" * 16 + "5"))) + # Deterministic: attempts==2 -> 2.0 * 2**1 = 4.0, no jitter. + assert clock.monotonic() == pytest.approx(4.0) + + +# ----- Retry-After ceiling ------------------------------------------------ + + +class TestRetryAfterCeiling: + def test_caps_outrageous_retry_after(self) -> None: + clock = FakeClock() + client = _ScriptedClient( + [Status.SERVICE_UNAVAILABLE, Status.OK], + headers={"Retry-After": "999999"}, + ) + retry = RetryPolicy(retry_after_max=30.0, clock=clock) + with Pipeline(client, policies=[retry]) as p: + p.run(_get(), DispatchContext(_instr("0" * 16 + "6"))) + assert clock.monotonic() == pytest.approx(30.0) + + def test_caps_outrageous_rate_limit_reset(self) -> None: + clock = FakeClock(start=0.0) + client = _ScriptedClient( + [Status.TOO_MANY_REQUESTS, Status.OK], + headers={"X-RateLimit-Reset": "999999"}, + ) + retry = RetryPolicy(retry_after_max=45.0, clock=clock, rand=_FixedRandom(1.0)) + with Pipeline(client, policies=[retry]) as p: + p.run(_get(), DispatchContext(_instr("0" * 16 + "7"))) + assert clock.monotonic() == pytest.approx(45.0) + + def test_small_retry_after_not_capped(self) -> None: + clock = FakeClock() + client = _ScriptedClient( + [Status.SERVICE_UNAVAILABLE, Status.OK], + headers={"Retry-After": "3"}, + ) + retry = RetryPolicy(retry_after_max=3600.0, clock=clock) + with Pipeline(client, policies=[retry]) as p: + p.run(_get(), DispatchContext(_instr("0" * 16 + "8"))) + assert clock.monotonic() == pytest.approx(3.0) + + +# ----- tracer attempt events ---------------------------------------------- + + +class TestTracerAttemptEvents: + def test_emits_started_and_failed_for_each_retry(self) -> None: + tracer = _RecordingTracer() + factory = _RecordingTracerFactory(tracer) + clock = FakeClock() + client = _ScriptedClient( + [Status.SERVICE_UNAVAILABLE, Status.SERVICE_UNAVAILABLE, Status.OK] + ) + retry = RetryPolicy(clock=clock, rand=_FixedRandom(0.5)) + with Pipeline(client, policies=[retry]) as p: + p.run(_get(), DispatchContext(_instr("0" * 16 + "9", factory))) + # Three attempts started (0, 1, 2); two failed before the success. + assert tracer.started == [0, 1, 2] + assert len(tracer.failed) == 2 + assert all(isinstance(err, _StatusRetryError) for err, _ in tracer.failed) + assert tracer.exhausted == 0 + + def test_emits_retries_exhausted_on_budget_exhaustion(self) -> None: + tracer = _RecordingTracer() + factory = _RecordingTracerFactory(tracer) + clock = FakeClock() + client = _ScriptedClient([Status.SERVICE_UNAVAILABLE] * 5) + retry = RetryPolicy(status_retries=1, clock=clock, rand=_FixedRandom(0.5)) + with Pipeline(client, policies=[retry]) as p: + p.run(_get(), DispatchContext(_instr("a" * 16, factory))) + assert tracer.exhausted == 1 + + def test_status_retry_marker_carries_status_code(self) -> None: + tracer = _RecordingTracer() + factory = _RecordingTracerFactory(tracer) + clock = FakeClock() + client = _ScriptedClient([Status.SERVICE_UNAVAILABLE, Status.OK]) + retry = RetryPolicy(clock=clock, rand=_FixedRandom(0.5)) + with Pipeline(client, policies=[retry]) as p: + p.run(_get(), DispatchContext(_instr("b" * 16, factory))) + err, _ = tracer.failed[0] + assert isinstance(err, _StatusRetryError) + assert err.status == int(Status.SERVICE_UNAVAILABLE) + + +class _LifecycleTracer(HttpTracer): + """Records the operation- and attempt-level events on one instance.""" + + def __init__(self) -> None: + self.events: list[str] = [] + + def operation_started(self) -> None: + self.events.append("operation_started") + + def operation_succeeded(self) -> None: + self.events.append("operation_succeeded") + + def attempt_started(self, attempt: int) -> None: + self.events.append(f"attempt_started:{attempt}") + + +class _CountingFactory: + """Mints a *fresh* tracer per ``create`` — the spec-conformant contract.""" + + def __init__(self) -> None: + self.created: list[_LifecycleTracer] = [] + + def create(self) -> HttpTracer: + tracer = _LifecycleTracer() + self.created.append(tracer) + return tracer + + +class TestSharedTracerAcrossPolicies: + def test_retry_and_tracing_share_one_per_operation_tracer(self) -> None: + # With a factory that mints a fresh tracer per ``create`` (the + # documented contract), the retry and tracing policies must still land + # on a single per-operation instance via the ``ctx.data`` cache — + # otherwise attempt events and lifecycle events split across objects. + factory = _CountingFactory() + clock = FakeClock() + client = _ScriptedClient([Status.SERVICE_UNAVAILABLE, Status.OK]) + with Pipeline( + client, + policies=[TracingPolicy(), RetryPolicy(clock=clock, rand=_FixedRandom(0.5))], + ) as p: + p.run(_get(), DispatchContext(_instr("c" * 16, factory))) + assert len(factory.created) == 1 + tracer = factory.created[0] + assert "operation_started" in tracer.events + assert "attempt_started:0" in tracer.events + assert "attempt_started:1" in tracer.events + assert "operation_succeeded" in tracer.events + + +class _FixedRandom(random.Random): + """``random.Random`` whose ``uniform`` always returns a fixed factor.""" + + def __init__(self, factor: float) -> None: + super().__init__() + self._factor = factor + + def uniform(self, a: float, b: float) -> float: + return b * self._factor if self._factor == 1.0 else a + (b - a) * self._factor diff --git a/packages/dexpace-sdk-core/tests/pipeline/test_tracer_emission.py b/packages/dexpace-sdk-core/tests/pipeline/test_tracer_emission.py new file mode 100644 index 0000000..5a77016 --- /dev/null +++ b/packages/dexpace-sdk-core/tests/pipeline/test_tracer_emission.py @@ -0,0 +1,314 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Tests that pipeline policies emit the expected ``HttpTracer`` events. + +Covers the P7 emission seam: ``TracingPolicy`` drives the operation/request/ +response lifecycle callbacks, and ``RedirectPolicy`` / +``AsyncRedirectPolicy`` emit ``request_url_resolved`` per hop. The custom +tracer below records every callback so each test can assert the exact +sequence the policies produced. +""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence + +from dexpace.sdk.core.client.async_http_client import AsyncHttpClient +from dexpace.sdk.core.client.http_client import HttpClient +from dexpace.sdk.core.errors import ServiceRequestError +from dexpace.sdk.core.http.common import Protocol, Url +from dexpace.sdk.core.http.context import DispatchContext +from dexpace.sdk.core.http.request import Method, Request +from dexpace.sdk.core.http.request.request_body import RequestBody +from dexpace.sdk.core.http.response import AsyncResponse, Response, Status +from dexpace.sdk.core.http.response.response_body import ResponseBody +from dexpace.sdk.core.instrumentation import ( + HttpTracer, + InstrumentationContext, + SpanId, + TraceFlags, + TraceId, + TraceIdType, + TraceState, +) +from dexpace.sdk.core.instrumentation.noop import NOOP_SPAN +from dexpace.sdk.core.pipeline import AsyncPipeline, Pipeline +from dexpace.sdk.core.pipeline.policies import TracingPolicy +from dexpace.sdk.core.pipeline.policies.async_redirect import AsyncRedirectPolicy +from dexpace.sdk.core.pipeline.policies.redirect import RedirectPolicy + + +class _RecordingHttpTracer(HttpTracer): + """Captures every callback as a ``(name, payload)`` event tuple.""" + + def __init__(self) -> None: + self.events: list[tuple[str, object]] = [] + + def operation_started(self) -> None: + self.events.append(("operation_started", None)) + + def operation_succeeded(self) -> None: + self.events.append(("operation_succeeded", None)) + + def operation_failed(self, error: BaseException) -> None: + self.events.append(("operation_failed", error)) + + def request_url_resolved(self, url: str) -> None: + self.events.append(("request_url_resolved", url)) + + def request_sent(self, byte_count: int) -> None: + self.events.append(("request_sent", byte_count)) + + def response_headers_received(self, status: int, headers: Mapping[str, str]) -> None: + self.events.append(("response_headers_received", (status, dict(headers)))) + + def response_received(self, byte_count: int) -> None: + self.events.append(("response_received", byte_count)) + + def names(self) -> list[str]: + return [name for name, _ in self.events] + + +class _Factory: + """``HttpTracerFactory`` returning a single shared recording tracer.""" + + def __init__(self, tracer: HttpTracer) -> None: + self._tracer = tracer + + def create(self) -> HttpTracer: + return self._tracer + + +def _instr(tracer: HttpTracer, trace: str = "0" * 31 + "1") -> InstrumentationContext: + return InstrumentationContext( + trace_id_type=TraceIdType.W3C, + trace_id=TraceId(trace), + span_id=SpanId("0" * 15 + "1"), + span=NOOP_SPAN, + trace_flags=TraceFlags.NOOP, + trace_state=TraceState.NOOP, + http_tracer_factory=_Factory(tracer), + ) + + +def _request(url: str = "https://api.example.com/v1") -> Request: + return Request(method=Method.GET, url=Url.parse(url)) + + +class _OkClient(HttpClient): + def __init__( + self, + *, + status: Status = Status.OK, + body: ResponseBody | None = None, + response_headers: tuple[tuple[str, str], ...] = (), + raise_exc: BaseException | None = None, + ) -> None: + self.status = status + self.body = body + self.response_headers = response_headers + self.raise_exc = raise_exc + + def execute(self, request: Request) -> Response: + if self.raise_exc is not None: + raise self.raise_exc + response = Response( + request=request, + protocol=Protocol.HTTP_1_1, + status=self.status, + body=self.body, + ) + for name, value in self.response_headers: + response = response.with_header(name, value) + return response + + +class TestTracingPolicyEmission: + def test_emits_lifecycle_events_in_order_on_success(self) -> None: + tracer = _RecordingHttpTracer() + instr = _instr(tracer) + body = ResponseBody.from_bytes(b"hello world") + with Pipeline(_OkClient(body=body), policies=[TracingPolicy()]) as p: + p.run(_request(), DispatchContext(instr)) + assert tracer.names() == [ + "operation_started", + "request_sent", + "response_headers_received", + "response_received", + "operation_succeeded", + ] + + def test_request_sent_reports_body_byte_count(self) -> None: + tracer = _RecordingHttpTracer() + instr = _instr(tracer) + req = Request( + method=Method.POST, + url=Url.parse("https://api.example.com/v1"), + body=RequestBody.from_bytes(b"payload-12"), + ) + with Pipeline(_OkClient(), policies=[TracingPolicy()]) as p: + p.run(req, DispatchContext(instr)) + sent = [payload for name, payload in tracer.events if name == "request_sent"] + assert sent == [len(b"payload-12")] + + def test_request_sent_zero_for_bodyless_request(self) -> None: + tracer = _RecordingHttpTracer() + instr = _instr(tracer) + with Pipeline(_OkClient(), policies=[TracingPolicy()]) as p: + p.run(_request(), DispatchContext(instr)) + sent = [payload for name, payload in tracer.events if name == "request_sent"] + assert sent == [0] + + def test_response_headers_event_carries_status_and_headers(self) -> None: + tracer = _RecordingHttpTracer() + instr = _instr(tracer) + client = _OkClient(status=Status.OK, response_headers=(("X-Trace", "abc"),)) + with Pipeline(client, policies=[TracingPolicy()]) as p: + p.run(_request(), DispatchContext(instr)) + headers_events = [ + payload for name, payload in tracer.events if name == "response_headers_received" + ] + assert len(headers_events) == 1 + payload = headers_events[0] + assert isinstance(payload, tuple) + status, headers = payload + assert status == 200 + assert isinstance(headers, Mapping) + assert headers["x-trace"] == "abc" + + def test_emits_operation_failed_on_exception(self) -> None: + tracer = _RecordingHttpTracer() + instr = _instr(tracer) + boom = ServiceRequestError("connect failed") + raised: BaseException | None = None + with Pipeline(_OkClient(raise_exc=boom), policies=[TracingPolicy()]) as p: + try: + p.run(_request(), DispatchContext(instr)) + except ServiceRequestError as err: + raised = err + assert raised is boom + assert tracer.names() == ["operation_started", "request_sent", "operation_failed"] + failed = [payload for name, payload in tracer.events if name == "operation_failed"] + assert failed == [boom] + + def test_no_events_when_tracing_disabled(self) -> None: + tracer = _RecordingHttpTracer() + instr = _instr(tracer) + with Pipeline(_OkClient(), policies=[TracingPolicy()]) as p: + p.run(_request(), DispatchContext(instr), tracing_enabled=False) + assert tracer.events == [] + + +# ----- redirect hop emission (sync + async) ------------------------------ + + +class _Hop: + __slots__ = ("location", "status") + + def __init__(self, status: Status, location: str | None = None) -> None: + self.status = status + self.location = location + + +class _ScriptedClient(HttpClient): + def __init__(self, hops: Sequence[_Hop]) -> None: + self._hops = list(hops) + self.requests: list[Request] = [] + + def execute(self, request: Request) -> Response: + idx = len(self.requests) + self.requests.append(request) + hop = self._hops[idx] + response = Response(request=request, protocol=Protocol.HTTP_1_1, status=hop.status) + if hop.location is not None: + response = response.with_header("Location", hop.location) + return response + + +class _ScriptedAsyncClient(AsyncHttpClient): + def __init__(self, hops: Sequence[_Hop]) -> None: + self._hops = list(hops) + self.requests: list[Request] = [] + + async def execute(self, request: Request) -> AsyncResponse: + idx = len(self.requests) + self.requests.append(request) + hop = self._hops[idx] + response = AsyncResponse(request=request, protocol=Protocol.HTTP_1_1, status=hop.status) + if hop.location is not None: + response = response.with_header("Location", hop.location) + return response + + +class TestRedirectHopEmission: + def test_emits_request_url_resolved_per_hop(self) -> None: + tracer = _RecordingHttpTracer() + instr = _instr(tracer) + client = _ScriptedClient( + [ + _Hop(Status.MOVED_PERMANENTLY, "https://api.example.com/new"), + _Hop(Status.OK), + ], + ) + with Pipeline(client, policies=[RedirectPolicy()]) as p: + p.run(_request("https://api.example.com/start"), DispatchContext(instr)) + resolved = [payload for name, payload in tracer.events if name == "request_url_resolved"] + assert resolved == [ + "https://api.example.com/start", + "https://api.example.com/new", + ] + + def test_single_hop_emits_only_initial_url(self) -> None: + tracer = _RecordingHttpTracer() + instr = _instr(tracer) + client = _ScriptedClient([_Hop(Status.OK)]) + with Pipeline(client, policies=[RedirectPolicy()]) as p: + p.run(_request("https://api.example.com/start"), DispatchContext(instr)) + resolved = [payload for name, payload in tracer.events if name == "request_url_resolved"] + assert resolved == ["https://api.example.com/start"] + + async def test_async_emits_request_url_resolved_per_hop(self) -> None: + tracer = _RecordingHttpTracer() + instr = _instr(tracer) + client = _ScriptedAsyncClient( + [ + _Hop(Status.FOUND, "https://api.example.com/next"), + _Hop(Status.OK), + ], + ) + async with AsyncPipeline(client, policies=[AsyncRedirectPolicy()]) as p: + await p.run( + _request("https://api.example.com/start"), + DispatchContext(instr), + ) + resolved = [payload for name, payload in tracer.events if name == "request_url_resolved"] + assert resolved == [ + "https://api.example.com/start", + "https://api.example.com/next", + ] + + +class TestSharedTracerAcrossPolicies: + def test_redirect_and_tracing_share_one_tracer(self) -> None: + tracer = _RecordingHttpTracer() + instr = _instr(tracer) + client = _ScriptedClient( + [ + _Hop(Status.MOVED_PERMANENTLY, "https://api.example.com/new"), + _Hop(Status.OK), + ], + ) + with Pipeline( + client, + policies=[RedirectPolicy(), TracingPolicy()], + ) as p: + p.run(_request("https://api.example.com/start"), DispatchContext(instr)) + names = tracer.names() + # Both the redirect hop events and the operation lifecycle events land + # on the same recording tracer instance. + assert "request_url_resolved" in names + assert "operation_started" in names + assert "operation_succeeded" in names + # Two hops -> two request_sent events from the inner TracingPolicy. + assert names.count("request_url_resolved") == 2 diff --git a/packages/dexpace-sdk-core/tests/serde/test_codec.py b/packages/dexpace-sdk-core/tests/serde/test_codec.py new file mode 100644 index 0000000..503a487 --- /dev/null +++ b/packages/dexpace-sdk-core/tests/serde/test_codec.py @@ -0,0 +1,591 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Tests for the typed-model codec. + +Imports come straight from ``dexpace.sdk.core.serde.codec`` rather than the +package ``__init__`` so the suite does not depend on the re-export landing. +""" + +from __future__ import annotations + +import dataclasses +import enum +from dataclasses import dataclass, field +from datetime import UTC, date, datetime, time + +import pytest + +from dexpace.sdk.core.errors import DeserializationError, SerializationError +from dexpace.sdk.core.serde.codec import ( + ALIAS_KEY, + Codec, + CodecError, + discriminated, + field_alias, + variant, +) +from dexpace.sdk.core.serde.tristate import ABSENT, NULL, Present, Tristate + + +class _Color(enum.Enum): + RED = "red" + GREEN = "green" + + +@dataclass(frozen=True, slots=True) +class _Inner: + x: int + + +@dataclass(frozen=True, slots=True) +class _Model: + name: str + created: datetime + color: _Color + inner: _Inner | None = None + tags: list[str] = field(default_factory=list) + nick: str = field(default="", metadata={ALIAS_KEY: "nick_name"}) + note: Tristate[str] = ABSENT + + +_BASE_DOC: dict[str, object] = { + "name": "alice", + "created": "2026-01-02T03:04:05Z", + "color": "red", + "inner": {"x": 7}, + "tags": ["p", "q"], + "nick_name": "al", +} + + +@pytest.fixture +def codec() -> Codec: + return Codec() + + +# --------------------------------------------------------------------------- # +# Plain dataclass decode # +# --------------------------------------------------------------------------- # + + +def test_decode_populates_all_declared_fields(codec: Codec) -> None: + model = codec.decode(_BASE_DOC, _Model) + assert model.name == "alice" + assert model.color is _Color.RED + assert model.inner == _Inner(7) + assert model.tags == ["p", "q"] + + +def test_decode_maps_aliased_field_from_wire_name(codec: Codec) -> None: + model = codec.decode(_BASE_DOC, _Model) + assert model.nick == "al" + + +def test_decode_parses_datetime_with_trailing_z(codec: Codec) -> None: + model = codec.decode(_BASE_DOC, _Model) + assert model.created == datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + + +def test_decode_applies_field_default_when_key_missing(codec: Codec) -> None: + doc = {k: v for k, v in _BASE_DOC.items() if k != "tags"} + model = codec.decode(doc, _Model) + assert model.tags == [] + + +def test_decode_raises_when_required_field_missing(codec: Codec) -> None: + doc = {"created": "2026-01-01T00:00:00", "color": "red"} + with pytest.raises(CodecError) as info: + codec.decode(doc, _Model) + assert "name" in str(info.value) + + +def test_decode_raises_when_dataclass_target_is_not_a_mapping(codec: Codec) -> None: + with pytest.raises(CodecError): + codec.decode(["not", "an", "object"], _Model) + + +# --------------------------------------------------------------------------- # +# Unknown-key tolerance # +# --------------------------------------------------------------------------- # + + +def test_decode_tolerates_unknown_keys_by_default(codec: Codec) -> None: + model = codec.decode({**_BASE_DOC, "future_field": 1}, _Model) + assert model.name == "alice" + + +def test_decode_rejects_unknown_keys_when_configured() -> None: + strict = Codec(tolerate_unknown=False) + with pytest.raises(CodecError) as info: + strict.decode({**_BASE_DOC, "future_field": 1}, _Model) + assert "future_field" in str(info.value) + + +# --------------------------------------------------------------------------- # +# Tristate fields # +# --------------------------------------------------------------------------- # + + +def test_decode_missing_tristate_key_uses_default_absent(codec: Codec) -> None: + model = codec.decode(_BASE_DOC, _Model) + assert model.note is ABSENT + + +def test_decode_present_tristate_value_wraps_in_present(codec: Codec) -> None: + model = codec.decode({**_BASE_DOC, "note": "hi"}, _Model) + assert model.note == Present("hi") + + +def test_decode_null_tristate_value_becomes_null(codec: Codec) -> None: + model = codec.decode({**_BASE_DOC, "note": None}, _Model) + assert model.note is NULL + + +def test_encode_omits_absent_tristate_field(codec: Codec) -> None: + model = codec.decode(_BASE_DOC, _Model) + assert "note" not in codec.encode(model) # type: ignore[operator] + + +def test_encode_writes_null_for_null_tristate_field(codec: Codec) -> None: + model = codec.decode({**_BASE_DOC, "note": None}, _Model) + encoded = codec.encode(model) + assert isinstance(encoded, dict) + assert encoded["note"] is None + + +def test_encode_writes_value_for_present_tristate_field(codec: Codec) -> None: + model = codec.decode({**_BASE_DOC, "note": "hi"}, _Model) + encoded = codec.encode(model) + assert isinstance(encoded, dict) + assert encoded["note"] == "hi" + + +def test_decode_tristate_recurses_into_inner_type() -> None: + @dataclass(frozen=True, slots=True) + class Wrapped: + when: Tristate[datetime] = ABSENT + + model = Codec().decode({"when": "2026-01-02T03:04:05Z"}, Wrapped) + assert model.when == Present(datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)) + + +# --------------------------------------------------------------------------- # +# Encode round-trip # +# --------------------------------------------------------------------------- # + + +def test_encode_uses_wire_name_for_aliased_field(codec: Codec) -> None: + model = codec.decode(_BASE_DOC, _Model) + encoded = codec.encode(model) + assert isinstance(encoded, dict) + assert encoded["nick_name"] == "al" + assert "nick" not in encoded + + +def test_encode_emits_iso_string_for_datetime(codec: Codec) -> None: + model = codec.decode(_BASE_DOC, _Model) + encoded = codec.encode(model) + assert isinstance(encoded, dict) + assert encoded["created"] == "2026-01-02T03:04:05+00:00" + + +def test_encode_emits_enum_value(codec: Codec) -> None: + model = codec.decode(_BASE_DOC, _Model) + encoded = codec.encode(model) + assert isinstance(encoded, dict) + assert encoded["color"] == "red" + + +def test_encode_optional_none_is_written_as_null(codec: Codec) -> None: + doc = {k: v for k, v in _BASE_DOC.items() if k != "inner"} + model = codec.decode(doc, _Model) + encoded = codec.encode(model) + assert isinstance(encoded, dict) + assert encoded["inner"] is None + + +def test_encode_rejects_non_documentable_value(codec: Codec) -> None: + with pytest.raises(SerializationError): + codec.encode(object()) + + +def test_encode_utf8_bytes_decodes_to_string(codec: Codec) -> None: + assert codec.encode(b"hello") == "hello" + + +def test_encode_non_utf8_bytes_raises_serialization_error(codec: Codec) -> None: + with pytest.raises(SerializationError): + codec.encode(b"\xff\xfe") + + +# --------------------------------------------------------------------------- # +# Containers and scalars # +# --------------------------------------------------------------------------- # + + +def test_decode_list_of_models(codec: Codec) -> None: + decoded = codec.decode([{"x": 1}, {"x": 2}], list[_Inner]) + assert decoded == [_Inner(1), _Inner(2)] + + +def test_decode_dict_of_models(codec: Codec) -> None: + decoded = codec.decode({"a": {"x": 1}, "b": {"x": 2}}, dict[str, _Inner]) + assert decoded == {"a": _Inner(1), "b": _Inner(2)} + + +def test_decode_homogeneous_tuple(codec: Codec) -> None: + assert codec.decode([1, 2, 3], tuple[int, ...]) == (1, 2, 3) + + +def test_decode_set(codec: Codec) -> None: + assert codec.decode([1, 2, 2, 3], set[int]) == {1, 2, 3} + + +def test_decode_frozenset(codec: Codec) -> None: + assert codec.decode([1, 2], frozenset[int]) == frozenset({1, 2}) + + +def test_decode_optional_none_passthrough(codec: Codec) -> None: + # Unions are decoded at runtime; the static type[T] signature can't model them. + assert codec.decode(None, str | None) is None # type: ignore[arg-type] + + +def test_decode_optional_decodes_inner(codec: Codec) -> None: + assert codec.decode({"x": 4}, _Inner | None) == _Inner(4) # type: ignore[arg-type] + + +def test_decode_scalars_pass_through_without_coercion(codec: Codec) -> None: + # Scalars pass through uncoerced: decoding "5" as int yields the original str. + assert codec.decode("5", int) == "5" # type: ignore[comparison-overlap] + assert codec.decode(5, int) == 5 + assert codec.decode(True, bool) is True + + +def test_decode_object_target_is_passthrough(codec: Codec) -> None: + payload = {"arbitrary": [1, 2]} + assert codec.decode(payload, object) is payload + + +def test_decode_sequence_rejects_non_array(codec: Codec) -> None: + with pytest.raises(CodecError): + codec.decode("string-not-array", list[str]) + + +# --------------------------------------------------------------------------- # +# Date / time # +# --------------------------------------------------------------------------- # + + +def test_decode_date(codec: Codec) -> None: + assert codec.decode("2026-06-09", date) == date(2026, 6, 9) + + +def test_decode_time(codec: Codec) -> None: + assert codec.decode("12:30:00", time) == time(12, 30) + + +def test_decode_invalid_datetime_raises_with_path(codec: Codec) -> None: + with pytest.raises(CodecError) as info: + codec.decode({**_BASE_DOC, "created": "not-a-date"}, _Model) + assert "created" in str(info.value) + + +def test_decode_invalid_enum_raises(codec: Codec) -> None: + with pytest.raises(CodecError): + codec.decode({**_BASE_DOC, "color": "purple"}, _Model) + + +# --------------------------------------------------------------------------- # +# Discriminated unions # +# --------------------------------------------------------------------------- # + + +@discriminated("type") +class _Pay: + pass + + +@variant("card") +@dataclass(frozen=True, slots=True) +class _Card(_Pay): + last4: str + type: str = "card" + + +@variant("bank") +@dataclass(frozen=True, slots=True) +class _Bank(_Pay): + iban: str + type: str = "bank" + + +def test_decode_dispatches_to_variant_by_tag(codec: Codec) -> None: + decoded = codec.decode({"type": "card", "last4": "1234"}, _Pay) + assert isinstance(decoded, _Card) + assert decoded.last4 == "1234" + + +def test_decode_list_of_union_variants(codec: Codec) -> None: + decoded = codec.decode( + [{"type": "bank", "iban": "X"}, {"type": "card", "last4": "9"}], + list[_Pay], + ) + assert isinstance(decoded[0], _Bank) + assert isinstance(decoded[1], _Card) + + +def test_decode_unknown_tag_raises_listing_known(codec: Codec) -> None: + with pytest.raises(CodecError) as info: + codec.decode({"type": "crypto"}, _Pay) + message = str(info.value) + assert "crypto" in message + assert "card" in message + + +def test_decode_missing_discriminator_raises(codec: Codec) -> None: + with pytest.raises(CodecError): + codec.decode({"last4": "1"}, _Pay) + + +def test_encode_variant_emits_discriminator_field(codec: Codec) -> None: + encoded = codec.encode(_Card(last4="1234")) + assert encoded == {"last4": "1234", "type": "card"} + + +def test_variant_duplicate_tag_raises() -> None: + with pytest.raises(ValueError, match="already registered"): + variant("card")(_Card) + + +def test_variant_without_discriminated_base_raises() -> None: + @dataclass(frozen=True, slots=True) + class Orphan: + v: int + + with pytest.raises(TypeError): + variant("z")(Orphan) + + +# --------------------------------------------------------------------------- # +# Error model # +# --------------------------------------------------------------------------- # + + +def test_codec_error_is_a_deserialization_error() -> None: + assert issubclass(CodecError, DeserializationError) + assert issubclass(CodecError, ValueError) + + +def test_codec_error_renders_nested_path(codec: Codec) -> None: + @dataclass(frozen=True, slots=True) + class Outer: + items: list[_Inner] + + with pytest.raises(CodecError) as info: + codec.decode({"items": [{"x": 1}, "oops"]}, Outer) + assert "items[1]" in str(info.value) + + +def test_codec_error_carries_path_tuple(codec: Codec) -> None: + err = CodecError("boom", path=("a", "[0]", "b"), target_name="X") + assert err.path == ("a", "[0]", "b") + assert "a[0].b" in str(err) + + +# --------------------------------------------------------------------------- # +# field_alias helper # +# --------------------------------------------------------------------------- # + + +def test_field_alias_sets_metadata() -> None: + @dataclass(frozen=True, slots=True) + class Aliased: + value: int = field_alias("v", default=3) # type: ignore[assignment] + + fields = {f.name: f for f in dataclasses.fields(Aliased)} + assert fields["value"].metadata[ALIAS_KEY] == "v" + + +def test_field_alias_default_is_used_when_key_absent(codec: Codec) -> None: + @dataclass(frozen=True, slots=True) + class Aliased: + value: int = field_alias("v", default=3) # type: ignore[assignment] + + assert codec.decode({}, Aliased).value == 3 + assert codec.decode({"v": 9}, Aliased).value == 9 + + +def test_field_alias_default_factory() -> None: + @dataclass(frozen=True, slots=True) + class Aliased: + # field_alias returns a dataclasses.Field, same as dataclasses.field(). + items: list[int] = field_alias("xs", default_factory=list) # type: ignore[assignment] # noqa: RUF009 + + decoded = Codec().decode({}, Aliased) + assert decoded.items == [] + + +# --------------------------------------------------------------------------- # +# Enum encoding (StrEnum / IntEnum collapse to scalar value) # +# --------------------------------------------------------------------------- # + + +class _StrFlavour(enum.StrEnum): + A = "aValue" + B = "bValue" + + +class _IntLevel(enum.IntEnum): + LOW = 1 + HIGH = 9 + + +def test_encode_str_enum_member_collapses_to_value(codec: Codec) -> None: + encoded = codec.encode(_StrFlavour.A) + assert encoded == "aValue" + assert type(encoded) is str + + +def test_encode_int_enum_member_collapses_to_value(codec: Codec) -> None: + encoded = codec.encode(_IntLevel.HIGH) + assert encoded == 9 + assert type(encoded) is int + + +def test_encode_str_enum_field_inside_dataclass(codec: Codec) -> None: + @dataclass(frozen=True, slots=True) + class Holder: + flavour: _StrFlavour + + encoded = codec.encode(Holder(_StrFlavour.B)) + assert isinstance(encoded, dict) + assert encoded["flavour"] == "bValue" + assert type(encoded["flavour"]) is str + + +def test_str_enum_round_trips(codec: Codec) -> None: + @dataclass(frozen=True, slots=True) + class Holder: + flavour: _StrFlavour + + model = Holder(_StrFlavour.A) + decoded = codec.decode(codec.encode(model), Holder) + assert decoded == model + + +# --------------------------------------------------------------------------- # +# Discriminated tag is exempt from strict unknown-key rejection # +# --------------------------------------------------------------------------- # + + +@discriminated("kind") +class _Shape: + pass + + +@variant("circle") +@dataclass(frozen=True, slots=True) +class _Circle(_Shape): + radius: int + + +def test_strict_codec_accepts_discriminator_without_matching_field() -> None: + strict = Codec(tolerate_unknown=False) + decoded = strict.decode({"kind": "circle", "radius": 3}, _Shape) + assert isinstance(decoded, _Circle) + assert decoded.radius == 3 + + +def test_strict_codec_still_rejects_genuine_unknown_in_variant() -> None: + strict = Codec(tolerate_unknown=False) + with pytest.raises(CodecError) as info: + strict.decode({"kind": "circle", "radius": 3, "stray": 1}, _Shape) + assert "stray" in str(info.value) + + +def test_tolerant_codec_dispatches_discriminator_normally(codec: Codec) -> None: + decoded = codec.decode({"kind": "circle", "radius": 5}, _Shape) + assert isinstance(decoded, _Circle) + + +# --------------------------------------------------------------------------- # +# Fixed-arity tuple length validation # +# --------------------------------------------------------------------------- # + + +def test_decode_fixed_arity_tuple_exact_length(codec: Codec) -> None: + assert codec.decode([1, "a"], tuple[int, str]) == (1, "a") + + +def test_decode_fixed_arity_tuple_too_short_raises(codec: Codec) -> None: + with pytest.raises(CodecError) as info: + codec.decode([1], tuple[int, str]) + assert "2" in str(info.value) + + +def test_decode_fixed_arity_tuple_too_long_raises(codec: Codec) -> None: + with pytest.raises(CodecError): + codec.decode([1, "a", "EXTRA", 9], tuple[int, str]) + + +def test_decode_homogeneous_tuple_still_unbounded(codec: Codec) -> None: + assert codec.decode([1, 2, 3, 4], tuple[int, ...]) == (1, 2, 3, 4) + + +# --------------------------------------------------------------------------- # +# dict key-type recovery # +# --------------------------------------------------------------------------- # + + +class _KeyKind(enum.Enum): + FIRST = "first" + SECOND = "second" + + +def test_decode_dict_recovers_enum_keys(codec: Codec) -> None: + # Enum key types are recovered exactly as enum value types are: the wire + # key string is mapped to its enum member. + decoded = codec.decode({"first": 10}, dict[_KeyKind, int]) + assert decoded == {_KeyKind.FIRST: 10} + assert all(isinstance(k, _KeyKind) for k in decoded) + + +def test_decode_dict_str_keys_pass_through(codec: Codec) -> None: + decoded = codec.decode({"a": 1}, dict[str, int]) + assert decoded == {"a": 1} + + +def test_decode_dict_scalar_keys_follow_value_no_coercion_rule(codec: Codec) -> None: + # The codec performs no scalar coercion (decode("5", int) == "5"); keys + # follow the same rule as values, so a wire ``"1"`` key stays a string. + decoded = codec.decode({"1": "a"}, dict[int, str]) + assert decoded == {"1": "a"} # type: ignore[comparison-overlap] + + +def test_decode_dict_recovers_enum_keys_and_model_values(codec: Codec) -> None: + decoded = codec.decode({"second": {"x": 5}}, dict[_KeyKind, _Inner]) + assert decoded == {_KeyKind.SECOND: _Inner(5)} + + +# --------------------------------------------------------------------------- # +# Union None-passthrough gating # +# --------------------------------------------------------------------------- # + + +def test_decode_non_optional_union_does_not_accept_none(codec: Codec) -> None: + # ``int | str`` has no ``NoneType`` arm; ``None`` must not be injected. + assert codec.decode(None, int | str) is None # type: ignore[arg-type] + + +def test_decode_optional_union_accepts_none_when_arm_present(codec: Codec) -> None: + assert codec.decode(None, int | None) is None # type: ignore[arg-type] + + +def test_decode_optional_union_recovers_inner_dataclass(codec: Codec) -> None: + assert codec.decode({"x": 4}, _Inner | None) == _Inner(4) # type: ignore[arg-type] + + +def test_decode_multi_arm_union_passes_scalar_through(codec: Codec) -> None: + # A multi-arm union with no matching coercion passes the scalar through. + assert codec.decode("hello", int | str) == "hello" # type: ignore[arg-type, comparison-overlap] diff --git a/packages/dexpace-sdk-core/tests/serde/test_tristate.py b/packages/dexpace-sdk-core/tests/serde/test_tristate.py new file mode 100644 index 0000000..0fc6607 --- /dev/null +++ b/packages/dexpace-sdk-core/tests/serde/test_tristate.py @@ -0,0 +1,203 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Tests for the three-valued ``Tristate`` optional type.""" + +from __future__ import annotations + +import copy +import pickle + +import pytest + +from dexpace.sdk.core.serde import ( + ABSENT, + NULL, + Present, + Tristate, + fold, + is_absent, + is_null, + is_present, + of_optional, + present, +) + + +def test_present_wraps_value() -> None: + wrapped = present("x") + assert isinstance(wrapped, Present) + assert wrapped.value == "x" + + +def test_present_preserves_falsy_values() -> None: + falsy_values: tuple[object, ...] = (0, "", [], False, 0.0) + for falsy in falsy_values: + wrapped = present(falsy) + assert isinstance(wrapped, Present) + assert wrapped.value == falsy + assert is_present(wrapped) + + +def test_present_is_frozen() -> None: + wrapped = present(1) + with pytest.raises(AttributeError): + wrapped.value = 2 # type: ignore[misc] # frozen-dataclass guard under test + + +def test_present_equality_and_hash() -> None: + assert present(1) == present(1) + assert present(1) != present(2) + assert hash(present(1)) == hash(Present(1)) + + +def test_absent_and_null_are_singletons() -> None: + assert ABSENT is ABSENT + assert NULL is NULL + assert ABSENT is not NULL # type: ignore[comparison-overlap] # distinct singletons + assert type(ABSENT)() is ABSENT + assert type(NULL)() is NULL + + +def test_sentinel_reprs() -> None: + assert repr(ABSENT) == "ABSENT" + assert repr(NULL) == "NULL" + + +def test_singletons_survive_copy() -> None: + assert copy.copy(ABSENT) is ABSENT + assert copy.deepcopy(ABSENT) is ABSENT + assert copy.copy(NULL) is NULL + assert copy.deepcopy(NULL) is NULL + + +def test_singletons_survive_pickle() -> None: + assert pickle.loads(pickle.dumps(ABSENT)) is ABSENT + assert pickle.loads(pickle.dumps(NULL)) is NULL + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + (None, NULL), + (0, Present(0)), + ("", Present("")), + ("hello", Present("hello")), + ([], Present([])), + ], + ids=["none-to-null", "zero", "empty-str", "str", "empty-list"], +) +def test_of_optional_maps_none_to_null(value: object, expected: Tristate[object]) -> None: + assert of_optional(value) == expected + + +def test_of_optional_never_returns_absent() -> None: + # of_optional never yields ABSENT; the identity check can't statically overlap. + assert of_optional(None) is not ABSENT # type: ignore[comparison-overlap] + assert of_optional("x") is not ABSENT # type: ignore[comparison-overlap] + + +def test_guards_are_mutually_exclusive() -> None: + cases: list[Tristate[int]] = [ABSENT, NULL, present(7)] + for state in cases: + flags = [is_absent(state), is_null(state), is_present(state)] + assert sum(flags) == 1 + + +def test_is_absent() -> None: + assert is_absent(ABSENT) + assert not is_absent(NULL) + assert not is_absent(present(1)) + + +def test_is_null() -> None: + assert is_null(NULL) + assert not is_null(ABSENT) + assert not is_null(present(1)) + + +def test_is_present() -> None: + assert is_present(present(1)) + assert not is_present(ABSENT) + assert not is_present(NULL) + + +def test_fold_dispatches_to_present() -> None: + result = fold( + present(10), + on_absent=lambda: "absent", + on_null=lambda: "null", + on_present=lambda v: f"present:{v}", + ) + assert result == "present:10" + + +def test_fold_dispatches_to_null() -> None: + result = fold( + NULL, + on_absent=lambda: "absent", + on_null=lambda: "null", + on_present=lambda v: f"present:{v}", + ) + assert result == "null" + + +def test_fold_dispatches_to_absent() -> None: + result = fold( + ABSENT, + on_absent=lambda: "absent", + on_null=lambda: "null", + on_present=lambda v: f"present:{v}", + ) + assert result == "absent" + + +def test_fold_runs_exactly_one_branch() -> None: + calls: list[str] = [] + fold( + present("v"), + on_absent=lambda: calls.append("absent"), + on_null=lambda: calls.append("null"), + on_present=lambda _v: calls.append("present"), + ) + assert calls == ["present"] + + +def test_fold_present_passes_falsy_value() -> None: + result = fold( + present(0), + on_absent=lambda: -1, + on_null=lambda: -2, + on_present=lambda v: v, + ) + assert result == 0 + + +def test_fold_is_exhaustive_over_all_inhabitants() -> None: + def describe(state: Tristate[int]) -> str: + return fold( + state, + on_absent=lambda: "absent", + on_null=lambda: "null", + on_present=lambda v: f"present:{v}", + ) + + assert describe(ABSENT) == "absent" + assert describe(NULL) == "null" + assert describe(present(3)) == "present:3" + + +def test_serialize_semantics_via_fold() -> None: + """ABSENT omits the key, NULL writes null, Present writes the value.""" + + def encode(field: str, state: Tristate[object]) -> dict[str, object]: + return fold( + state, + on_absent=lambda: {}, + on_null=lambda: {field: None}, + on_present=lambda v: {field: v}, + ) + + assert encode("name", ABSENT) == {} + assert encode("name", NULL) == {"name": None} + assert encode("name", present("Ada")) == {"name": "Ada"} diff --git a/packages/dexpace-sdk-core/tests/sse/test_sse_bom.py b/packages/dexpace-sdk-core/tests/sse/test_sse_bom.py new file mode 100644 index 0000000..24a342e --- /dev/null +++ b/packages/dexpace-sdk-core/tests/sse/test_sse_bom.py @@ -0,0 +1,102 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Tests for leading UTF-8 BOM handling in the WHATWG SSE parser (fix F4).""" + +from __future__ import annotations + +from collections.abc import AsyncIterator + +import pytest + +from dexpace.sdk.core.errors import StreamingError +from dexpace.sdk.core.http.sse import SseEvent, parse_async_events, parse_events +from dexpace.sdk.core.http.sse.parser import SseParser + +_BOM = b"\xef\xbb\xbf" + + +def _events(stream: bytes, chunk_size: int = 4096) -> list[SseEvent]: + return list(parse_events(_chunked(stream, chunk_size))) + + +def _chunked(data: bytes, size: int) -> list[bytes]: + return [data[i : i + size] for i in range(0, len(data), size)] + + +def test_leading_bom_stripped_so_first_field_parses() -> None: + stream = _BOM + b"data: hello\n\n" + assert _events(stream) == [SseEvent(data="hello")] + + +def test_leading_bom_strips_only_once() -> None: + # Only the first BOM is stripped; a second is ordinary content that + # corrupts the field name, so the event yields no data and is dropped. + stream = _BOM + _BOM + b"data: hi\n\n" + assert _events(stream) == [] + + +def test_bom_split_across_chunks_is_still_stripped() -> None: + stream = _BOM + b"data: split\n\n" + # Force every byte (and thus the BOM) to arrive one at a time. + assert _events(stream, chunk_size=1) == [SseEvent(data="split")] + + +def test_bom_split_two_then_rest() -> None: + parser = SseParser() + parser.feed(_BOM[:2]) + assert list(parser.drain()) == [] + parser.feed(_BOM[2:] + b"data: x\n\n") + assert list(parser.drain()) == [SseEvent(data="x")] + + +def test_no_bom_first_field_unaffected() -> None: + stream = b"data: plain\n\n" + assert _events(stream) == [SseEvent(data="plain")] + + +def test_feff_not_at_start_is_not_stripped() -> None: + # U+FEFF after the first event is content, never a BOM. + stream = b"data: a\n\n" + "data: b".encode() + b"\n\n" + events = _events(stream) + assert events[0] == SseEvent(data="a") + assert events[1] == SseEvent(data="b") + + +def test_bom_only_stream_at_eos_emits_nothing() -> None: + assert _events(_BOM) == [] + + +def test_partial_bom_only_stream_at_eos_decodes_as_content() -> None: + # A lone first BOM byte that never completes is not a BOM; it is a + # truncated UTF-8 sequence and must surface as a mid-codepoint error. + parser = SseParser() + parser.feed(_BOM[:1]) + with pytest.raises(StreamingError): + list(parser.end()) + + +def test_bom_does_not_regress_max_line_cap() -> None: + parser = SseParser(max_line_bytes=16) + with pytest.raises(StreamingError): + parser.feed(_BOM + b"data: " + b"x" * 64) + + +def test_bom_does_not_regress_mid_codepoint_guard() -> None: + parser = SseParser() + # Leading BOM, then a truncated multi-byte codepoint at end-of-stream. + parser.feed(_BOM + b"data: \xe2\x82") # start of U+20AC, cut short + with pytest.raises(StreamingError): + list(parser.end()) + + +async def test_leading_bom_stripped_in_async_path() -> None: + async def producer() -> AsyncIterator[bytes]: + yield _BOM[:1] + yield _BOM[1:] + yield b"data: async\n\n" + + events: list[SseEvent] = [] + async for event in parse_async_events(producer()): + events.append(event) + assert events == [SseEvent(data="async")] diff --git a/packages/dexpace-sdk-core/tests/test_conformance_fixtures.py b/packages/dexpace-sdk-core/tests/test_conformance_fixtures.py new file mode 100644 index 0000000..44b319e --- /dev/null +++ b/packages/dexpace-sdk-core/tests/test_conformance_fixtures.py @@ -0,0 +1,330 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Behavioral conformance fixtures — one observable check per seam (Q2). + +Each test pins the *observable* contract of one behavioral seam — not its +private state — so a quiet regression in any of these load-bearing behaviors +fails the build. Where a seam only makes sense end-to-end, the check drives a +mock ``HttpClient`` / ``Pipeline`` rather than poking at internals: + +* a single-use request body raises on a second read; +* ``RequestBody.to_replayable`` makes a single-use body retryable; +* ``Headers`` lookups are case-insensitive; +* the ``ContextStore`` evicts an entry on ``CallContext.close``; +* the retry policy honours a ``Retry-After`` header; +* the paginator iterates items across every page; +* the webhook verifier accepts a valid signature and rejects a tampered one. +""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import io +import json +from collections.abc import Iterator + +import pytest + +from dexpace.sdk.core.client.http_client import HttpClient +from dexpace.sdk.core.http.common import Headers, MediaType, Protocol, Url +from dexpace.sdk.core.http.common.http_header_name import CONTENT_TYPE +from dexpace.sdk.core.http.context import ContextStore, DispatchContext +from dexpace.sdk.core.http.request import Method, Request +from dexpace.sdk.core.http.request.request_body import RequestBody +from dexpace.sdk.core.http.response import Response, Status +from dexpace.sdk.core.http.response.response_body import ResponseBody +from dexpace.sdk.core.http.webhooks import ( + InvalidWebhookSignatureError, + WebhookVerifier, +) +from dexpace.sdk.core.instrumentation import ( + InstrumentationContext, + SpanId, + TraceFlags, + TraceId, + TraceIdType, + TraceState, +) +from dexpace.sdk.core.instrumentation.noop import NOOP_SPAN +from dexpace.sdk.core.pagination import CursorStrategy, Paginator +from dexpace.sdk.core.pipeline import Pipeline +from dexpace.sdk.core.pipeline.policies import RetryPolicy + +from .conftest import FakeClock + + +def _instrumentation(trace_id_value: str) -> InstrumentationContext: + return InstrumentationContext( + trace_id_type=TraceIdType.W3C, + trace_id=TraceId(trace_id_value), + span_id=SpanId("0" * 16), + span=NOOP_SPAN, + trace_flags=TraceFlags.NOOP, + trace_state=TraceState.NOOP, + ) + + +# ----- single-use body raises on second read ------------------------------ + + +def test_single_use_stream_body_raises_on_second_read() -> None: + body = RequestBody.from_stream(io.BytesIO(b"payload")) + assert b"".join(body.iter_bytes()) == b"payload" + with pytest.raises(RuntimeError, match="already called"): + list(body.iter_bytes()) + + +def test_single_use_iter_body_raises_on_second_read() -> None: + body = RequestBody.from_iter([b"a", b"b", b"c"]) + assert b"".join(body.iter_bytes()) == b"abc" + with pytest.raises(RuntimeError, match="already called"): + list(body.iter_bytes()) + + +# ----- to_replayable makes a single-use body retryable -------------------- + + +def test_to_replayable_allows_repeated_reads_of_a_stream_body() -> None: + body = RequestBody.from_stream(io.BytesIO(b"retry-me")).to_replayable() + assert body.is_replayable() + # Two independent reads, modelling an initial send plus one retry. + assert b"".join(body.iter_bytes()) == b"retry-me" + assert b"".join(body.iter_bytes()) == b"retry-me" + + +def test_retry_replays_a_single_use_body_across_attempts() -> None: + sent: list[bytes] = [] + + class _DrainingClient(HttpClient): + """Fails once (drains the body), then succeeds on the replayed body.""" + + def __init__(self) -> None: + self.attempts = 0 + + def execute(self, request: Request) -> Response: + self.attempts += 1 + assert request.body is not None + sent.append(b"".join(request.body.iter_bytes())) + status = Status.SERVICE_UNAVAILABLE if self.attempts == 1 else Status.OK + return Response(request=request, protocol=Protocol.HTTP_1_1, status=status) + + client = _DrainingClient() + request = Request( + method=Method.PUT, + url=Url.parse("https://api.example.com/things/1"), + body=RequestBody.from_iter([b"single-use-payload"]), + ) + retry = RetryPolicy(clock=FakeClock()) + with Pipeline(client, policies=[retry]) as pipeline: + response = pipeline.run(request, DispatchContext(_instrumentation("0" * 16 + "a"))) + + assert response.status is Status.OK + assert client.attempts == 2 + # The retry replayed the exact same bytes the first attempt drained — the + # observable proof that retry auto-buffered the single-use body. + assert sent == [b"single-use-payload", b"single-use-payload"] + + +# ----- Headers case-insensitive lookup ------------------------------------ + + +def test_headers_lookup_is_case_insensitive() -> None: + headers = Headers({"Content-Type": "application/json"}) + assert headers.get("content-type") == "application/json" + assert headers.get("CONTENT-TYPE") == "application/json" + assert headers.get("Content-Type") == "application/json" + assert "cOnTeNt-TyPe" in headers + # A typed header-name constant resolves identically to its string form. + assert headers.get(CONTENT_TYPE) == "application/json" + + +def test_headers_canonicalises_name_on_iteration() -> None: + headers = Headers({"X-Custom-Header": "v"}) + assert tuple(headers) == ("x-custom-header",) + + +# ----- ContextStore evicts on close --------------------------------------- + + +def test_context_store_evicts_entry_on_close() -> None: + trace_id = "0" * 16 + "b" + instr = _instrumentation(trace_id) + dispatch = DispatchContext(instrumentation_context=instr) + + request_ctx = dispatch.to_request_context( + Request(method=Method.GET, url=Url.parse("https://example.com/")) + ) + assert ContextStore.get(trace_id) is request_ctx + + request_ctx.close() + assert ContextStore.get(trace_id) is None + + +def test_context_manager_exit_evicts_from_store() -> None: + trace_id = "0" * 16 + "c" + instr = _instrumentation(trace_id) + dispatch = DispatchContext(instrumentation_context=instr) + request_ctx = dispatch.to_request_context( + Request(method=Method.GET, url=Url.parse("https://example.com/")) + ) + with request_ctx: + assert ContextStore.get(trace_id) is request_ctx + assert ContextStore.get(trace_id) is None + + +# ----- retry honours Retry-After ------------------------------------------ + + +class _RetryAfterClient(HttpClient): + """Returns 503 carrying a ``Retry-After`` header, then 200.""" + + def __init__(self, retry_after: str) -> None: + self._retry_after = retry_after + self.attempts = 0 + + def execute(self, request: Request) -> Response: + self.attempts += 1 + if self.attempts == 1: + response = Response( + request=request, + protocol=Protocol.HTTP_1_1, + status=Status.SERVICE_UNAVAILABLE, + ) + return response.with_header("Retry-After", self._retry_after) + return Response(request=request, protocol=Protocol.HTTP_1_1, status=Status.OK) + + +def test_retry_sleeps_for_the_retry_after_delay() -> None: + clock = FakeClock(start=1_000.0) + client = _RetryAfterClient(retry_after="7") + retry = RetryPolicy(clock=clock) + request = Request(method=Method.GET, url=Url.parse("https://example.com/")) + with Pipeline(client, policies=[retry]) as pipeline: + response = pipeline.run(request, DispatchContext(_instrumentation("0" * 16 + "d"))) + + assert response.status is Status.OK + assert client.attempts == 2 + # The server asked for exactly 7 seconds; the policy slept precisely that + # long (not the computed backoff). + assert clock.monotonic() == pytest.approx(1_007.0) + + +def test_retry_caps_a_hostile_retry_after_header() -> None: + clock = FakeClock(start=0.0) + client = _RetryAfterClient(retry_after="999999") # ~11.5 days + retry = RetryPolicy(clock=clock, retry_after_max=30.0) + request = Request(method=Method.GET, url=Url.parse("https://example.com/")) + with Pipeline(client, policies=[retry]) as pipeline: + pipeline.run(request, DispatchContext(_instrumentation("0" * 16 + "e"))) + + # A multi-day header is clamped to the configured ceiling. + assert clock.monotonic() == pytest.approx(30.0) + + +# ----- pagination iterates items across pages ----------------------------- + + +class _InMemoryBody(ResponseBody): + """Response body backed by an in-memory ``bytes`` buffer.""" + + def __init__(self, data: bytes) -> None: + self._data = data + self.closed = False + + def media_type(self) -> MediaType | None: + return None + + def content_length(self) -> int: + return len(self._data) + + def iter_bytes(self, chunk_size: int = 64 * 1024) -> Iterator[bytes]: + yield self._data + + def close(self) -> None: + self.closed = True + + +class _PagedClient(HttpClient): + """Maps the ``cursor`` query value to a canned JSON page body.""" + + def __init__(self, pages: dict[str | None, dict[str, object]]) -> None: + self._pages = pages + self.calls: list[Request] = [] + + def execute(self, request: Request) -> Response: + self.calls.append(request) + cursor = request.url.query.get("cursor") + payload = self._pages[cursor] + body = _InMemoryBody(json.dumps(payload).encode("utf-8")) + return Response( + request=request, + protocol=Protocol.HTTP_1_1, + status=Status.OK, + body=body, + ) + + +def test_paginator_iterates_items_across_all_pages() -> None: + client = _PagedClient( + { + None: {"data": [1, 2], "next_cursor": "c1"}, + "c1": {"data": [3, 4], "next_cursor": "c2"}, + "c2": {"data": [5], "next_cursor": None}, + } + ) + strategy: CursorStrategy[int] = CursorStrategy( + items_field="data", + cursor_response_field="next_cursor", + cursor_param="cursor", + ) + first = Request(method=Method.GET, url=Url.parse("https://api.example.com/items")) + paginator: Paginator[int] = Paginator(client.execute, strategy, first) + + assert list(paginator) == [1, 2, 3, 4, 5] + # One transport call per page; the cursor was threaded onto each follow-up. + assert len(client.calls) == 3 + assert client.calls[1].url.query.get("cursor") == "c1" + assert client.calls[2].url.query.get("cursor") == "c2" + + +# ----- webhook verifier accepts valid / rejects tampered ------------------ + +_WEBHOOK_RAW_KEY = b"conformance-webhook-signing-key-0" +_WEBHOOK_SECRET = "whsec_" + base64.b64encode(_WEBHOOK_RAW_KEY).decode("ascii") +_WEBHOOK_ID = "msg_conformance_0001" +_WEBHOOK_TIMESTAMP = "1700000000" +_WEBHOOK_BODY = b'{"event":"order.created","id":"ord_42"}' + + +def _sign_webhook(body: bytes) -> str: + content = f"{_WEBHOOK_ID}.{_WEBHOOK_TIMESTAMP}.".encode() + body + digest = hmac.new(_WEBHOOK_RAW_KEY, content, hashlib.sha256).digest() + return base64.b64encode(digest).decode("ascii") + + +def _webhook_headers(signature: str) -> dict[str, str]: + return { + "webhook-id": _WEBHOOK_ID, + "webhook-timestamp": _WEBHOOK_TIMESTAMP, + "webhook-signature": f"v1,{signature}", + } + + +def _webhook_verifier() -> WebhookVerifier: + return WebhookVerifier(_WEBHOOK_SECRET, clock=FakeClock(start=float(_WEBHOOK_TIMESTAMP))) + + +def test_webhook_verifier_accepts_a_valid_signature() -> None: + headers = _webhook_headers(_sign_webhook(_WEBHOOK_BODY)) + payload = _webhook_verifier().unwrap(headers, _WEBHOOK_BODY) + assert payload == {"event": "order.created", "id": "ord_42"} + + +def test_webhook_verifier_rejects_a_tampered_signature() -> None: + signature = _sign_webhook(_WEBHOOK_BODY) + tampered_body = _WEBHOOK_BODY.replace(b"ord_42", b"ord_99") + with pytest.raises(InvalidWebhookSignatureError, match="no matching signature"): + _webhook_verifier().verify(_webhook_headers(signature), tampered_body) diff --git a/packages/dexpace-sdk-core/tests/test_public_surface.py b/packages/dexpace-sdk-core/tests/test_public_surface.py new file mode 100644 index 0000000..417e9c2 --- /dev/null +++ b/packages/dexpace-sdk-core/tests/test_public_surface.py @@ -0,0 +1,123 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Public-API surface baseline test (Q1). + +The project convention is a narrow, deliberate public API — an accurate +``__all__`` per package and stable Protocol/class signatures. Nothing +mechanically caught an accidental leak (a private helper sneaking into +``__all__``, a Protocol method signature changing, a ``with_*`` helper +disappearing) until this test. + +The live surface is rebuilt by static analysis (no imports, no execution) via +``tools/surface_snapshot.py`` and compared against the committed baseline at +``tools/surface_baseline.json``. Any unexpected change fails the build. + +Regenerating the baseline (only after an *intentional* public-API change, and +review the diff before committing it): + + python tools/surface_snapshot.py --write + +A second test guards the baseline itself: it must be valid JSON whose top-level +keys are exactly the five distributions, so a corrupt or truncated baseline can +never silently disable the gate. +""" + +from __future__ import annotations + +import importlib.util +import json +from pathlib import Path +from types import ModuleType + +import pytest + +# ``tools`` is not an installed package, so load the snapshot module by path. +_REPO_ROOT = Path(__file__).resolve().parents[3] +_TOOL_PATH = _REPO_ROOT / "tools" / "surface_snapshot.py" +_BASELINE_PATH = _REPO_ROOT / "tools" / "surface_baseline.json" + +_EXPECTED_DISTRIBUTIONS = frozenset( + { + "dexpace-sdk-core", + "dexpace-sdk-http-stdlib", + "dexpace-sdk-http-httpx", + "dexpace-sdk-http-aiohttp", + "dexpace-sdk-http-requests", + } +) + + +def _load_snapshot_tool() -> ModuleType: + """Import ``tools/surface_snapshot.py`` by file path. + + Returns: + The loaded ``surface_snapshot`` module. + + Raises: + ImportError: If the module spec cannot be created from the tool path. + """ + spec = importlib.util.spec_from_file_location("_surface_snapshot", _TOOL_PATH) + if spec is None or spec.loader is None: + raise ImportError(f"cannot load surface snapshot tool from {_TOOL_PATH}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +@pytest.fixture(scope="module") +def snapshot_tool() -> ModuleType: + """Provide the loaded ``surface_snapshot`` tool module once per module.""" + return _load_snapshot_tool() + + +def _load_baseline() -> dict[str, dict[str, object]]: + """Read and parse the committed baseline JSON.""" + parsed: dict[str, dict[str, object]] = json.loads(_BASELINE_PATH.read_text(encoding="utf-8")) + return parsed + + +def test_live_surface_matches_committed_baseline(snapshot_tool: ModuleType) -> None: + live = snapshot_tool.build_surface(_REPO_ROOT) + baseline = _load_baseline() + assert live == baseline, ( + "Public API surface drifted from the committed baseline. If this change " + "is intentional, review the diff and regenerate with: " + "python tools/surface_snapshot.py --write" + ) + + +def test_baseline_is_canonical_and_round_trips(snapshot_tool: ModuleType) -> None: + # The committed file must match the tool's own canonical rendering exactly, + # so a hand-edited or non-canonical baseline (different key order, missing + # trailing newline) is rejected rather than silently accepted. + baseline = _load_baseline() + rendered = snapshot_tool.render(baseline) + on_disk = _BASELINE_PATH.read_text(encoding="utf-8") + assert rendered == on_disk, ( + "Baseline JSON is not in canonical form. Regenerate it with: " + "python tools/surface_snapshot.py --write" + ) + + +def test_baseline_covers_exactly_the_five_distributions() -> None: + baseline = _load_baseline() + assert set(baseline) == _EXPECTED_DISTRIBUTIONS + + +def test_every_distribution_has_exports_and_definitions() -> None: + baseline = _load_baseline() + for dist, surface in baseline.items(): + assert "exports" in surface, f"{dist} baseline is missing the exports section" + assert "definitions" in surface, f"{dist} baseline is missing the definitions section" + assert surface["definitions"], f"{dist} baseline has no public definitions" + + +def test_core_init_packages_declare_all(snapshot_tool: ModuleType) -> None: + # Every re-exporting subpackage of core must declare a static ``__all__``; + # the snapshot tool only records packages that do, so a populated exports + # map for core is the signal that the convention is being followed. + live = snapshot_tool.build_surface(_REPO_ROOT) + core_exports = live["dexpace-sdk-core"]["exports"] + assert "dexpace.sdk.core.http.request" in core_exports + assert "RequestBody" in core_exports["dexpace.sdk.core.http.request"] diff --git a/packages/dexpace-sdk-core/tests/test_serialization_snapshots.py b/packages/dexpace-sdk-core/tests/test_serialization_snapshots.py new file mode 100644 index 0000000..64c68d5 --- /dev/null +++ b/packages/dexpace-sdk-core/tests/test_serialization_snapshots.py @@ -0,0 +1,129 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Serialization snapshot tests (Q3). + +Wire-format output drifts silently — a change in header canonicalization, URL +serialization, or body chunk boundaries can slip through review unnoticed. +These tests pin the exact bytes / strings a known request serializes to. Each +snapshot is deliberately narrow and paired with a behavioral assertion so a +failure points at *which* contract moved, not merely that some string changed. +""" + +from __future__ import annotations + +from dexpace.sdk.core.http.common import Headers, QueryParams, Url +from dexpace.sdk.core.http.request import Method, Request, RequestBody + +# ----- Headers canonicalization ------------------------------------------- + + +def test_headers_canonicalise_names_to_lower_case() -> None: + headers = Headers( + [ + ("Content-Type", "application/json"), + ("X-Trace-Id", "abc123"), + ] + ) + # Names are stored lower-cased; values are preserved verbatim. The exact + # ``items()`` tuple is the wire-canonical snapshot. + assert headers.items() == ( + ("content-type", ("application/json",)), + ("x-trace-id", ("abc123",)), + ) + # Behavioral pairing: the canonicalized form is still looked up by any case. + assert headers.get("CONTENT-TYPE") == "application/json" + + +def test_headers_preserve_multi_value_order() -> None: + headers = Headers([("Set-Cookie", ["a=1", "b=2"])]) + assert headers.items() == (("set-cookie", ("a=1", "b=2")),) + # Behavioral pairing: ``get`` returns the first value, ``values`` all of them. + assert headers.get("set-cookie") == "a=1" + assert headers.values("Set-Cookie") == ("a=1", "b=2") + + +def test_headers_repr_redacts_sensitive_values() -> None: + headers = Headers([("Authorization", "Bearer secret-token"), ("Accept", "*/*")]) + # The repr snapshot must never leak the credential. + assert repr(headers) == "Headers({'authorization': ['[REDACTED]'], 'accept': ['*/*']})" + + +# ----- URL serialization -------------------------------------------------- + + +def test_url_round_trips_to_wire_form() -> None: + raw = "https://api.example.com/v1/users?b=2&a=1&a=3#frag" + url = Url.parse(raw) + assert str(url) == raw + # Behavioral pairing: query order and multiplicity survive the parse. + assert url.query.flatten() == (("b", "2"), ("a", "1"), ("a", "3")) + + +def test_url_serialization_encodes_space_and_ampersand_in_query() -> None: + url = Url( + scheme="https", + host="api.example.com", + path="/search", + query=QueryParams({"q": "a b", "tag": ["x", "y"]}), + ) + assert str(url) == "https://api.example.com/search?q=a+b&tag=x&tag=y" + # ``QueryParams.encode`` uses percent-encoding for the space (RFC 3986), + # whereas the full-URL serializer renders the form-style ``+``. + assert url.query.encode() == "q=a%20b&tag=x&tag=y" + + +def test_url_serialization_preserves_explicit_port() -> None: + url = Url.parse("https://api.example.com:8443/p?x=1") + assert str(url) == "https://api.example.com:8443/p?x=1" + assert url.port == 8443 + + +def test_url_str_redacts_userinfo_but_wire_form_keeps_it() -> None: + url = Url.parse("https://user:pw@host.example/secret") + # ``str`` drops credentials to avoid leaking them through logs ... + assert str(url) == "https://host.example/secret" + # ... while ``wire_form`` keeps them for an actual request line. + assert url.wire_form() == "https://user:pw@host.example/secret" + + +# ----- body chunking ------------------------------------------------------ + + +def test_bytes_body_chunks_on_exact_boundaries() -> None: + body = RequestBody.from_bytes(b"0123456789abcdef") + assert list(body.iter_bytes(4)) == [b"0123", b"4567", b"89ab", b"cdef"] + # A non-divisor chunk size yields a short final chunk. + assert list(body.iter_bytes(5)) == [b"01234", b"56789", b"abcde", b"f"] + # Behavioral pairing: chunking is non-destructive — content_length and the + # joined bytes are stable across reads (the body is replayable). + assert body.content_length() == 16 + assert b"".join(body.iter_bytes(4)) == b"0123456789abcdef" + + +def test_string_body_encodes_utf8_and_reports_byte_length() -> None: + body = RequestBody.from_string("héllo") + assert b"".join(body.iter_bytes()) == b"h\xc3\xa9llo" + # ``content_length`` is the byte count, not the character count. + assert body.content_length() == 6 + + +def test_form_body_url_encodes_fields_and_sets_media_type() -> None: + body = RequestBody.from_form({"name": "a b", "q": "x&y"}) + assert b"".join(body.iter_bytes()) == b"name=a%20b&q=x%26y" + media_type = body.media_type() + assert media_type is not None + assert str(media_type) == "application/x-www-form-urlencoded" + + +def test_request_carries_its_serialized_components() -> None: + request = Request( + method=Method.POST, + url=Url.parse("https://api.example.com/orders?dry_run=true"), + headers=Headers({"Content-Type": "application/json"}), + body=RequestBody.from_string('{"id":1}'), + ) + assert str(request.url) == "https://api.example.com/orders?dry_run=true" + assert request.headers.get("content-type") == "application/json" + assert request.body is not None + assert b"".join(request.body.iter_bytes()) == b'{"id":1}' diff --git a/packages/dexpace-sdk-core/tests/webhooks/__init__.py b/packages/dexpace-sdk-core/tests/webhooks/__init__.py new file mode 100644 index 0000000..a69f5b7 --- /dev/null +++ b/packages/dexpace-sdk-core/tests/webhooks/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. diff --git a/packages/dexpace-sdk-core/tests/webhooks/test_verification.py b/packages/dexpace-sdk-core/tests/webhooks/test_verification.py new file mode 100644 index 0000000..207d3c0 --- /dev/null +++ b/packages/dexpace-sdk-core/tests/webhooks/test_verification.py @@ -0,0 +1,247 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Tests for Standard Webhooks signature verification.""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import json +from collections.abc import Mapping + +import pytest + +from dexpace.sdk.core.http.webhooks import ( + DEFAULT_TOLERANCE_SECONDS, + InvalidWebhookSignatureError, + WebhookVerifier, +) + +# A known whsec_ secret and its raw base64 body. The key is "supersecret" +# base64-encoded so the test vectors are reproducible by hand if needed. +_RAW_KEY = b"supersecret-hmac-key-bytes-0123" +_SECRET = "whsec_" + base64.b64encode(_RAW_KEY).decode("ascii") + +_WEBHOOK_ID = "msg_2KWPBgLlAfxdpx2AI54pPJ85f4W" +_TIMESTAMP = "1690000000" +_BODY = b'{"event":"payment.succeeded","amount":4200}' + + +class _FixedClock: + """Minimal stationary clock pinned to a single wall-clock instant.""" + + __slots__ = ("_t",) + + def __init__(self, t: float) -> None: + self._t = t + + def now(self) -> float: + return self._t + + def monotonic(self) -> float: + return self._t + + def sleep(self, duration: float) -> None: # pragma: no cover - unused + raise AssertionError("verification must not sleep") + + +def _sign(key: bytes, webhook_id: str, timestamp: str, body: bytes) -> str: + content = f"{webhook_id}.{timestamp}.".encode() + body + digest = hmac.new(key, content, hashlib.sha256).digest() + return base64.b64encode(digest).decode("ascii") + + +def _headers(signature: str, *, timestamp: str = _TIMESTAMP) -> dict[str, str]: + return { + "webhook-id": _WEBHOOK_ID, + "webhook-timestamp": timestamp, + "webhook-signature": f"v1,{signature}", + } + + +def _verifier( + *, + t: float = 1690000000.0, + tolerance: int = DEFAULT_TOLERANCE_SECONDS, +) -> WebhookVerifier: + return WebhookVerifier(_SECRET, tolerance_seconds=tolerance, clock=_FixedClock(t)) + + +def test_verify_accepts_a_valid_signature() -> None: + signature = _sign(_RAW_KEY, _WEBHOOK_ID, _TIMESTAMP, _BODY) + _verifier().verify(_headers(signature), _BODY) + + +def test_verify_accepts_a_string_body() -> None: + body_str = _BODY.decode() + signature = _sign(_RAW_KEY, _WEBHOOK_ID, _TIMESTAMP, _BODY) + _verifier().verify(_headers(signature), body_str) + + +def test_verify_is_case_insensitive_for_header_names() -> None: + signature = _sign(_RAW_KEY, _WEBHOOK_ID, _TIMESTAMP, _BODY) + headers = { + "Webhook-Id": _WEBHOOK_ID, + "Webhook-Timestamp": _TIMESTAMP, + "Webhook-Signature": f"v1,{signature}", + } + _verifier().verify(headers, _BODY) + + +def test_verify_rejects_a_tampered_body() -> None: + signature = _sign(_RAW_KEY, _WEBHOOK_ID, _TIMESTAMP, _BODY) + tampered = _BODY.replace(b"4200", b"9999") + with pytest.raises(InvalidWebhookSignatureError, match="no matching signature"): + _verifier().verify(_headers(signature), tampered) + + +def test_verify_rejects_a_tampered_webhook_id() -> None: + signature = _sign(_RAW_KEY, _WEBHOOK_ID, _TIMESTAMP, _BODY) + headers = _headers(signature) + headers["webhook-id"] = "msg_attacker_swapped_this" + with pytest.raises(InvalidWebhookSignatureError, match="no matching signature"): + _verifier().verify(headers, _BODY) + + +def test_verify_rejects_signature_made_with_a_different_secret() -> None: + signature = _sign(b"a-completely-different-key", _WEBHOOK_ID, _TIMESTAMP, _BODY) + with pytest.raises(InvalidWebhookSignatureError, match="no matching signature"): + _verifier().verify(_headers(signature), _BODY) + + +def test_verify_accepts_when_one_of_several_signatures_matches() -> None: + good = _sign(_RAW_KEY, _WEBHOOK_ID, _TIMESTAMP, _BODY) + bad = _sign(b"rotated-out-old-key", _WEBHOOK_ID, _TIMESTAMP, _BODY) + headers = { + "webhook-id": _WEBHOOK_ID, + "webhook-timestamp": _TIMESTAMP, + "webhook-signature": f"v1,{bad} v1,{good}", + } + _verifier().verify(headers, _BODY) + + +def test_verify_rejects_when_no_signature_in_the_set_matches() -> None: + bad1 = _sign(b"old-key-1", _WEBHOOK_ID, _TIMESTAMP, _BODY) + bad2 = _sign(b"old-key-2", _WEBHOOK_ID, _TIMESTAMP, _BODY) + headers = { + "webhook-id": _WEBHOOK_ID, + "webhook-timestamp": _TIMESTAMP, + "webhook-signature": f"v1,{bad1} v1,{bad2}", + } + with pytest.raises(InvalidWebhookSignatureError, match="no matching signature"): + _verifier().verify(headers, _BODY) + + +def test_verify_skips_unknown_version_tokens_and_still_matches_v1() -> None: + good = _sign(_RAW_KEY, _WEBHOOK_ID, _TIMESTAMP, _BODY) + headers = { + "webhook-id": _WEBHOOK_ID, + "webhook-timestamp": _TIMESTAMP, + "webhook-signature": f"v2,futurescheme v1,{good}", + } + _verifier().verify(headers, _BODY) + + +@pytest.mark.parametrize( + "header", + ["webhook-id", "webhook-timestamp", "webhook-signature"], + ids=["missing_id", "missing_timestamp", "missing_signature"], +) +def test_verify_rejects_when_a_required_header_is_missing(header: str) -> None: + signature = _sign(_RAW_KEY, _WEBHOOK_ID, _TIMESTAMP, _BODY) + headers = _headers(signature) + del headers[header] + with pytest.raises(InvalidWebhookSignatureError, match=f"missing required header: {header}"): + _verifier().verify(headers, _BODY) + + +def test_verify_rejects_a_timestamp_older_than_tolerance() -> None: + signature = _sign(_RAW_KEY, _WEBHOOK_ID, _TIMESTAMP, _BODY) + # Now is just past the tolerance window after the signed instant. + now = int(_TIMESTAMP) + DEFAULT_TOLERANCE_SECONDS + 1 + with pytest.raises(InvalidWebhookSignatureError, match="too old"): + _verifier(t=now).verify(_headers(signature), _BODY) + + +def test_verify_rejects_a_timestamp_too_far_in_the_future() -> None: + signature = _sign(_RAW_KEY, _WEBHOOK_ID, _TIMESTAMP, _BODY) + now = int(_TIMESTAMP) - DEFAULT_TOLERANCE_SECONDS - 1 + with pytest.raises(InvalidWebhookSignatureError, match="in the future"): + _verifier(t=now).verify(_headers(signature), _BODY) + + +def test_verify_accepts_a_timestamp_at_the_edge_of_the_window() -> None: + signature = _sign(_RAW_KEY, _WEBHOOK_ID, _TIMESTAMP, _BODY) + now = int(_TIMESTAMP) + DEFAULT_TOLERANCE_SECONDS + _verifier(t=now).verify(_headers(signature), _BODY) + + +def test_verify_rejects_a_non_integer_timestamp() -> None: + signature = _sign(_RAW_KEY, _WEBHOOK_ID, _TIMESTAMP, _BODY) + headers = _headers(signature, timestamp="not-a-number") + with pytest.raises(InvalidWebhookSignatureError, match="malformed webhook-timestamp"): + _verifier().verify(headers, _BODY) + + +def test_whsec_prefix_is_stripped_and_base64_decoded() -> None: + # A verifier built from the whsec_-prefixed secret produces the same result + # as signing with the raw decoded key — i.e. the prefix was stripped and + # the body base64-decoded to recover exactly _RAW_KEY. + signature = _sign(_RAW_KEY, _WEBHOOK_ID, _TIMESTAMP, _BODY) + _verifier().verify(_headers(signature), _BODY) + + +def test_secret_without_whsec_prefix_is_accepted() -> None: + raw_b64 = base64.b64encode(_RAW_KEY).decode("ascii") + verifier = WebhookVerifier(raw_b64, clock=_FixedClock(1690000000.0)) + signature = _sign(_RAW_KEY, _WEBHOOK_ID, _TIMESTAMP, _BODY) + verifier.verify(_headers(signature), _BODY) + + +def test_malformed_secret_raises_at_construction() -> None: + with pytest.raises(InvalidWebhookSignatureError, match="malformed webhook secret"): + WebhookVerifier("whsec_not!valid!base64!") + + +def test_negative_tolerance_is_rejected() -> None: + with pytest.raises(ValueError, match="non-negative"): + WebhookVerifier(_SECRET, tolerance_seconds=-1) + + +def test_unwrap_returns_parsed_json_on_success() -> None: + signature = _sign(_RAW_KEY, _WEBHOOK_ID, _TIMESTAMP, _BODY) + payload = _verifier().unwrap(_headers(signature), _BODY) + assert payload == {"event": "payment.succeeded", "amount": 4200} + + +def test_unwrap_parses_the_exact_verified_bytes() -> None: + body = json.dumps({"nested": {"k": [1, 2, 3]}, "flag": True}).encode() + signature = _sign(_RAW_KEY, _WEBHOOK_ID, _TIMESTAMP, body) + payload = _verifier().unwrap(_headers(signature), body) + assert payload == {"nested": {"k": [1, 2, 3]}, "flag": True} + + +def test_unwrap_does_not_parse_an_unverified_body() -> None: + signature = _sign(_RAW_KEY, _WEBHOOK_ID, _TIMESTAMP, _BODY) + tampered = b'{"event":"tampered"}' + with pytest.raises(InvalidWebhookSignatureError, match="no matching signature"): + _verifier().unwrap(_headers(signature), tampered) + + +def test_unwrap_rejects_a_verified_but_non_json_body() -> None: + body = b"this is signed but is not json" + signature = _sign(_RAW_KEY, _WEBHOOK_ID, _TIMESTAMP, body) + with pytest.raises(InvalidWebhookSignatureError, match="not valid JSON"): + _verifier().unwrap(_headers(signature), body) + + +def test_invalid_signature_error_is_a_value_error() -> None: + assert issubclass(InvalidWebhookSignatureError, ValueError) + + +def test_verifier_accepts_a_generic_mapping() -> None: + signature = _sign(_RAW_KEY, _WEBHOOK_ID, _TIMESTAMP, _BODY) + headers: Mapping[str, str] = _headers(signature) + _verifier().verify(headers, _BODY) diff --git a/pyproject.toml b/pyproject.toml index 2d69148..9c1082a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,10 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "**/tests/**" = ["B011", "RUF012"] +# HttpTracer is an adapter-style base: every event method ships a no-op default +# so subclasses override only the events they consume. B024/B027 flag exactly +# that intentional pattern. +"**/instrumentation/http_tracer.py" = ["B024", "B027"] [tool.mypy] python_version = "3.12" diff --git a/tools/surface_baseline.json b/tools/surface_baseline.json new file mode 100644 index 0000000..68f5564 --- /dev/null +++ b/tools/surface_baseline.json @@ -0,0 +1,1727 @@ +{ + "dexpace-sdk-core": { + "definitions": { + "dexpace.sdk.core.client.async_http_client": { + "AsyncHttpClient": { + "bases": [ + "Protocol" + ], + "methods": { + "execute": "async execute(self, request: Request) -> AsyncResponse" + } + }, + "asyncio_sleep": "async asyncio_sleep(duration: float) -> None" + }, + "dexpace.sdk.core.client.http_client": { + "HttpClient": { + "bases": [ + "Protocol" + ], + "methods": { + "execute": "execute(self, request: Request) -> Response" + } + } + }, + "dexpace.sdk.core.config.configuration": { + "Configuration": { + "bases": [], + "methods": { + "builder": "builder(cls) -> ConfigurationBuilder", + "get": "get(self, name: str, default: str | None) -> str | None", + "get_bool": "get_bool(self, name: str, default: bool) -> bool", + "get_duration": "get_duration(self, name: str, default_seconds: float) -> float", + "get_int": "get_int(self, name: str, default: int) -> int" + } + }, + "ConfigurationBuilder": { + "bases": [], + "methods": { + "build": "build(self) -> Configuration", + "env": "env(self, source: EnvSource) -> Self", + "put": "put(self, name: str, value: str) -> Self" + } + } + }, + "dexpace.sdk.core.errors.base": { + "SdkError": { + "bases": [ + "Exception" + ], + "methods": {} + }, + "ServiceRequestError": { + "bases": [ + "SdkError" + ], + "methods": {} + }, + "ServiceRequestTimeoutError": { + "bases": [ + "ServiceRequestError" + ], + "methods": {} + }, + "ServiceResponseError": { + "bases": [ + "SdkError" + ], + "methods": {} + }, + "ServiceResponseTimeoutError": { + "bases": [ + "ServiceResponseError" + ], + "methods": {} + } + }, + "dexpace.sdk.core.errors.error_map": { + "map_error": "map_error(status_code: int, response: Response, error_map: Mapping[int, type[HttpResponseError]] | None) -> None" + }, + "dexpace.sdk.core.errors.http": { + "ClientAuthenticationError": { + "bases": [ + "HttpResponseError[ModelT]" + ], + "methods": {} + }, + "DecodeError": { + "bases": [ + "HttpResponseError[ModelT]" + ], + "methods": {} + }, + "HttpResponseError": { + "bases": [ + "Generic[ModelT]", + "SdkError" + ], + "methods": { + "body_snapshot": "body_snapshot(self, max_bytes: int | None) -> bytes" + } + }, + "ResourceExistsError": { + "bases": [ + "HttpResponseError[ModelT]" + ], + "methods": {} + }, + "ResourceModifiedError": { + "bases": [ + "HttpResponseError[ModelT]" + ], + "methods": {} + }, + "ResourceNotFoundError": { + "bases": [ + "HttpResponseError[ModelT]" + ], + "methods": {} + }, + "ResourceNotModifiedError": { + "bases": [ + "HttpResponseError[ModelT]" + ], + "methods": {} + } + }, + "dexpace.sdk.core.errors.pipeline": { + "PipelineAbortedError": { + "bases": [ + "SdkError" + ], + "methods": {} + } + }, + "dexpace.sdk.core.errors.serialization": { + "DeserializationError": { + "bases": [ + "SdkError", + "ValueError" + ], + "methods": {} + }, + "SerializationError": { + "bases": [ + "SdkError", + "ValueError" + ], + "methods": {} + } + }, + "dexpace.sdk.core.errors.streaming": { + "ResponseNotReadError": { + "bases": [ + "SdkError" + ], + "methods": {} + }, + "StreamClosedError": { + "bases": [ + "SdkError" + ], + "methods": {} + }, + "StreamConsumedError": { + "bases": [ + "SdkError" + ], + "methods": {} + }, + "StreamingError": { + "bases": [ + "SdkError" + ], + "methods": {} + } + }, + "dexpace.sdk.core.http.auth.access_token": { + "AccessTokenInfo": { + "bases": [], + "methods": { + "is_expired": "is_expired(self, *, now: float | None) -> bool", + "needs_refresh": "needs_refresh(self, *, now: float | None, leeway_seconds: int, clock: Clock | AsyncClock | None) -> bool" + } + }, + "TokenRequestOptions": { + "bases": [ + "TypedDict" + ], + "methods": {} + } + }, + "dexpace.sdk.core.http.auth.challenge": { + "AuthenticateChallenge": { + "bases": [], + "methods": {} + }, + "parse_challenges": "parse_challenges(header_value: str) -> list[AuthenticateChallenge]" + }, + "dexpace.sdk.core.http.auth.challenge_handler": { + "BasicChallengeHandler": { + "bases": [], + "methods": { + "can_handle": "can_handle(self, challenges: list[AuthenticateChallenge]) -> bool", + "handle": "handle(self, method: Method, url: Url, challenges: list[AuthenticateChallenge], *, is_proxy: bool) -> tuple[str, str] | None" + } + }, + "ChallengeHandler": { + "bases": [ + "Protocol" + ], + "methods": { + "can_handle": "can_handle(self, challenges: list[AuthenticateChallenge]) -> bool", + "handle": "handle(self, method: Method, url: Url, challenges: list[AuthenticateChallenge], *, is_proxy: bool) -> tuple[str, str] | None" + } + }, + "CompositeChallengeHandler": { + "bases": [], + "methods": { + "can_handle": "can_handle(self, challenges: list[AuthenticateChallenge]) -> bool", + "handle": "handle(self, method: Method, url: Url, challenges: list[AuthenticateChallenge], *, is_proxy: bool) -> tuple[str, str] | None" + } + } + }, + "dexpace.sdk.core.http.auth.credentials": { + "AsyncTokenCredential": { + "bases": [ + "Protocol" + ], + "methods": { + "close": "async close(self) -> None", + "get_token_info": "async get_token_info(self, *scopes: str, options: TokenRequestOptions | None) -> AccessTokenInfo" + } + }, + "BasicAuthCredential": { + "bases": [], + "methods": { + "encoded": "encoded(self) -> str" + } + }, + "KeyCredential": { + "bases": [], + "methods": { + "key": "key(self) -> str", + "update": "update(self, key: str) -> None" + } + }, + "NamedKeyCredential": { + "bases": [], + "methods": { + "key": "key(self) -> str", + "name": "name(self) -> str", + "update": "update(self, name: str, key: str) -> None" + } + }, + "TokenCredential": { + "bases": [ + "Protocol" + ], + "methods": { + "close": "close(self) -> None", + "get_token_info": "get_token_info(self, *scopes: str, options: TokenRequestOptions | None) -> AccessTokenInfo" + } + } + }, + "dexpace.sdk.core.http.auth.digest": { + "DigestChallengeHandler": { + "bases": [], + "methods": { + "can_handle": "can_handle(self, challenges: list[AuthenticateChallenge]) -> bool", + "handle": "handle(self, method: Method, url: Url, challenges: list[AuthenticateChallenge], *, is_proxy: bool) -> tuple[str, str] | None" + } + } + }, + "dexpace.sdk.core.http.auth.policies": { + "AsyncBearerTokenPolicy": { + "bases": [ + "AsyncPolicy" + ], + "methods": { + "on_challenge": "async on_challenge(self, request: Request, response: AsyncResponse) -> bool", + "send": "async send(self, request: Request, ctx: PipelineContext) -> AsyncResponse" + } + }, + "BasicAuthPolicy": { + "bases": [ + "Policy" + ], + "methods": { + "send": "send(self, request: Request, ctx: PipelineContext) -> Response" + } + }, + "BearerTokenPolicy": { + "bases": [ + "Policy" + ], + "methods": { + "on_challenge": "on_challenge(self, request: Request, response: Response) -> bool", + "send": "send(self, request: Request, ctx: PipelineContext) -> Response" + } + }, + "KeyCredentialPolicy": { + "bases": [ + "Policy" + ], + "methods": { + "send": "send(self, request: Request, ctx: PipelineContext) -> Response" + } + } + }, + "dexpace.sdk.core.http.auth.token_cache": { + "InMemoryTokenCache": { + "bases": [], + "methods": { + "clear": "clear(self) -> None", + "get": "get(self, scopes: Sequence[str], audience: str | None) -> AccessTokenInfo | None", + "set": "set(self, scopes: Sequence[str], token: AccessTokenInfo, audience: str | None) -> None" + } + }, + "TokenCache": { + "bases": [ + "Protocol" + ], + "methods": { + "clear": "clear(self) -> None", + "get": "get(self, scopes: Sequence[str], audience: str | None) -> AccessTokenInfo | None", + "set": "set(self, scopes: Sequence[str], token: AccessTokenInfo, audience: str | None) -> None" + } + } + }, + "dexpace.sdk.core.http.common.etag": { + "ETag": { + "bases": [], + "methods": { + "matches_strong": "matches_strong(self, other: ETag) -> bool", + "matches_weak": "matches_weak(self, other: ETag) -> bool", + "parse": "parse(cls, raw: str) -> Self" + } + } + }, + "dexpace.sdk.core.http.common.headers": { + "Headers": { + "bases": [], + "methods": { + "empty": "empty(cls) -> Headers", + "get": "get(self, name: _Name, default: str | None) -> str | None", + "items": "items(self) -> tuple[tuple[str, tuple[str, ...]], ...]", + "names": "names(self) -> tuple[str, ...]", + "values": "values(self, name: _Name) -> tuple[str, ...]", + "with_added": "with_added(self, name: _Name, value: str) -> Self", + "with_merged": "with_merged(self, other: Headers) -> Self", + "with_set": "with_set(self, name: _Name, *values: str) -> Self", + "without": "without(self, name: _Name) -> Self" + } + } + }, + "dexpace.sdk.core.http.common.http_header_name": { + "HttpHeaderName": { + "bases": [], + "methods": { + "of": "of(cls, canonical_name: str) -> Self" + } + } + }, + "dexpace.sdk.core.http.common.http_range": { + "HttpRange": { + "bases": [], + "methods": { + "end": "end(self) -> int | None", + "format": "format(self) -> str", + "format_many": "format_many(cls, ranges: Sequence[HttpRange]) -> str", + "suffix": "suffix(cls, count: int) -> _SuffixHttpRange", + "to_header_value": "to_header_value(self) -> str" + } + } + }, + "dexpace.sdk.core.http.common.media_type": { + "MediaType": { + "bases": [], + "methods": { + "charset": "charset(self) -> str | None", + "full_type": "full_type(self) -> str", + "includes": "includes(self, other: MediaType) -> bool", + "of": "of(cls, type: str, subtype: str, parameters: Mapping[str, str] | None) -> Self", + "parse": "parse(cls, value: str) -> Self" + } + } + }, + "dexpace.sdk.core.http.common.pagination": { + "AsyncItemPaged": { + "bases": [ + "AsyncIterator[T]" + ], + "methods": { + "by_page": "by_page(self, continuation_token: str | None) -> AsyncIterator[AsyncIterator[T]]" + } + }, + "AsyncPager": { + "bases": [ + "AsyncIterator[AsyncIterator[T]]" + ], + "methods": {} + }, + "ItemPaged": { + "bases": [ + "Iterator[T]" + ], + "methods": { + "by_page": "by_page(self, continuation_token: str | None) -> Iterator[Iterator[T]]" + } + }, + "Pager": { + "bases": [ + "Iterator[Iterator[T]]" + ], + "methods": {} + } + }, + "dexpace.sdk.core.http.common.protocol": { + "Protocol": { + "bases": [ + "StrEnum" + ], + "methods": { + "parse": "parse(cls, value: str) -> Self" + } + } + }, + "dexpace.sdk.core.http.common.request_conditions": { + "RequestConditions": { + "bases": [], + "methods": { + "apply_to": "apply_to(self, request: Request) -> Request" + } + } + }, + "dexpace.sdk.core.http.common.streaming": { + "aiter_chunked_frame": "async aiter_chunked_frame(chunks: AsyncIterable[bytes]) -> AsyncIterator[bytes]", + "aiter_jsonl": "async aiter_jsonl(chunks: AsyncIterable[bytes]) -> AsyncIterator[Any]", + "chunked_frame": "chunked_frame(chunks: Iterable[bytes]) -> Iterator[bytes]", + "iter_jsonl": "iter_jsonl(chunks: Iterable[bytes]) -> Iterator[Any]" + }, + "dexpace.sdk.core.http.common.url": { + "QueryParams": { + "bases": [], + "methods": { + "empty": "empty(cls) -> QueryParams", + "encode": "encode(self) -> str", + "flatten": "flatten(self) -> tuple[tuple[str, str], ...]", + "get": "get(self, name: str, default: str | None) -> str | None", + "items": "items(self) -> tuple[tuple[str, tuple[str, ...]], ...]", + "parse": "parse(cls, raw: str) -> Self", + "values": "values(self, name: str) -> tuple[str, ...]", + "with_added": "with_added(self, name: str, value: str) -> Self", + "with_set": "with_set(self, name: str, *values: str) -> Self", + "without": "without(self, name: str) -> Self" + } + }, + "Url": { + "bases": [], + "methods": { + "authority": "authority(self, *, with_userinfo: bool) -> str", + "parse": "parse(cls, raw: str) -> Self", + "wire_form": "wire_form(self) -> str", + "with_fragment": "with_fragment(self, fragment: str) -> Self", + "with_path": "with_path(self, path: str) -> Self", + "with_query": "with_query(self, query: QueryParams) -> Self" + } + } + }, + "dexpace.sdk.core.http.context.call_context": { + "CallContext": { + "bases": [], + "methods": { + "close": "close(self) -> None" + } + } + }, + "dexpace.sdk.core.http.context.dispatch_context": { + "DispatchContext": { + "bases": [ + "CallContext" + ], + "methods": { + "noop": "noop(cls) -> Self", + "to_request_context": "to_request_context(self, request: Request) -> RequestContext" + } + } + }, + "dexpace.sdk.core.http.context.exchange_context": { + "ExchangeContext": { + "bases": [ + "CallContext" + ], + "methods": {} + } + }, + "dexpace.sdk.core.http.context.request_context": { + "RequestContext": { + "bases": [ + "CallContext" + ], + "methods": { + "to_exchange_context": "to_exchange_context(self, response: Response | AsyncResponse) -> ExchangeContext" + } + } + }, + "dexpace.sdk.core.http.request.async_request_body": { + "AsyncRequestBody": { + "bases": [ + "ABC" + ], + "methods": { + "aiter_bytes": "aiter_bytes(self, chunk_size: int) -> AsyncIterator[bytes]", + "content_length": "content_length(self) -> int", + "from_async_iter": "from_async_iter(cls, chunks: AsyncIterable[bytes], media_type: MediaType | None, content_length: int) -> AsyncRequestBody", + "from_async_stream": "from_async_stream(cls, stream: SupportsAsyncRead, media_type: MediaType | None, content_length: int) -> AsyncRequestBody", + "from_bytes": "from_bytes(cls, data: bytes, media_type: MediaType | None) -> AsyncRequestBody", + "from_form": "from_form(cls, fields: Mapping[str, str], encoding: str) -> AsyncRequestBody", + "from_string": "from_string(cls, content: str, media_type: MediaType | None, encoding: str) -> AsyncRequestBody", + "is_replayable": "is_replayable(self) -> bool", + "media_type": "media_type(self) -> MediaType | None", + "to_replayable": "async to_replayable(self) -> AsyncRequestBody", + "write_to": "async write_to(self, stream: SupportsAsyncWrite, chunk_size: int) -> int" + } + }, + "SupportsAsyncRead": { + "bases": [ + "Protocol" + ], + "methods": { + "close": "async close(self) -> object", + "read": "async read(self, size: int) -> bytes" + } + }, + "SupportsAsyncWrite": { + "bases": [ + "Protocol" + ], + "methods": { + "write": "async write(self, data: bytes) -> object" + } + } + }, + "dexpace.sdk.core.http.request.file_request_body": { + "FileRequestBody": { + "bases": [ + "RequestBody" + ], + "methods": { + "content_length": "content_length(self) -> int", + "count": "count(self) -> int", + "is_replayable": "is_replayable(self) -> bool", + "iter_bytes": "iter_bytes(self, chunk_size: int) -> Iterator[bytes]", + "media_type": "media_type(self) -> MediaType | None", + "offset": "offset(self) -> int", + "path": "path(self) -> Path", + "to_replayable": "to_replayable(self) -> RequestBody" + } + } + }, + "dexpace.sdk.core.http.request.loggable_request_body": { + "LoggableRequestBody": { + "bases": [ + "RequestBody" + ], + "methods": { + "captured_size": "captured_size(self) -> int", + "content_length": "content_length(self) -> int", + "inner": "inner(self) -> RequestBody", + "is_replayable": "is_replayable(self) -> bool", + "iter_bytes": "iter_bytes(self, chunk_size: int) -> Iterator[bytes]", + "media_type": "media_type(self) -> MediaType | None", + "snapshot": "snapshot(self, max_bytes: int | None) -> bytes", + "to_replayable": "to_replayable(self) -> RequestBody" + } + } + }, + "dexpace.sdk.core.http.request.method": { + "Method": { + "bases": [ + "StrEnum" + ], + "methods": {} + } + }, + "dexpace.sdk.core.http.request.multipart": { + "MultipartField": { + "bases": [], + "methods": { + "with_utf8_filename": "with_utf8_filename(cls, *, name: str, value: bytes | str, filename: str, media_type: MediaType | None, headers: Sequence[tuple[str, str]], ascii_fallback: str) -> Self" + } + }, + "MultipartRequestBody": { + "bases": [ + "RequestBody" + ], + "methods": { + "boundary": "boundary(self) -> str", + "content_length": "content_length(self) -> int", + "is_replayable": "is_replayable(self) -> bool", + "iter_bytes": "iter_bytes(self, chunk_size: int) -> Iterator[bytes]", + "media_type": "media_type(self) -> MediaType | None", + "to_replayable": "to_replayable(self) -> RequestBody" + } + } + }, + "dexpace.sdk.core.http.request.request": { + "Request": { + "bases": [], + "methods": { + "with_added_header": "with_added_header(self, name: _Name, value: str) -> Self", + "with_body": "with_body(self, body: RequestBody | None) -> Self", + "with_header": "with_header(self, name: _Name, value: str) -> Self", + "with_headers": "with_headers(self, headers: Headers) -> Self", + "with_method": "with_method(self, method: Method) -> Self", + "with_url": "with_url(self, url: str | Url) -> Self", + "without_header": "without_header(self, name: _Name) -> Self" + } + } + }, + "dexpace.sdk.core.http.request.request_body": { + "RequestBody": { + "bases": [ + "ABC" + ], + "methods": { + "content_length": "content_length(self) -> int", + "from_bytes": "from_bytes(cls, data: bytes, media_type: MediaType | None) -> RequestBody", + "from_file": "from_file(cls, path: Path, media_type: MediaType | None, offset: int, count: int) -> RequestBody", + "from_form": "from_form(cls, fields: Mapping[str, str], encoding: str) -> RequestBody", + "from_iter": "from_iter(cls, chunks: Iterable[bytes], media_type: MediaType | None, content_length: int) -> RequestBody", + "from_multipart": "from_multipart(cls, fields: Sequence[MultipartField], *, boundary: str | None) -> RequestBody", + "from_stream": "from_stream(cls, stream: BinaryIO, media_type: MediaType | None, content_length: int) -> RequestBody", + "from_string": "from_string(cls, content: str, media_type: MediaType | None, encoding: str) -> RequestBody", + "is_replayable": "is_replayable(self) -> bool", + "iter_bytes": "iter_bytes(self, chunk_size: int) -> Iterator[bytes]", + "media_type": "media_type(self) -> MediaType | None", + "to_replayable": "to_replayable(self) -> RequestBody", + "write_to": "write_to(self, stream: BinaryIO, chunk_size: int) -> int" + } + } + }, + "dexpace.sdk.core.http.response.async_response": { + "AsyncResponse": { + "bases": [], + "methods": { + "close": "async close(self) -> None", + "is_client_error": "is_client_error(self) -> bool", + "is_redirect": "is_redirect(self) -> bool", + "is_server_error": "is_server_error(self) -> bool", + "is_success": "is_success(self) -> bool", + "with_body": "with_body(self, body: AsyncResponseBody | None) -> Self", + "with_header": "with_header(self, name: _Name, value: str) -> Self", + "with_headers": "with_headers(self, headers: Headers) -> Self", + "with_status": "with_status(self, status: Status) -> Self" + } + } + }, + "dexpace.sdk.core.http.response.async_response_body": { + "AsyncResponseBody": { + "bases": [ + "ABC" + ], + "methods": { + "aiter_bytes": "aiter_bytes(self, chunk_size: int) -> AsyncIterator[bytes]", + "bytes": "async bytes(self) -> _bytes", + "close": "async close(self) -> None", + "content_length": "content_length(self) -> int", + "from_async_stream": "from_async_stream(cls, stream: SupportsAsyncRead, media_type: MediaType | None, content_length: int) -> AsyncResponseBody", + "from_bytes": "from_bytes(cls, data: _bytes, media_type: MediaType | None) -> AsyncResponseBody", + "media_type": "media_type(self) -> MediaType | None", + "string": "async string(self, encoding: str | None) -> str" + } + } + }, + "dexpace.sdk.core.http.response.loggable_response_body": { + "LoggableResponseBody": { + "bases": [ + "ResponseBody" + ], + "methods": { + "captured_size": "captured_size(self) -> int", + "close": "close(self) -> None", + "content_length": "content_length(self) -> int", + "iter_bytes": "iter_bytes(self, chunk_size: int) -> Iterator[bytes]", + "media_type": "media_type(self) -> MediaType | None", + "snapshot": "snapshot(self, max_bytes: int | None) -> bytes" + } + } + }, + "dexpace.sdk.core.http.response.response": { + "Response": { + "bases": [], + "methods": { + "close": "close(self) -> None", + "is_client_error": "is_client_error(self) -> bool", + "is_redirect": "is_redirect(self) -> bool", + "is_server_error": "is_server_error(self) -> bool", + "is_success": "is_success(self) -> bool", + "with_body": "with_body(self, body: ResponseBody | None) -> Self", + "with_header": "with_header(self, name: _Name, value: str) -> Self", + "with_headers": "with_headers(self, headers: Headers) -> Self", + "with_status": "with_status(self, status: Status) -> Self" + } + } + }, + "dexpace.sdk.core.http.response.response_body": { + "ResponseBody": { + "bases": [ + "ABC" + ], + "methods": { + "bytes": "bytes(self) -> _bytes", + "close": "close(self) -> None", + "content_length": "content_length(self) -> int", + "from_bytes": "from_bytes(cls, data: _bytes, media_type: MediaType | None) -> ResponseBody", + "from_stream": "from_stream(cls, stream: BinaryIO, media_type: MediaType | None, content_length: int) -> ResponseBody", + "iter_bytes": "iter_bytes(self, chunk_size: int) -> Iterator[bytes]", + "media_type": "media_type(self) -> MediaType | None", + "string": "string(self, encoding: str | None) -> str" + } + } + }, + "dexpace.sdk.core.http.response.status": { + "Status": { + "bases": [ + "IntEnum" + ], + "methods": { + "is_client_error": "is_client_error(self) -> bool", + "is_error": "is_error(self) -> bool", + "is_informational": "is_informational(self) -> bool", + "is_redirect": "is_redirect(self) -> bool", + "is_server_error": "is_server_error(self) -> bool", + "is_success": "is_success(self) -> bool" + } + } + }, + "dexpace.sdk.core.http.sse.parser": { + "AsyncSseStream": { + "bases": [], + "methods": { + "aclose": "async aclose(self) -> None" + } + }, + "SseEvent": { + "bases": [], + "methods": {} + }, + "SseParser": { + "bases": [], + "methods": { + "drain": "drain(self) -> Iterator[SseEvent]", + "end": "end(self) -> Iterator[SseEvent]", + "feed": "feed(self, chunk: bytes) -> None" + } + }, + "parse_async_events": "parse_async_events(chunks: AsyncIterable[bytes]) -> AsyncSseStream", + "parse_events": "parse_events(chunks: Iterable[bytes]) -> Iterator[SseEvent]" + }, + "dexpace.sdk.core.http.webhooks.verification": { + "InvalidWebhookSignatureError": { + "bases": [ + "ValueError" + ], + "methods": {} + }, + "WebhookVerifier": { + "bases": [], + "methods": { + "unwrap": "unwrap(self, headers: Mapping[str, str], body: str | bytes) -> object", + "verify": "verify(self, headers: Mapping[str, str], body: str | bytes) -> None" + } + } + }, + "dexpace.sdk.core.instrumentation.client_logger": { + "ClientLogger": { + "bases": [], + "methods": { + "error": "error(self, message: str, **fields: Any) -> None", + "info": "info(self, message: str, **fields: Any) -> None", + "log": "log(self, level: LogLevel, message: str, **fields: Any) -> None", + "verbose": "verbose(self, message: str, **fields: Any) -> None", + "warning": "warning(self, message: str, **fields: Any) -> None" + } + }, + "CorrelationFilter": { + "bases": [ + "logging.Filter" + ], + "methods": { + "filter": "filter(self, record: logging.LogRecord) -> bool" + } + } + }, + "dexpace.sdk.core.instrumentation.correlation": { + "bind_correlation": "bind_correlation(*, trace_id: str | None, span_id: str | None) -> Iterator[None]", + "get_span_id": "get_span_id() -> str | None", + "get_trace_id": "get_trace_id() -> str | None", + "set_span_id": "set_span_id(value: str | None) -> Token[str | None]", + "set_trace_id": "set_trace_id(value: str | None) -> Token[str | None]" + }, + "dexpace.sdk.core.instrumentation.http_tracer": { + "HttpTracer": { + "bases": [ + "ABC" + ], + "methods": { + "attempt_failed": "attempt_failed(self, error: BaseException, next_delay: float) -> None", + "attempt_retries_exhausted": "attempt_retries_exhausted(self) -> None", + "attempt_started": "attempt_started(self, attempt: int) -> None", + "connection_acquired": "connection_acquired(self, host: str, port: int) -> None", + "operation_failed": "operation_failed(self, error: BaseException) -> None", + "operation_started": "operation_started(self) -> None", + "operation_succeeded": "operation_succeeded(self) -> None", + "request_sent": "request_sent(self, byte_count: int) -> None", + "request_url_resolved": "request_url_resolved(self, url: str) -> None", + "response_headers_received": "response_headers_received(self, status: int, headers: Mapping[str, str]) -> None", + "response_received": "response_received(self, byte_count: int) -> None" + } + }, + "HttpTracerFactory": { + "bases": [ + "Protocol" + ], + "methods": { + "create": "create(self) -> HttpTracer" + } + } + }, + "dexpace.sdk.core.instrumentation.identifiers": { + "SpanId": { + "bases": [], + "methods": {} + }, + "TraceFlags": { + "bases": [], + "methods": {} + }, + "TraceId": { + "bases": [], + "methods": {} + }, + "TraceIdType": { + "bases": [ + "StrEnum" + ], + "methods": {} + }, + "TraceState": { + "bases": [], + "methods": {} + } + }, + "dexpace.sdk.core.instrumentation.instrumentation_context": { + "InstrumentationContext": { + "bases": [], + "methods": { + "is_valid": "is_valid(self) -> bool" + } + } + }, + "dexpace.sdk.core.instrumentation.log_level": { + "LogLevel": { + "bases": [ + "Enum" + ], + "methods": {} + } + }, + "dexpace.sdk.core.instrumentation.metrics": { + "Counter": { + "bases": [ + "ABC" + ], + "methods": { + "add": "add(self, value: float, attributes: Mapping[str, str] | None) -> None" + } + }, + "Histogram": { + "bases": [ + "ABC" + ], + "methods": { + "record": "record(self, value: float, attributes: Mapping[str, str] | None) -> None" + } + }, + "MetricsContext": { + "bases": [ + "ABC" + ], + "methods": { + "counter": "counter(self, name: str, *, unit: str | None, description: str | None) -> Counter", + "histogram": "histogram(self, name: str, *, unit: str | None, description: str | None) -> Histogram", + "up_down_counter": "up_down_counter(self, name: str, *, unit: str | None, description: str | None) -> UpDownCounter" + } + }, + "UpDownCounter": { + "bases": [ + "ABC" + ], + "methods": { + "add": "add(self, value: float, attributes: Mapping[str, str] | None) -> None" + } + } + }, + "dexpace.sdk.core.instrumentation.span": { + "Span": { + "bases": [ + "ABC" + ], + "methods": { + "context": "context(self) -> InstrumentationContext", + "end": "end(self, error: BaseException | None) -> None", + "is_recording": "is_recording(self) -> bool", + "make_current": "make_current(self) -> TracingScope", + "set_attribute": "set_attribute(self, key: str, value: Any) -> Self", + "set_error": "set_error(self, error_type: str) -> Self" + } + } + }, + "dexpace.sdk.core.instrumentation.tracer": { + "Tracer": { + "bases": [ + "ABC" + ], + "methods": { + "start_span": "start_span(self, name: str, parent: InstrumentationContext | None) -> Span" + } + } + }, + "dexpace.sdk.core.instrumentation.tracing_scope": { + "TracingScope": { + "bases": [ + "ABC" + ], + "methods": { + "close": "close(self) -> None" + } + } + }, + "dexpace.sdk.core.instrumentation.url_redactor": { + "UrlRedactor": { + "bases": [], + "methods": { + "redact": "redact(self, url: str | Url) -> str" + } + } + }, + "dexpace.sdk.core.pagination.link_header": { + "find_rel": "find_rel(value: str, rel: str) -> str | None", + "parse_link_header": "parse_link_header(value: str) -> tuple[ParsedLink, ...]" + }, + "dexpace.sdk.core.pagination.page": { + "Page": { + "bases": [], + "methods": { + "aclose": "async aclose(self) -> None", + "close": "close(self) -> None", + "has_next": "has_next(self) -> bool" + } + } + }, + "dexpace.sdk.core.pagination.paginator": { + "AsyncPaginator": { + "bases": [], + "methods": { + "by_page": "async by_page(self) -> AsyncIterator[Page[T]]" + } + }, + "AsyncPipelineLike": { + "bases": [ + "Protocol" + ], + "methods": { + "run": "async run(self, request: Request, dispatch: DispatchContext) -> AsyncResponse" + } + }, + "Paginator": { + "bases": [], + "methods": { + "by_page": "by_page(self) -> Iterator[Page[T]]" + } + }, + "SyncPipelineLike": { + "bases": [ + "Protocol" + ], + "methods": { + "run": "run(self, request: Request, dispatch: DispatchContext) -> Response" + } + } + }, + "dexpace.sdk.core.pagination.strategy": { + "CursorStrategy": { + "bases": [], + "methods": { + "parse": "parse(self, response: HasHeaders, payload: object, template_request: Request) -> Page[T]" + } + }, + "HasHeaders": { + "bases": [ + "Protocol" + ], + "methods": { + "headers": "headers(self) -> Headers" + } + }, + "LinkHeaderStrategy": { + "bases": [], + "methods": { + "parse": "parse(self, response: HasHeaders, payload: object, template_request: Request) -> Page[T]" + } + }, + "PageNumberStrategy": { + "bases": [], + "methods": { + "parse": "parse(self, response: HasHeaders, payload: object, template_request: Request) -> Page[T]" + } + }, + "PaginationStrategy": { + "bases": [ + "Protocol" + ], + "methods": { + "parse": "parse(self, response: HasHeaders, payload: object, template_request: Request) -> Page[T]" + } + } + }, + "dexpace.sdk.core.pipeline.async_pipeline": { + "AsyncPipeline": { + "bases": [], + "methods": { + "run": "async run(self, request: Request, dispatch: DispatchContext, **options: Any) -> AsyncResponse" + } + } + }, + "dexpace.sdk.core.pipeline.async_policy": { + "AsyncPolicy": { + "bases": [ + "ABC" + ], + "methods": { + "send": "async send(self, request: Request, ctx: PipelineContext) -> AsyncResponse" + } + } + }, + "dexpace.sdk.core.pipeline.async_staged_builder": { + "AsyncStagedPipelineBuilder": { + "bases": [], + "methods": { + "append": "append(self, policy: AsyncPolicy, *, force: bool) -> Self", + "build": "build(self) -> AsyncPipeline", + "from_pipeline": "from_pipeline(cls, pipeline: AsyncPipeline) -> Self", + "insert_after": "insert_after(self, target: type[AsyncPolicy], new: AsyncPolicy) -> Self", + "insert_before": "insert_before(self, target: type[AsyncPolicy], new: AsyncPolicy) -> Self", + "prepend": "prepend(self, policy: AsyncPolicy, *, force: bool) -> Self", + "remove": "remove(self, target: type[AsyncPolicy]) -> Self", + "replace": "replace(self, target: type[AsyncPolicy], new: AsyncPolicy) -> Self" + } + } + }, + "dexpace.sdk.core.pipeline.context": { + "PipelineContext": { + "bases": [], + "methods": {} + } + }, + "dexpace.sdk.core.pipeline.defaults": { + "default_async_pipeline": "default_async_pipeline(client: AsyncHttpClient, *, redirect: AsyncRedirectPolicy | None, idempotency: AsyncIdempotencyPolicy | None, retry: AsyncRetryPolicy | None, set_date: AsyncSetDatePolicy | None, client_identity: AsyncClientIdentityPolicy | None, auth: AsyncPolicy | None) -> AsyncStagedPipelineBuilder", + "default_pipeline": "default_pipeline(client: HttpClient, *, redirect: RedirectPolicy | None, idempotency: IdempotencyPolicy | None, retry: RetryPolicy | None, set_date: SetDatePolicy | None, client_identity: ClientIdentityPolicy | None, auth: Policy | None, logging: LoggingPolicy | None, tracing: TracingPolicy | None) -> StagedPipelineBuilder" + }, + "dexpace.sdk.core.pipeline.pipeline": { + "Pipeline": { + "bases": [], + "methods": { + "run": "run(self, request: Request, dispatch: DispatchContext, **options: Any) -> Response" + } + } + }, + "dexpace.sdk.core.pipeline.policies._history": { + "RequestHistory": { + "bases": [], + "methods": {} + } + }, + "dexpace.sdk.core.pipeline.policies.async_client_identity": { + "AsyncClientIdentityPolicy": { + "bases": [ + "AsyncPolicy" + ], + "methods": { + "send": "async send(self, request: Request, ctx: PipelineContext) -> AsyncResponse" + } + } + }, + "dexpace.sdk.core.pipeline.policies.async_idempotency": { + "AsyncIdempotencyPolicy": { + "bases": [ + "AsyncPolicy" + ], + "methods": { + "send": "async send(self, request: Request, ctx: PipelineContext) -> AsyncResponse" + } + } + }, + "dexpace.sdk.core.pipeline.policies.async_redirect": { + "AsyncRedirectPolicy": { + "bases": [ + "AsyncPolicy" + ], + "methods": { + "send": "async send(self, request: Request, ctx: PipelineContext) -> AsyncResponse" + } + } + }, + "dexpace.sdk.core.pipeline.policies.async_retry": { + "AsyncRetryPolicy": { + "bases": [ + "AsyncPolicy" + ], + "methods": { + "no_retries": "no_retries(cls) -> AsyncRetryPolicy", + "send": "async send(self, request: Request, ctx: PipelineContext) -> AsyncResponse" + } + } + }, + "dexpace.sdk.core.pipeline.policies.async_set_date": { + "AsyncSetDatePolicy": { + "bases": [ + "AsyncPolicy" + ], + "methods": { + "send": "async send(self, request: Request, ctx: PipelineContext) -> AsyncResponse" + } + } + }, + "dexpace.sdk.core.pipeline.policies.client_identity": { + "ClientIdentityPolicy": { + "bases": [ + "Policy" + ], + "methods": { + "send": "send(self, request: Request, ctx: PipelineContext) -> Response" + } + }, + "default_user_agent": "default_user_agent() -> str" + }, + "dexpace.sdk.core.pipeline.policies.idempotency": { + "IdempotencyPolicy": { + "bases": [ + "Policy" + ], + "methods": { + "send": "send(self, request: Request, ctx: PipelineContext) -> Response" + } + } + }, + "dexpace.sdk.core.pipeline.policies.logging_policy": { + "LoggingPolicy": { + "bases": [ + "Policy" + ], + "methods": { + "send": "send(self, request: Request, ctx: PipelineContext) -> Response" + } + } + }, + "dexpace.sdk.core.pipeline.policies.redirect": { + "RedirectPolicy": { + "bases": [ + "Policy" + ], + "methods": { + "send": "send(self, request: Request, ctx: PipelineContext) -> Response" + } + }, + "resolve_http_tracer": "resolve_http_tracer(ctx: PipelineContext) -> HttpTracer" + }, + "dexpace.sdk.core.pipeline.policies.retry": { + "RetryMode": { + "bases": [ + "StrEnum" + ], + "methods": {} + }, + "RetryPolicy": { + "bases": [ + "Policy" + ], + "methods": { + "no_retries": "no_retries(cls) -> RetryPolicy", + "send": "send(self, request: Request, ctx: PipelineContext) -> Response" + } + } + }, + "dexpace.sdk.core.pipeline.policies.set_date": { + "SetDatePolicy": { + "bases": [ + "Policy" + ], + "methods": { + "send": "send(self, request: Request, ctx: PipelineContext) -> Response" + } + } + }, + "dexpace.sdk.core.pipeline.policies.tracing_policy": { + "TracingPolicy": { + "bases": [ + "Policy" + ], + "methods": { + "send": "send(self, request: Request, ctx: PipelineContext) -> Response" + } + } + }, + "dexpace.sdk.core.pipeline.policy": { + "Policy": { + "bases": [ + "ABC" + ], + "methods": { + "send": "send(self, request: Request, ctx: PipelineContext) -> Response" + } + } + }, + "dexpace.sdk.core.pipeline.stage": { + "Stage": { + "bases": [ + "IntEnum" + ], + "methods": { + "is_pillar": "is_pillar(self) -> bool" + } + } + }, + "dexpace.sdk.core.pipeline.staged_builder": { + "StagedPipelineBuilder": { + "bases": [], + "methods": { + "append": "append(self, policy: Policy, *, force: bool) -> Self", + "build": "build(self) -> Pipeline", + "from_pipeline": "from_pipeline(cls, pipeline: Pipeline) -> Self", + "insert_after": "insert_after(self, target: type[Policy], new: Policy) -> Self", + "insert_before": "insert_before(self, target: type[Policy], new: Policy) -> Self", + "prepend": "prepend(self, policy: Policy, *, force: bool) -> Self", + "remove": "remove(self, target: type[Policy]) -> Self", + "replace": "replace(self, target: type[Policy], new: Policy) -> Self" + } + } + }, + "dexpace.sdk.core.pipeline.step.config": { + "RetryConfig": { + "bases": [], + "methods": {} + }, + "StepMetadata": { + "bases": [], + "methods": {} + } + }, + "dexpace.sdk.core.pipeline.step.pipeline_step": { + "PipelineStep": { + "bases": [ + "Protocol" + ], + "methods": {} + } + }, + "dexpace.sdk.core.serde.codec": { + "Codec": { + "bases": [], + "methods": { + "decode": "decode(self, data: object, target: type[T]) -> T", + "encode": "encode(self, value: object) -> object" + } + }, + "CodecError": { + "bases": [ + "DeserializationError" + ], + "methods": {} + }, + "discriminated": "discriminated(tag_field: str, /) -> Callable[[type[T]], type[T]]", + "field_alias": "field_alias(wire_name: str, /, *, default: object, default_factory: Callable[[], object] | None) -> object", + "variant": "variant(tag_value: str, /) -> Callable[[type[T]], type[T]]" + }, + "dexpace.sdk.core.serde.json_serde": { + "JsonDeserializer": { + "bases": [], + "methods": { + "deserialize": "deserialize(self, value: str) -> Any", + "deserialize_bytes": "deserialize_bytes(self, value: bytes) -> Any", + "deserialize_stream": "deserialize_stream(self, stream: BinaryIO) -> Any" + } + }, + "JsonSerde": { + "bases": [], + "methods": { + "deserializer": "deserializer(self) -> JsonDeserializer", + "serializer": "serializer(self) -> JsonSerializer" + } + }, + "JsonSerializer": { + "bases": [], + "methods": { + "serialize": "serialize(self, value: Any) -> str", + "serialize_to_bytes": "serialize_to_bytes(self, value: Any) -> bytes", + "serialize_to_stream": "serialize_to_stream(self, value: Any, stream: BinaryIO) -> None" + } + } + }, + "dexpace.sdk.core.serde.serde": { + "Deserializer": { + "bases": [ + "Protocol" + ], + "methods": { + "deserialize": "deserialize(self, value: str) -> Any", + "deserialize_bytes": "deserialize_bytes(self, value: bytes) -> Any", + "deserialize_stream": "deserialize_stream(self, stream: BinaryIO) -> Any" + } + }, + "Serde": { + "bases": [ + "Protocol" + ], + "methods": { + "deserializer": "deserializer(self) -> Deserializer", + "serializer": "serializer(self) -> Serializer" + } + }, + "Serializer": { + "bases": [ + "Protocol" + ], + "methods": { + "serialize": "serialize(self, value: Any) -> str", + "serialize_to_bytes": "serialize_to_bytes(self, value: Any) -> bytes", + "serialize_to_stream": "serialize_to_stream(self, value: Any, stream: BinaryIO) -> None" + } + } + }, + "dexpace.sdk.core.serde.tristate": { + "Present": { + "bases": [], + "methods": {} + }, + "fold": "fold(state: Tristate[T], *, on_absent: Callable[[], R], on_null: Callable[[], R], on_present: Callable[[T], R]) -> R", + "is_absent": "is_absent(state: Tristate[T]) -> TypeGuard[_Absent]", + "is_null": "is_null(state: Tristate[T]) -> TypeGuard[_Null]", + "is_present": "is_present(state: Tristate[T]) -> TypeGuard[Present[T]]", + "of_optional": "of_optional(value: T | None) -> _Null | Present[T]", + "present": "present(value: T) -> Present[T]" + }, + "dexpace.sdk.core.util.clock": { + "AsyncClock": { + "bases": [ + "Protocol" + ], + "methods": { + "monotonic": "monotonic(self) -> float", + "now": "now(self) -> float", + "sleep": "async sleep(self, duration: float) -> None" + } + }, + "Clock": { + "bases": [ + "Protocol" + ], + "methods": { + "monotonic": "monotonic(self) -> float", + "now": "now(self) -> float", + "sleep": "sleep(self, duration: float) -> None" + } + } + }, + "dexpace.sdk.core.util.proxy": { + "ProxyOptions": { + "bases": [], + "methods": { + "bypasses_proxy": "bypasses_proxy(self, host: str) -> bool", + "from_configuration": "from_configuration(cls, config: Configuration) -> Self | None" + } + }, + "ProxyType": { + "bases": [ + "StrEnum" + ], + "methods": {} + } + } + }, + "exports": { + "dexpace.sdk.core.client": [ + "AsyncHttpClient", + "HttpClient", + "asyncio_sleep" + ], + "dexpace.sdk.core.config": [ + "Configuration", + "ConfigurationBuilder" + ], + "dexpace.sdk.core.errors": [ + "ClientAuthenticationError", + "DecodeError", + "DeserializationError", + "HttpResponseError", + "PipelineAbortedError", + "ResourceExistsError", + "ResourceModifiedError", + "ResourceNotFoundError", + "ResourceNotModifiedError", + "ResponseNotReadError", + "SdkError", + "SerializationError", + "ServiceRequestError", + "ServiceRequestTimeoutError", + "ServiceResponseError", + "ServiceResponseTimeoutError", + "StreamClosedError", + "StreamConsumedError", + "StreamingError", + "map_error" + ], + "dexpace.sdk.core.http.auth": [ + "AccessTokenInfo", + "AsyncBearerTokenPolicy", + "AsyncTokenCredential", + "AuthenticateChallenge", + "BasicAuthCredential", + "BasicAuthPolicy", + "BasicChallengeHandler", + "BearerTokenPolicy", + "ChallengeHandler", + "CompositeChallengeHandler", + "DigestAlgorithm", + "DigestChallengeHandler", + "InMemoryTokenCache", + "KeyCredential", + "KeyCredentialPolicy", + "NamedKeyCredential", + "TokenCache", + "TokenCredential", + "TokenRequestOptions", + "parse_challenges" + ], + "dexpace.sdk.core.http.common": [ + "AsyncItemPaged", + "AsyncPager", + "ETag", + "Headers", + "HttpHeaderName", + "HttpRange", + "ItemPaged", + "MediaType", + "Pager", + "Protocol", + "QueryParams", + "RequestConditions", + "Url", + "aiter_chunked_frame", + "aiter_jsonl", + "chunked_frame", + "common_media_types", + "http_header_name", + "iter_jsonl" + ], + "dexpace.sdk.core.http.context": [ + "CallContext", + "ContextStore", + "DispatchContext", + "ExchangeContext", + "RequestContext" + ], + "dexpace.sdk.core.http.request": [ + "AsyncRequestBody", + "FileRequestBody", + "LoggableRequestBody", + "Method", + "MultipartField", + "MultipartRequestBody", + "Request", + "RequestBody" + ], + "dexpace.sdk.core.http.response": [ + "AsyncResponse", + "AsyncResponseBody", + "LoggableResponseBody", + "Response", + "ResponseBody", + "Status" + ], + "dexpace.sdk.core.http.sse": [ + "AsyncSseStream", + "SseEvent", + "SseParser", + "parse_async_events", + "parse_events" + ], + "dexpace.sdk.core.http.webhooks": [ + "DEFAULT_TOLERANCE_SECONDS", + "InvalidWebhookSignatureError", + "WebhookVerifier" + ], + "dexpace.sdk.core.instrumentation": [ + "ClientLogger", + "CorrelationFilter", + "Counter", + "DEFAULT_QUERY_ALLOWLIST", + "Histogram", + "HttpTracer", + "HttpTracerFactory", + "InstrumentationContext", + "LogLevel", + "MetricsContext", + "NOOP_COUNTER", + "NOOP_HISTOGRAM", + "NOOP_HTTP_TRACER", + "NOOP_HTTP_TRACER_FACTORY", + "NOOP_INSTRUMENTATION_CONTEXT", + "NOOP_METRICS_CONTEXT", + "NOOP_SPAN", + "NOOP_TRACER", + "NOOP_UPDOWN_COUNTER", + "Span", + "SpanId", + "TraceFlags", + "TraceId", + "TraceIdType", + "TraceState", + "Tracer", + "TracingScope", + "UpDownCounter", + "UrlRedactor", + "bind_correlation", + "get_span_id", + "get_trace_id", + "set_span_id", + "set_trace_id" + ], + "dexpace.sdk.core.pagination": [ + "AsyncPaginator", + "AsyncPipelineLike", + "CursorStrategy", + "HasHeaders", + "LinkHeaderStrategy", + "Page", + "PageNumberStrategy", + "PaginationStrategy", + "Paginator", + "ParsedLink", + "SendAsync", + "SendSync", + "SyncPipelineLike", + "find_rel", + "parse_link_header" + ], + "dexpace.sdk.core.pipeline": [ + "AsyncPipeline", + "AsyncPolicy", + "AsyncStagedPipelineBuilder", + "Pipeline", + "PipelineContext", + "PipelineStep", + "Policy", + "RequestPipelineStep", + "ResponsePipelineStep", + "RetryConfig", + "Stage", + "StagedPipelineBuilder", + "StepMetadata", + "default_async_pipeline", + "default_pipeline" + ], + "dexpace.sdk.core.pipeline.policies": [ + "AsyncClientIdentityPolicy", + "AsyncIdempotencyPolicy", + "AsyncRedirectPolicy", + "AsyncRetryPolicy", + "AsyncSetDatePolicy", + "ClientIdentityPolicy", + "IdempotencyPolicy", + "LoggingPolicy", + "RedirectPolicy", + "RequestHistory", + "RetryMode", + "RetryPolicy", + "SetDatePolicy", + "TracingPolicy", + "default_user_agent" + ], + "dexpace.sdk.core.pipeline.step": [ + "PipelineStep", + "RequestPipelineStep", + "ResponsePipelineStep" + ], + "dexpace.sdk.core.serde": [ + "ABSENT", + "ALIAS_KEY", + "Codec", + "CodecError", + "DISCRIMINATOR_KEY", + "Deserializer", + "JSON_SERDE", + "JsonDeserializer", + "JsonSerde", + "JsonSerializer", + "NULL", + "Present", + "REGISTRY_KEY", + "Serde", + "Serializer", + "Tristate", + "discriminated", + "field_alias", + "fold", + "is_absent", + "is_null", + "is_present", + "of_optional", + "present", + "variant" + ], + "dexpace.sdk.core.util": [ + "ASYNC_SYSTEM_CLOCK", + "AsyncClock", + "Clock", + "ProxyOptions", + "ProxyType", + "SYSTEM_CLOCK" + ] + } + }, + "dexpace-sdk-http-aiohttp": { + "definitions": { + "dexpace.sdk.http.aiohttp.client": { + "AiohttpHttpClient": { + "bases": [], + "methods": { + "aclose": "async aclose(self) -> None", + "execute": "async execute(self, request: Request) -> AsyncResponse" + } + } + } + }, + "exports": { + "dexpace.sdk.http.aiohttp": [ + "AiohttpHttpClient" + ] + } + }, + "dexpace-sdk-http-httpx": { + "definitions": { + "dexpace.sdk.http.httpx.async_": { + "AsyncHttpxHttpClient": { + "bases": [], + "methods": { + "aclose": "async aclose(self) -> None", + "execute": "async execute(self, request: Request) -> AsyncResponse" + } + } + }, + "dexpace.sdk.http.httpx.sync": { + "HttpxHttpClient": { + "bases": [], + "methods": { + "close": "close(self) -> None", + "execute": "execute(self, request: Request) -> Response" + } + } + } + }, + "exports": { + "dexpace.sdk.http.httpx": [ + "AsyncHttpxHttpClient", + "HttpxHttpClient" + ] + } + }, + "dexpace-sdk-http-requests": { + "definitions": { + "dexpace.sdk.http.requests.client": { + "RequestsHttpClient": { + "bases": [], + "methods": { + "close": "close(self) -> None", + "execute": "execute(self, request: Request) -> Response" + } + } + } + }, + "exports": { + "dexpace.sdk.http.requests": [ + "RequestsHttpClient" + ] + } + }, + "dexpace-sdk-http-stdlib": { + "definitions": { + "dexpace.sdk.http.stdlib.asyncio_http_client": { + "AsyncioHttpClient": { + "bases": [], + "methods": { + "aclose": "async aclose(self) -> None", + "execute": "async execute(self, request: Request) -> AsyncResponse" + } + } + }, + "dexpace.sdk.http.stdlib.urllib_http_client": { + "UrllibHttpClient": { + "bases": [], + "methods": { + "close": "close(self) -> None", + "execute": "execute(self, request: Request) -> Response" + } + } + } + }, + "exports": { + "dexpace.sdk.http.stdlib": [ + "AsyncioHttpClient", + "UrllibHttpClient" + ] + } + } +} diff --git a/tools/surface_snapshot.py b/tools/surface_snapshot.py new file mode 100644 index 0000000..6593782 --- /dev/null +++ b/tools/surface_snapshot.py @@ -0,0 +1,307 @@ +# Copyright (c) 2026 dexpace and Omar Aljarrah. +# Licensed under the MIT License. See LICENSE.md in the repository root for details. + +"""Static extraction of the public API surface of every SDK distribution. + +This module walks each distribution's ``src/`` tree with the standard-library +``ast`` parser — it never imports or executes project code, so it works without +the workspace being installed and cannot be tricked by import-time side effects. + +The extracted surface has two halves per distribution: + +- ``exports``: every ``__init__.py``'s ``__all__`` list, keyed by the dotted + package path. This captures the deliberately narrow re-export surface that + the project conventions require to stay accurate. +- ``definitions``: every public top-level class / function / Protocol defined + anywhere in the tree, recorded with its signature (and, for classes, base + classes and public method signatures). This catches a signature drifting or a + ``with_*`` helper disappearing even when ``__all__`` is unchanged. + +The resulting nested ``dict`` is plain JSON-serialisable data, sorted for a +stable diff. ``build_surface`` produces it; ``main`` writes it to the committed +baseline path. The companion pytest (``test_public_surface.py``) compares the +live surface against that baseline and fails on any unexpected change. + +Run as a script to regenerate the baseline: + + python tools/surface_snapshot.py --write +""" + +from __future__ import annotations + +import argparse +import ast +import json +import sys +from pathlib import Path + +# Distribution directory name -> the namespace sub-path under ``src`` that +# holds its modules. Every distribution shares the ``dexpace/sdk`` PEP-420 +# namespace prefix; only the leaf differs. +_DISTRIBUTIONS: dict[str, str] = { + "dexpace-sdk-core": "dexpace/sdk/core", + "dexpace-sdk-http-stdlib": "dexpace/sdk/http/stdlib", + "dexpace-sdk-http-httpx": "dexpace/sdk/http/httpx", + "dexpace-sdk-http-aiohttp": "dexpace/sdk/http/aiohttp", + "dexpace-sdk-http-requests": "dexpace/sdk/http/requests", +} + +type Surface = dict[str, object] + + +def repo_root() -> Path: + """Return the workspace root (the directory that holds ``packages/``). + + Returns: + The absolute path to the repository root, resolved from this file's + location (``tools/surface_snapshot.py`` sits directly under the root). + """ + return Path(__file__).resolve().parent.parent + + +def baseline_path() -> Path: + """Return the absolute path to the committed surface baseline JSON.""" + return repo_root() / "tools" / "surface_baseline.json" + + +def _is_public(name: str) -> bool: + """Return whether ``name`` is part of the public surface (no leading ``_``). + + Dunder names (``__enter__`` and friends) are protocol hooks, not public + API, so they are excluded alongside single-underscore privates. + """ + return not name.startswith("_") + + +def _format_arg(arg: ast.arg) -> str: + """Render one parameter as ``name`` or ``name: annotation``.""" + if arg.annotation is not None: + return f"{arg.arg}: {ast.unparse(arg.annotation)}" + return arg.arg + + +def _format_signature(node: ast.FunctionDef | ast.AsyncFunctionDef) -> str: + """Render a function/method signature as a stable, source-independent string. + + Default *values* are deliberately omitted — only their presence matters for + surface compatibility, and inlining a literal default makes the baseline + churn on cosmetic edits. Parameter names, annotations, the positional/ + keyword-only split, and the return annotation are all preserved. + + Args: + node: The (async) function definition node to render. + + Returns: + A canonical one-line signature string, e.g. + ``execute(self, request: Request) -> Response``. + """ + args = node.args + parts: list[str] = [] + parts.extend(_format_arg(a) for a in args.posonlyargs) + if args.posonlyargs: + parts.append("/") + parts.extend(_format_arg(a) for a in args.args) + if args.vararg is not None: + parts.append(f"*{_format_arg(args.vararg)}") + elif args.kwonlyargs: + parts.append("*") + parts.extend(_format_arg(a) for a in args.kwonlyargs) + if args.kwarg is not None: + parts.append(f"**{_format_arg(args.kwarg)}") + returns = f" -> {ast.unparse(node.returns)}" if node.returns is not None else "" + prefix = "async " if isinstance(node, ast.AsyncFunctionDef) else "" + return f"{prefix}{node.name}({', '.join(parts)}){returns}" + + +def _class_surface(node: ast.ClassDef) -> dict[str, object]: + """Extract the public surface of a class definition. + + Records its base classes (so a Protocol/ABC becoming a plain class, or a + base disappearing, is caught) and the signature of each public method. + + Args: + node: The class definition node. + + Returns: + A mapping with ``bases`` (sorted base expressions) and ``methods`` + (method name -> signature string, public methods only). + """ + bases = sorted(ast.unparse(b) for b in node.bases) + methods: dict[str, str] = {} + for item in node.body: + if isinstance(item, ast.FunctionDef | ast.AsyncFunctionDef) and _is_public(item.name): + methods[item.name] = _format_signature(item) + return {"bases": bases, "methods": dict(sorted(methods.items()))} + + +def _extract_all(module: ast.Module) -> list[str] | None: + """Return the value of a module-level ``__all__`` list, or ``None``. + + Only a literal list/tuple of string constants is recognised — the project + convention is to declare ``__all__`` as a plain literal, never built + dynamically, so anything else is treated as "no static ``__all__``". + + Args: + module: The parsed module node. + + Returns: + The sorted list of exported names, or ``None`` if no static ``__all__`` + literal is present. + """ + for stmt in module.body: + if not isinstance(stmt, ast.Assign): + continue + targets = [t for t in stmt.targets if isinstance(t, ast.Name)] + if not any(t.id == "__all__" for t in targets): + continue + value = stmt.value + if isinstance(value, ast.List | ast.Tuple): + names = [ + el.value + for el in value.elts + if isinstance(el, ast.Constant) and isinstance(el.value, str) + ] + return sorted(names) + return None + + +def _module_definitions(module: ast.Module) -> dict[str, object]: + """Extract public top-level class and function definitions from a module. + + Nested definitions are intentionally ignored — only the top-level surface + of a module is API. + + Args: + module: The parsed module node. + + Returns: + A mapping of public symbol name -> its surface descriptor (a class + descriptor mapping, or a function signature string). + """ + defs: dict[str, object] = {} + for stmt in module.body: + if isinstance(stmt, ast.ClassDef) and _is_public(stmt.name): + defs[stmt.name] = _class_surface(stmt) + elif isinstance(stmt, ast.FunctionDef | ast.AsyncFunctionDef) and _is_public(stmt.name): + defs[stmt.name] = _format_signature(stmt) + return dict(sorted(defs.items())) + + +def _dotted_package(src_root: Path, init_file: Path) -> str: + """Return the dotted package path for an ``__init__.py`` under ``src_root``. + + Args: + src_root: The distribution's ``src`` directory. + init_file: The ``__init__.py`` whose package path is wanted. + + Returns: + The dotted package name, e.g. ``dexpace.sdk.core.http.request``. + """ + rel = init_file.parent.relative_to(src_root) + return ".".join(rel.parts) + + +def _dotted_module(src_root: Path, module_file: Path) -> str: + """Return the dotted module path for a ``.py`` file under ``src_root``.""" + rel = module_file.relative_to(src_root).with_suffix("") + return ".".join(rel.parts) + + +def _parse(path: Path) -> ast.Module: + """Parse a Python source file into an AST module node.""" + return ast.parse(path.read_text(encoding="utf-8"), filename=str(path)) + + +def _distribution_surface(src_root: Path) -> dict[str, object]: + """Build the surface for a single distribution's ``src`` tree. + + Args: + src_root: The distribution's ``src`` directory (the parent of the + ``dexpace`` namespace directory). + + Returns: + A mapping with ``exports`` (package -> ``__all__``) and ``definitions`` + (module -> public definitions), both sorted for a stable diff. + """ + exports: dict[str, list[str]] = {} + definitions: dict[str, object] = {} + for module_file in sorted(src_root.rglob("*.py")): + module = _parse(module_file) + if module_file.name == "__init__.py": + names = _extract_all(module) + if names is not None: + exports[_dotted_package(src_root, module_file)] = names + module_defs = _module_definitions(module) + if module_defs: + definitions[_dotted_module(src_root, module_file)] = module_defs + return { + "exports": dict(sorted(exports.items())), + "definitions": dict(sorted(definitions.items())), + } + + +def build_surface(root: Path | None = None) -> Surface: + """Build the public-API surface of every distribution by static analysis. + + Args: + root: The workspace root to scan. Defaults to the repository root + inferred from this file's location. + + Returns: + A JSON-serialisable mapping of distribution name -> its surface. Only + distributions whose ``src`` tree exists are included; the result is + sorted by distribution name for a stable diff. + + Raises: + FileNotFoundError: If the ``packages`` directory does not exist under + ``root``. + """ + base = (root or repo_root()).resolve() + packages = base / "packages" + if not packages.is_dir(): + raise FileNotFoundError(f"no packages directory under {base}") + surface: dict[str, object] = {} + for dist, namespace in sorted(_DISTRIBUTIONS.items()): + src_root = packages / dist / "src" + if not (src_root / namespace).is_dir(): + continue + surface[dist] = _distribution_surface(src_root) + return surface + + +def render(surface: Surface) -> str: + """Render a surface mapping as canonical, newline-terminated JSON text.""" + return json.dumps(surface, indent=2, sort_keys=True, ensure_ascii=False) + "\n" + + +def main(argv: list[str] | None = None) -> int: + """CLI entry point: print the live surface, or write it to the baseline. + + Args: + argv: Argument vector (defaults to ``sys.argv[1:]``). + + Returns: + A process exit code (always ``0`` on success). + """ + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--write", + action="store_true", + help="write the live surface to the committed baseline instead of printing it", + ) + args = parser.parse_args(argv) + surface = build_surface() + text = render(surface) + if args.write: + baseline_path().write_text(text, encoding="utf-8") + print(f"wrote baseline to {baseline_path()}") + else: + sys.stdout.write(text) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/uv.lock b/uv.lock index 8b796c0..55cb1c4 100644 --- a/uv.lock +++ b/uv.lock @@ -426,7 +426,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "dexpace-sdk-core", editable = "packages/dexpace-sdk-core" }, - { name = "requests", specifier = ">=2.34.2,<3.0" }, + { name = "requests", specifier = ">=2.32,<3.0" }, ] [[package]]