diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/auth/digest.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/auth/digest.py index 7ec85ce..a9a3e3c 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/auth/digest.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/auth/digest.py @@ -8,9 +8,12 @@ verification are out of scope — matching the Java v1 cut. A single handler instance is intended to be reused across requests so the -per-client nonce counter (``nc``) advances monotonically. The counter is -guarded by a ``threading.Lock`` since CPython integer increment is not -atomic with respect to other threads' reads of the same variable. +per-nonce request counter (``nc``) advances correctly. RFC 7616 §3.4 defines +``nc`` as the count of requests the client has sent with the *current* server +nonce, so the counter restarts at ``00000001`` for each fresh nonce and +increments only while a nonce is reused. The per-nonce counts are tracked in a +bounded mapping guarded by a ``threading.Lock`` since CPython dict mutation and +integer increment are not atomic with respect to other threads. """ from __future__ import annotations @@ -18,6 +21,7 @@ import hashlib import secrets import threading +from collections import OrderedDict from collections.abc import Callable from dataclasses import dataclass from typing import Final @@ -58,6 +62,13 @@ class _ResolvedChallenge: "SHA-256-SESS": hashlib.sha256, } +# Cap on how many distinct server nonces retain a request count. A long-lived +# handler hitting many rotating nonces would otherwise grow the map without +# bound; the oldest entry is evicted past this size. The limit is generous — +# servers rotate nonces rarely — so eviction effectively never affects an +# active nonce. +_MAX_TRACKED_NONCES: Final[int] = 1024 + class DigestChallengeHandler: """Satisfy a ``Digest`` challenge per RFC 7616. @@ -79,8 +90,8 @@ class DigestChallengeHandler: __slots__ = ( "_cnonce_factory", - "_counter", "_lock", + "_nonce_counts", "_password", "_preferred", "_username", @@ -102,7 +113,10 @@ def __init__( self._password = password self._preferred = preferred_algorithms self._cnonce_factory = cnonce_factory or (lambda: secrets.token_hex(16)) - self._counter = 0 + # Maps a server nonce to the count of requests sent with it. Insertion + # order is preserved so the oldest nonce can be evicted once the map + # exceeds ``_MAX_TRACKED_NONCES``. + self._nonce_counts: OrderedDict[str, int] = OrderedDict() self._lock = threading.Lock() def can_handle(self, challenges: list[AuthenticateChallenge]) -> bool: @@ -122,7 +136,7 @@ def handle( resolved = self._resolve(selected) if resolved is None: return None - nc = self._next_nc() + nc = self._next_nc(resolved.nonce) cnonce = self._cnonce_factory() uri = _request_uri(url) response = self._compute_response( @@ -223,13 +237,33 @@ def _select(self, challenges: list[AuthenticateChallenge]) -> AuthenticateChalle return challenge return None - def _next_nc(self) -> str: + def _next_nc(self, nonce: str) -> str: + """Return the next ``nc`` value for ``nonce`` as 8 lowercase hex digits. + + RFC 7616 §3.4 defines ``nc`` as the count of requests sent with the + current server nonce, so the count starts at ``00000001`` for each + fresh nonce and increments on every reuse. A bounded, insertion-ordered + map tracks the per-nonce counts; the oldest nonce is evicted once the + map exceeds ``_MAX_TRACKED_NONCES`` to cap memory on a long-lived + handler that sees many rotating nonces. + + Args: + nonce: The server nonce the request is being signed against. + + Returns: + The request count for ``nonce`` formatted as 8 lowercase hex + digits, e.g. ``"00000001"`` for the first request. + """ with self._lock: # Clamp to 32 bits — ``nc`` is rendered as 8 hex digits per # RFC 7616, and wrapping after 2**32-1 is acceptable since the # server hashes the value (no monotonic check). - self._counter = (self._counter + 1) & 0xFFFFFFFF - value = self._counter + value = (self._nonce_counts.get(nonce, 0) + 1) & 0xFFFFFFFF + self._nonce_counts[nonce] = value + # Refresh recency so an actively reused nonce is never evicted. + self._nonce_counts.move_to_end(nonce) + if len(self._nonce_counts) > _MAX_TRACKED_NONCES: + self._nonce_counts.popitem(last=False) return f"{value:08x}" def _credentials_encodable(self, charset: str) -> bool: diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/auth/policies.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/auth/policies.py index f36edac..e2543e9 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/auth/policies.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/auth/policies.py @@ -39,6 +39,12 @@ class KeyCredentialPolicy(Policy): SansIO-shaped (no chain wrapping needed) but implemented as a ``Policy`` so it integrates uniformly with the rest of the pipeline. + The credential header is stamped only while the request stays on the + origin recorded on the first pass through the policy. When a downstream + redirect reissues the request against a different origin (scheme, host, + or effective port), the credential is withheld so it never reaches a + foreign host. + Attributes: header_name: Header to write. prefix: Optional prefix (with trailing space) for the header value. @@ -63,12 +69,21 @@ def __init__( self.prefix = f"{prefix} " if prefix else "" def send(self, request: Request, ctx: PipelineContext) -> Response: + if _crosses_recorded_origin(request, ctx): + return self.next.send(request, ctx) value = f"{self.prefix}{self._credential.key}" return self.next.send(request.with_header(self.header_name, value), ctx) class BasicAuthPolicy(Policy): - """Stamp ``Authorization: Basic `` from a ``BasicAuthCredential``.""" + """Stamp ``Authorization: Basic `` from a ``BasicAuthCredential``. + + The credential is stamped only while the request stays on the origin + recorded on the first pass through the policy. When a downstream redirect + reissues the request against a different origin (scheme, host, or + effective port), the credential is withheld so it never reaches a foreign + host. + """ STAGE = Stage.AUTH __slots__ = ("_credential",) @@ -79,6 +94,8 @@ def __init__(self, credential: BasicAuthCredential) -> None: self._credential = credential def send(self, request: Request, ctx: PipelineContext) -> Response: + if _crosses_recorded_origin(request, ctx): + return self.next.send(request, ctx) value = f"Basic {self._credential.encoded}" return self.next.send(request.with_header("Authorization", value), ctx) @@ -91,6 +108,13 @@ class BearerTokenPolicy(Policy): returns True, or after a 401 response with ``WWW-Authenticate``. Enforces HTTPS unless ``enforce_https=False`` is passed in ``ctx.options``. + The token is acquired and stamped only while the request stays on the + origin recorded on the first pass. When a downstream redirect reissues the + request against a different origin (scheme, host, or effective port), the + policy forwards the request unchanged — it does not acquire, refresh, or + stamp a token — so the bearer token never reaches a foreign host. The + HTTPS-enforcement check applies only on that same-origin stamping path. + Concurrent refreshes are serialized via a ``threading.Lock`` using a double-checked pattern so the credential's ``get_token_info`` is invoked at most once per refresh window even under heavy concurrent send pressure. @@ -226,6 +250,11 @@ def _authorize( *, force_refresh: bool = False, ) -> Request: + # A redirect that crossed origin must not receive the bearer token: + # forward the request unchanged without acquiring or refreshing one, + # and skip the HTTPS enforcement that only governs the stamping path. + if _crosses_recorded_origin(request, ctx): + return request if ctx.options.get("enforce_https", True) and not _is_https(request.url): raise ServiceRequestError( "Bearer token authentication is not permitted for non-HTTPS URLs." @@ -257,6 +286,13 @@ class AsyncBearerTokenPolicy(AsyncPolicy): the returned ``(name, value)`` pair is stamped on the retried request. A 401 invalidates the cached origin token; a 407 leaves it alone because the proxy, not the origin, rejected the request. + + The token is acquired and stamped only while the request stays on the + origin recorded on the first pass. When a downstream redirect reissues the + request against a different origin (scheme, host, or effective port), the + policy forwards the request unchanged — it does not acquire, refresh, or + stamp a token — so the bearer token never reaches a foreign host. The + HTTPS-enforcement check applies only on that same-origin stamping path. """ STAGE = Stage.AUTH @@ -381,6 +417,11 @@ async def _authorize( *, force_refresh: bool = False, ) -> Request: + # A redirect that crossed origin must not receive the bearer token: + # forward the request unchanged without acquiring or refreshing one, + # and skip the HTTPS enforcement that only governs the stamping path. + if _crosses_recorded_origin(request, ctx): + return request if ctx.options.get("enforce_https", True) and not _is_https(request.url): raise ServiceRequestError( "Bearer token authentication is not permitted for non-HTTPS URLs." @@ -398,6 +439,51 @@ async def _authorize( return request.with_header("Authorization", f"{token.token_type} {token.token}") +_DEFAULT_PORTS: dict[str, int] = {"https": 443, "http": 80} +_AUTH_ORIGIN_KEY: str = "_auth_origin" + + +def _origin(url: Url) -> tuple[str, str, int | None]: + """Return the ``(scheme, host, port)`` origin tuple for ``url``. + + The scheme and host are lower-cased and the port is resolved to its + scheme default (443 for https, 80 for http) when not explicit, so two + URLs that differ only in an implied/explicit default port compare equal. + + Args: + url: The URL to derive an origin from. + + Returns: + A ``(scheme, host, effective_port)`` tuple suitable for equality + comparison. + """ + scheme = url.scheme.lower() + port = url.port if url.port is not None else _DEFAULT_PORTS.get(scheme) + return scheme, url.host.lower(), port + + +def _crosses_recorded_origin(request: Request, ctx: PipelineContext) -> bool: + """Report whether ``request`` left the origin recorded for this operation. + + On the first pass through an auth policy the request's origin is stored in + ``ctx.data`` (which is per-operation), so a later redirect reissue can be + compared against it. When the current origin differs, the credential must + not be stamped — this is what stops a redirect to a foreign host from + receiving the caller's credentials. + + Args: + request: The request the auth policy is about to forward. + ctx: The per-operation pipeline context. + + Returns: + ``True`` when the request's origin differs from the one recorded on + the first pass; ``False`` on the first pass or a same-origin reissue. + """ + current = _origin(request.url) + recorded: tuple[str, str, int | None] = ctx.data.setdefault(_AUTH_ORIGIN_KEY, current) + return recorded != current + + def _is_https(url: Url) -> bool: """Return True if ``url``'s scheme is ``https`` (case-insensitive).""" return url.scheme.lower() == "https" diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/media_type.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/media_type.py index 7f21c82..192fb20 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/media_type.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/media_type.py @@ -17,6 +17,47 @@ _TOKEN_SEPARATORS = frozenset('()<>@,;:\\"/[]?={} \t') +def _split_params(value: str) -> list[str]: + """Split a media-type wire string on ``;`` outside quoted-strings. + + A naive ``value.split(";")`` mis-parses a quoted parameter value that + legitimately contains a semicolon (legal inside an RFC 7230 §3.2.6 + quoted-string, e.g. an exotic multipart boundary). This walks the string + with a small state machine that treats ``;`` as a separator only when it + is not enclosed in double quotes, honouring the ``\\X`` quoted-pair escape + so an escaped quote does not toggle the quoted state. + + Args: + value: The full wire-form media-type string. + + Returns: + The segments split on top-level ``;`` separators, in order. Surrounding + whitespace is not stripped here — callers strip per segment. + """ + segments: list[str] = [] + current: list[str] = [] + in_quotes = False + i = 0 + while i < len(value): + ch = value[i] + if in_quotes and ch == "\\" and i + 1 < len(value): + current.append(ch) + current.append(value[i + 1]) + i += 2 + continue + if ch == '"': + in_quotes = not in_quotes + elif ch == ";" and not in_quotes: + segments.append("".join(current)) + current = [] + i += 1 + continue + current.append(ch) + i += 1 + segments.append("".join(current)) + return segments + + def _unquote(s: str) -> str: """Decode a quoted-string parameter value per RFC 7230 §3.2.6. @@ -146,7 +187,7 @@ def parse(cls, value: str) -> Self: """ if not value or not value.strip(): raise ValueError("media type must not be blank") - segments = [segment.strip() for segment in value.split(";")] + segments = [segment.strip() for segment in _split_params(value)] mime = segments[0] slash = mime.find("/") if slash <= 0 or slash == len(mime) - 1: diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/pagination.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/pagination.py index bce316d..0c0ec3f 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/pagination.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/pagination.py @@ -11,6 +11,7 @@ from __future__ import annotations +import itertools from collections.abc import ( AsyncIterator, Awaitable, @@ -28,24 +29,40 @@ class Pager[T, R](Iterator[Iterator[T]]): Each ``__next__`` call returns an ``Iterator[T]`` over the items in the page that ``get_next`` produced. Iteration terminates when ``extract_data`` returns a ``None`` continuation token after producing - at least one page. + at least one page, or when the optional ``max_pages`` bound is reached. - On any ``SdkError`` raised by ``get_next``, the current - ``continuation_token`` is stamped onto the error so callers can resume. + On any ``SdkError`` raised by ``get_next`` or ``extract_data``, the + current ``continuation_token`` is stamped onto the error so callers can + resume from the page that failed. + + The optional ``max_pages`` guard caps how many pages are produced; it is + the safety valve against a buggy server that returns the same + continuation token forever, which would otherwise loop indefinitely. """ - __slots__ = ("_did_first_call", "_extract_data", "_get_next", "continuation_token") + __slots__ = ( + "_did_first_call", + "_extract_data", + "_get_next", + "_max_pages", + "_pages_yielded", + "continuation_token", + ) def __init__( self, get_next: Callable[[str | None], R], extract_data: Callable[[R], tuple[str | None, Iterable[T]]], continuation_token: str | None = None, + *, + max_pages: int | None = None, ) -> None: self._get_next = get_next self._extract_data = extract_data self.continuation_token = continuation_token self._did_first_call = False + self._max_pages = max_pages + self._pages_yielded = 0 def __iter__(self) -> Iterator[Iterator[T]]: return self @@ -53,14 +70,17 @@ def __iter__(self) -> Iterator[Iterator[T]]: def __next__(self) -> Iterator[T]: if self.continuation_token is None and self._did_first_call: raise StopIteration + if self._max_pages is not None and self._pages_yielded >= self._max_pages: + raise StopIteration try: response = self._get_next(self.continuation_token) + self._did_first_call = True + self.continuation_token, items = self._extract_data(response) except SdkError as err: if err.continuation_token is None: err.continuation_token = self.continuation_token raise - self._did_first_call = True - self.continuation_token, items = self._extract_data(response) + self._pages_yielded += 1 return iter(items) @@ -74,16 +94,19 @@ class ItemPaged[T, R](Iterator[T]): and consumed by ``extract_data``. """ - __slots__ = ("_extract_data", "_flat", "_get_next") + __slots__ = ("_extract_data", "_flat", "_get_next", "_max_pages") def __init__( self, get_next: Callable[[str | None], R], extract_data: Callable[[R], tuple[str | None, Iterable[T]]], + *, + max_pages: int | None = None, ) -> None: self._get_next = get_next self._extract_data = extract_data self._flat: Iterator[T] | None = None + self._max_pages = max_pages def by_page(self, continuation_token: str | None = None) -> Iterator[Iterator[T]]: """Return a page-level iterator, optionally resuming from a token. @@ -93,12 +116,14 @@ def by_page(self, continuation_token: str | None = None) -> Iterator[Iterator[T] than the first. Returns: - An iterator yielding one ``Iterator[T]`` per page. + An iterator yielding one ``Iterator[T]`` per page. The + ``max_pages`` bound supplied at construction is applied. """ return Pager( self._get_next, self._extract_data, continuation_token=continuation_token, + max_pages=self._max_pages, ) def __iter__(self) -> Iterator[T]: @@ -106,27 +131,43 @@ def __iter__(self) -> Iterator[T]: def __next__(self) -> T: if self._flat is None: - import itertools as _it - - self._flat = _it.chain.from_iterable(self.by_page()) + self._flat = itertools.chain.from_iterable(self.by_page()) return next(self._flat) class AsyncPager[T, R](AsyncIterator[AsyncIterator[T]]): - """Async iterator of pages.""" + """Async iterator of pages. - __slots__ = ("_did_first_call", "_extract_data", "_get_next", "continuation_token") + On any ``SdkError`` raised by ``get_next`` or ``extract_data``, the + current ``continuation_token`` is stamped onto the error so callers can + resume from the page that failed. The optional ``max_pages`` guard caps + how many pages are produced, guarding against a server that returns the + same continuation token forever. + """ + + __slots__ = ( + "_did_first_call", + "_extract_data", + "_get_next", + "_max_pages", + "_pages_yielded", + "continuation_token", + ) def __init__( self, get_next: Callable[[str | None], Awaitable[R]], extract_data: Callable[[R], Awaitable[tuple[str | None, Iterable[T]]]], continuation_token: str | None = None, + *, + max_pages: int | None = None, ) -> None: self._get_next = get_next self._extract_data = extract_data self.continuation_token = continuation_token self._did_first_call = False + self._max_pages = max_pages + self._pages_yielded = 0 def __aiter__(self) -> AsyncIterator[AsyncIterator[T]]: return self @@ -134,15 +175,18 @@ def __aiter__(self) -> AsyncIterator[AsyncIterator[T]]: async def __anext__(self) -> AsyncIterator[T]: if self.continuation_token is None and self._did_first_call: raise StopAsyncIteration + if self._max_pages is not None and self._pages_yielded >= self._max_pages: + raise StopAsyncIteration try: response = await self._get_next(self.continuation_token) + self._did_first_call = True + token, items = await self._extract_data(response) except SdkError as err: if err.continuation_token is None: err.continuation_token = self.continuation_token raise - self._did_first_call = True - token, items = await self._extract_data(response) self.continuation_token = token + self._pages_yielded += 1 return _SyncToAsync(iter(items)) @@ -153,17 +197,20 @@ class AsyncItemPaged[T, R](AsyncIterator[T]): and consumed by ``extract_data``. """ - __slots__ = ("_current", "_extract_data", "_get_next", "_pages") + __slots__ = ("_current", "_extract_data", "_get_next", "_max_pages", "_pages") def __init__( self, get_next: Callable[[str | None], Awaitable[R]], extract_data: Callable[[R], Awaitable[tuple[str | None, Iterable[T]]]], + *, + max_pages: int | None = None, ) -> None: self._get_next = get_next self._extract_data = extract_data self._pages: AsyncIterator[AsyncIterator[T]] | None = None self._current: AsyncIterator[T] | None = None + self._max_pages = max_pages def by_page( self, @@ -176,12 +223,14 @@ def by_page( than the first. Returns: - An async iterator yielding one ``AsyncIterator[T]`` per page. + An async iterator yielding one ``AsyncIterator[T]`` per page. The + ``max_pages`` bound supplied at construction is applied. """ return AsyncPager( self._get_next, self._extract_data, continuation_token=continuation_token, + max_pages=self._max_pages, ) def __aiter__(self) -> AsyncIterator[T]: diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/streaming.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/streaming.py index 7800a09..afa7777 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/streaming.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/streaming.py @@ -14,6 +14,7 @@ from typing import Any from ...errors.serialization import DeserializationError +from ...errors.streaming import StreamingError def iter_jsonl(chunks: Iterable[bytes]) -> Iterator[Any]: @@ -30,6 +31,8 @@ def iter_jsonl(chunks: Iterable[bytes]) -> Iterator[Any]: Raises: DeserializationError: If a line is not valid JSON. + StreamingError: If a line is not valid UTF-8 (e.g. a non-UTF-8 byte + sequence or a codepoint truncated by a short final line). """ buffer = bytearray() for chunk in chunks: @@ -66,7 +69,10 @@ def _drain_lines(buffer: bytearray, *, final: bool = False) -> Iterator[Any]: def _parse_line(line: bytes) -> Iterator[Any]: - text = line.rstrip(b"\r").decode("utf-8", errors="strict").strip() + try: + text = line.rstrip(b"\r").decode("utf-8", errors="strict").strip() + except UnicodeDecodeError as err: + raise StreamingError("JSONL line is not valid UTF-8") from err if not text: return try: diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/url.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/url.py index b1c794b..522a1ea 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/url.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/common/url.py @@ -108,6 +108,13 @@ def with_added(self, name: str, value: str) -> Self: return _construct_query(type(self), tuple(entries)) def with_set(self, name: str, *values: str) -> Self: + """Return a new ``QueryParams`` with ``name`` set to exactly ``values``. + + If no values are provided, the parameter is removed (mirroring + ``Headers.with_set``) rather than left behind as an empty entry. + """ + if not values: + return self.without(name) entries: list[tuple[str, tuple[str, ...]]] = [] replaced = False for key, existing in self._data: diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/context/context_store.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/context/context_store.py index 80a9e69..2128a2a 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/context/context_store.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/context/context_store.py @@ -24,6 +24,16 @@ class _ContextStore: Thread-safe — every operation acquires the lock so the guarantee survives free-threaded CPython (PEP 703) and non-CPython runtimes that do not guarantee atomic dict ops. + + Two writers coexist deliberately. ``put`` is a *guarded install* that + raises on a duplicate trace id; it is part of the public surface for + callers that own a trace id exclusively and want a duplicate to surface + as a programming error. ``set`` is the *unconditional overwrite* the + promotion chain (``DispatchContext.to_request_context`` → + ``RequestContext.to_exchange_context``) relies on, where the first + promotion installs the entry and later promotions replace it in place. + ``put`` therefore has no internal callers, but removing it would narrow + the public, test-covered surface — so it stays. """ __slots__ = ("_contexts", "_lock") @@ -40,6 +50,14 @@ def get(self, trace_id: str) -> CallContext | None: def put(self, trace_id: str, context: CallContext) -> None: """Register ``context`` under ``trace_id``; reject duplicate ids. + Guarded install for callers that own a trace id exclusively: a + re-registration is treated as a programming error. The promotion + chain uses `set` instead, which overwrites unconditionally. + + Args: + trace_id: The key to register ``context`` under. + context: The context to store. + Raises: ValueError: if ``trace_id`` is already registered. """ diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/context/exchange_context.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/context/exchange_context.py index 5e761be..06b6629 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/context/exchange_context.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/context/exchange_context.py @@ -28,6 +28,11 @@ class ExchangeContext(CallContext): an ``AsyncResponse`` — the immutable snapshot is recorded regardless of which pipeline produced it. + ``request`` is the request that actually produced ``response`` (i.e. + ``response.request``), not necessarily the original request the call + started with. After a redirect the two differ; recording the per-hop + request keeps the ``request`` / ``response`` pair consistent. + Note: ``slots=True`` is intentionally omitted here. Mixing a slotted dataclass into a non-slotted ABC base (``CallContext``) produces a layout that still allocates ``__dict__``, so the slots flag would not diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/context/request_context.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/context/request_context.py index cb9ceb5..f3a3596 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/context/request_context.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/context/request_context.py @@ -36,14 +36,26 @@ def to_exchange_context( ) -> ExchangeContext: """Promote into an `ExchangeContext` bound to ``response``. - Stores the new context in `ContextStore` keyed by trace id. + The promoted context records ``response.request`` — the request that + actually produced ``response`` — rather than ``self.request``. After a + redirect the per-hop request differs from the original one this + ``RequestContext`` was opened with; recording the response's own + request keeps ``ExchangeContext.request`` and ``response`` a pair that + truly traveled together. + + Args: + response: The response that arrived for this call. + + Returns: + The new `ExchangeContext`, also stored in `ContextStore` keyed by + trace id. """ from .context_store import ContextStore from .exchange_context import ExchangeContext promoted = ExchangeContext( instrumentation_context=self.instrumentation_context, - request=self.request, + request=response.request, response=response, ) ContextStore.set(promoted.instrumentation_context.trace_id.value, promoted) diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/request/file_request_body.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/request/file_request_body.py index d29b2a6..5baee55 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/request/file_request_body.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/request/file_request_body.py @@ -102,6 +102,9 @@ def to_replayable(self) -> RequestBody: def iter_bytes(self, chunk_size: int = _DEFAULT_CHUNK) -> Iterator[bytes]: _check_chunk_size(chunk_size) + return self._iter(chunk_size) + + def _iter(self, chunk_size: int) -> Iterator[bytes]: remaining = self._count with self._path.open("rb") as stream: if self._offset: diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/request/loggable_request_body.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/request/loggable_request_body.py index e149a56..8c69c46 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/request/loggable_request_body.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/request/loggable_request_body.py @@ -70,6 +70,16 @@ def to_replayable(self) -> RequestBody: return LoggableRequestBody(self._inner.to_replayable(), self._max) def iter_bytes(self, chunk_size: int = 64 * 1024) -> Iterator[bytes]: + # Reset the tap at the START of each iteration so a replayed body + # (retry, redirect 307) captures a single copy of the payload rather + # than accumulating ``body + body`` across attempts. The reset is + # eager — before the first ``next()`` — so ``snapshot`` reflects only + # the most recent attempt even if the returned iterator is not drained. + self._tap.seek(0) + self._tap.truncate(0) + return self._iter(chunk_size) + + def _iter(self, chunk_size: int) -> Iterator[bytes]: for chunk in self._inner.iter_bytes(chunk_size): remaining = self._max - self._tap.tell() if remaining > 0: diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/request/multipart.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/request/multipart.py index aa7367a..4da5bea 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/request/multipart.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/request/multipart.py @@ -39,6 +39,27 @@ def _has_filename_star_header(headers: Sequence[tuple[str, str]]) -> bool: return any("filename*=" in v for _, v in headers) +def _reject_control_chars(label: str, value: str) -> None: + """Reject CR, LF, and NUL to prevent multipart header injection. + + Field names, filenames, and custom part-header names/values are + interpolated verbatim into CRLF-delimited part headers. An attacker- + controlled ``\\r`` or ``\\n`` (both ASCII, so the ASCII guard lets them + through) could inject additional part headers or a fabricated boundary + line. This mirrors ``Headers._check_value`` in ``http.common.headers``. + + Args: + label: Human-readable description of the rejected value, used in + the error message (e.g. ``"field name"``). + value: The candidate string to validate. + + Raises: + ValueError: If ``value`` contains ``\\r``, ``\\n``, or ``\\0``. + """ + if "\r" in value or "\n" in value or "\0" in value: + raise ValueError(f"multipart {label} contains control characters: {value!r}") + + @dataclass(frozen=True, slots=True) class MultipartField: """One part of a ``multipart/form-data`` body. @@ -56,9 +77,10 @@ class MultipartField: headers: Optional extra headers as ``(name, value)`` pairs. Raises: - ValueError: If ``name`` is not ASCII, or if ``filename`` is not - ASCII and no ``filename*=`` parameter was provided through - ``headers``. + ValueError: If ``name`` is not ASCII; if ``filename`` is not ASCII + and no ``filename*=`` parameter was provided through ``headers``; + or if ``name``, ``filename``, the rendered ``media_type``, or any + custom header name/value contains CR, LF, or NUL. """ name: str @@ -70,6 +92,9 @@ class MultipartField: def __post_init__(self) -> None: if not _is_ascii(self.name): raise ValueError(f"multipart field name must be pure ASCII: {self.name!r}") + _reject_control_chars("field name", self.name) + if self.filename is not None: + _reject_control_chars("filename", self.filename) if ( self.filename is not None and not _is_ascii(self.filename) @@ -80,6 +105,14 @@ def __post_init__(self) -> None: "MultipartField.with_utf8_filename(...) or supply a " f"filename*=UTF-8''… header: {self.filename!r}" ) + if self.media_type is not None: + # The rendered media type becomes a ``Content-Type:`` part header, + # so a subtype or parameter value carrying CR/LF would inject an + # extra header line just like a malicious filename. + _reject_control_chars("media type", str(self.media_type)) + for header_name, header_value in self.headers: + _reject_control_chars("header name", header_name) + _reject_control_chars("header value", header_value) @classmethod def with_utf8_filename( @@ -187,6 +220,12 @@ class MultipartRequestBody(RequestBody): Build via ``RequestBody.from_multipart(fields)`` or instantiate directly. The boundary is generated once at construction so retries see identical bytes (and so loggable wrappers can capture the payload deterministically). + A caller-supplied ``boundary`` is rejected if it contains CR, LF, or NUL, + since it is interpolated into delimiter and header lines. + + Raises: + ValueError: If ``fields`` is empty, or if ``boundary`` contains CR, + LF, or NUL. """ __slots__ = ("_boundary", "_payload") @@ -200,6 +239,10 @@ def __init__( if not fields: raise ValueError("at least one field is required") self._boundary = boundary or _generate_boundary() + # The boundary is interpolated into every ``--boundary`` delimiter line + # and the ``Content-Type`` header, so a caller-supplied boundary with + # CR/LF/NUL would inject delimiter or header lines into the payload. + _reject_control_chars("boundary", self._boundary) parts: list[bytes] = [_build_part(f, self._boundary) for f in fields] parts.append(f"--{self._boundary}--\r\n".encode("ascii")) self._payload = b"".join(parts) @@ -222,6 +265,9 @@ def to_replayable(self) -> RequestBody: def iter_bytes(self, chunk_size: int = 64 * 1024) -> Iterator[bytes]: _check_chunk_size(chunk_size) + return self._iter(chunk_size) + + def _iter(self, chunk_size: int) -> Iterator[bytes]: view = memoryview(self._payload) for start in range(0, len(view), chunk_size): yield bytes(view[start : start + chunk_size]) diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/request/request_body.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/request/request_body.py index adee0ab..4c6e7b8 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/request/request_body.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/request/request_body.py @@ -310,6 +310,9 @@ def to_replayable(self) -> RequestBody: def iter_bytes(self, chunk_size: int = 64 * 1024) -> Iterator[bytes]: _check_chunk_size(chunk_size) + return self._iter(chunk_size) + + def _iter(self, chunk_size: int) -> Iterator[bytes]: view = memoryview(self._data) for start in range(0, len(view), chunk_size): yield bytes(view[start : start + chunk_size]) @@ -346,6 +349,9 @@ def iter_bytes(self, chunk_size: int = 64 * 1024) -> Iterator[bytes]: "if retries may be needed." ) self._consumed = True + return self._iter(chunk_size) + + def _iter(self, chunk_size: int) -> Iterator[bytes]: try: while True: chunk = self._stream.read(chunk_size) @@ -388,6 +394,9 @@ def iter_bytes(self, chunk_size: int = 64 * 1024) -> Iterator[bytes]: "if retries may be needed." ) self._consumed = True + return self._iter() + + def _iter(self) -> Iterator[bytes]: yield from self._chunks diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/response/loggable_response_body.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/response/loggable_response_body.py index 6916408..0b59c7d 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/response/loggable_response_body.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/response/loggable_response_body.py @@ -10,7 +10,7 @@ from typing import Final from ..common.media_type import MediaType -from .response_body import ResponseBody +from .response_body import ResponseBody, _check_chunk_size _DEFAULT_CAP: Final[int] = (1 << 31) - 9 @@ -72,6 +72,7 @@ def content_length(self) -> int: return self._inner.content_length() def iter_bytes(self, chunk_size: int = 64 * 1024) -> Iterator[bytes]: + _check_chunk_size(chunk_size) self._drain() if self._error is not None: raise self._error diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/response/status.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/response/status.py index f8d8854..522eac6 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/response/status.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/response/status.py @@ -13,6 +13,15 @@ class Status(IntEnum): Inheriting from `int` so callers can compare against integers and range-check directly: ``response.status == 200`` or ``200 <= status < 300``. + + Lookup is lenient: `Status(code)` for any integer in the HTTP range + 100..599 that is not a named member returns a synthesized pseudo-member + named ``UNKNOWN_`` carrying the raw integer value. This lets + responses with unregistered-but-valid codes (for example ``218`` from + Apache or ``599`` from a proxy) flow through the SDK with their band + classification (`is_success`, `is_redirect`, ...) and integer comparisons + intact, instead of being discarded. Integers outside 100..599 (for + example ``42`` or ``1000``) remain invalid and raise `ValueError`. """ # 1xx Informational @@ -87,6 +96,25 @@ class Status(IntEnum): NOT_EXTENDED = 510 NETWORK_AUTHENTICATION_REQUIRED = 511 + @classmethod + def _missing_(cls, value: object) -> Status | None: + """Synthesize a pseudo-member for an unregistered valid HTTP code. + + Args: + value: The lookup value passed to `Status(value)`. + + Returns: + A pseudo-member carrying `value` when it is an integer in the + HTTP range 100..599 with no named member, or `None` to let the + enum machinery raise `ValueError` for any other input. + """ + if isinstance(value, int) and 100 <= value <= 599: + pseudo = int.__new__(cls, value) + pseudo._name_ = f"UNKNOWN_{value}" + pseudo._value_ = value + return pseudo + return None + @property def is_informational(self) -> bool: return 100 <= self.value < 200 diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/sse/parser.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/sse/parser.py index 8ee9aff..eef03bb 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/sse/parser.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/http/sse/parser.py @@ -13,7 +13,10 @@ - Joins multi-line ``data:`` fields with a single ``\n``. - Skips comment lines (those beginning with ``:``). - Treats blank lines as event terminators; empty events are not emitted. -- Decodes payloads as UTF-8 (the spec mandates UTF-8 for ``text/event-stream``). +- Decodes payloads as UTF-8 (the spec mandates UTF-8 for ``text/event-stream``); + invalid bytes in a complete line are replaced with U+FFFD rather than raising, + per the spec. A codepoint truncated by the stream ending surfaces as a + ``StreamingError`` from `SseParser.end`. Per-event fields: - ``data``: the (possibly multi-line) message payload. @@ -238,8 +241,9 @@ def _process_line(self, line: str) -> None: def _dispatch(self) -> None: if not self._data_lines: - # Spec: blank line with no data buffered ⇒ no event emitted, - # but event name and retry reset. + # Spec: blank line with no data buffered ⇒ no event emitted, and the + # event name resets to the default. ``retry`` is connection-level and + # deliberately persists, so it is not reset here. self._event = "message" return event = SseEvent( @@ -281,23 +285,44 @@ def _read_line(buffer: bytearray, *, at_eos: bool) -> tuple[str | None, int]: """ for index, byte in enumerate(buffer): if byte == _LF: - return buffer[:index].decode("utf-8"), index + 1 + return _decode_line(buffer[:index]), index + 1 if byte == _CR: if index + 1 < len(buffer): # CRLF when the next byte is LF, otherwise a lone-CR terminator. consumed = index + 2 if buffer[index + 1] == _LF else index + 1 - return buffer[:index].decode("utf-8"), consumed + return _decode_line(buffer[:index]), consumed # CR is the final byte: hold it open until the next byte arrives so # a split CRLF is not mistaken for two terminators. At EOS it is a # complete terminator. if not at_eos: return None, 0 - return buffer[:index].decode("utf-8"), index + 1 + return _decode_line(buffer[:index]), index + 1 if at_eos and buffer: + # Unterminated trailing residue at end-of-stream. Decode strictly so a + # truncated multi-byte codepoint cut off by the stream ending surfaces + # as a ``UnicodeDecodeError`` for ``end`` to wrap as ``StreamingError``. return buffer.decode("utf-8"), len(buffer) return None, 0 +def _decode_line(raw: bytearray) -> str: + """Decode one complete SSE line as UTF-8 with replacement on bad bytes. + + The WHATWG ``text/event-stream`` spec mandates UTF-8 decoding with the + replacement character (U+FFFD) for invalid sequences rather than failing, + so malformed bytes inside an otherwise complete line never raise out of + `SseParser.feed`. Truncated codepoints at end-of-stream are handled + separately (strictly) by `_read_line` so they can surface as errors. + + Args: + raw: The line's bytes, excluding its terminator. + + Returns: + The decoded line text, with U+FFFD substituted for invalid bytes. + """ + return bytes(raw).decode("utf-8", errors="replace") + + def parse_events(chunks: Iterable[bytes]) -> Iterator[SseEvent]: """Drive an ``SseParser`` from a sync iterable of byte chunks. diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/client_logger.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/client_logger.py index d510a27..9c5901d 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/client_logger.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/client_logger.py @@ -6,6 +6,7 @@ from __future__ import annotations import logging +import threading from typing import Any, Final from .correlation import get_span_id, get_trace_id @@ -18,6 +19,12 @@ LogLevel.VERBOSE: logging.DEBUG, } +#: Serialises the check-then-act in ``_install_correlation_filter`` so two +#: threads constructing loggers concurrently cannot both add a filter. A single +#: process-wide lock is sufficient: installation is a one-time, low-contention +#: operation gated by an idempotent membership check. +_INSTALL_LOCK: Final[threading.Lock] = threading.Lock() + class CorrelationFilter(logging.Filter): """Stamps the active trace/span ids onto every record it sees. @@ -97,10 +104,16 @@ def verbose(self, message: str, **fields: Any) -> None: def _install_correlation_filter(logger: logging.Logger) -> None: - """Attach a `CorrelationFilter` to ``logger`` exactly once.""" - if any(isinstance(existing, CorrelationFilter) for existing in logger.filters): - return - logger.addFilter(CorrelationFilter()) + """Attach a `CorrelationFilter` to ``logger`` exactly once. + + The membership check and the ``addFilter`` call run under a process-wide + lock so concurrent ``ClientLogger`` construction on the same logger can + never install two filters (the check-then-act would otherwise race). + """ + with _INSTALL_LOCK: + if any(isinstance(existing, CorrelationFilter) for existing in logger.filters): + return + logger.addFilter(CorrelationFilter()) def _correlation_fields() -> dict[str, str]: @@ -119,7 +132,7 @@ def _format_fields(fields: dict[str, Any]) -> str: parts: list[str] = [] for key, value in fields.items(): rendered = str(value) - if any(c in rendered for c in ' \t"\n\r'): + if any(c in rendered for c in ' \t"\n\r='): rendered = ( '"' + rendered.replace("\\", "\\\\") diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/http_tracer.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/http_tracer.py index 09a574f..a93554c 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/http_tracer.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/http_tracer.py @@ -78,11 +78,13 @@ def request_url_resolved(self, url: str) -> None: url: The absolute URL the attempt targets. """ - def request_sent(self, byte_count: int) -> None: + def request_sent(self, byte_count: int | None) -> None: """The request body finished writing to the wire. Args: - byte_count: Number of body bytes written. + byte_count: Number of body bytes written, or ``None`` when the + length is unknown (for example a streamed body whose size is + not declared up front). """ def response_headers_received(self, status: int, headers: Mapping[str, str]) -> None: diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/metrics.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/metrics.py index 288376d..05fc009 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/metrics.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/metrics.py @@ -1,7 +1,7 @@ # Copyright (c) 2026 dexpace and Omar Aljarrah. # Licensed under the MIT License. See LICENSE.md in the repository root for details. -"""Metrics SPI — Counter / Histogram / Gauge ABCs and no-op singletons. +"""Metrics SPI — Counter / UpDownCounter / Histogram ABCs and no-op singletons. Mirrors the OpenTelemetry-style metric primitives but stays no-deps. A real implementation lives in a sibling package (e.g. diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/url_redactor.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/url_redactor.py index 23b492c..760ab25 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/url_redactor.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/instrumentation/url_redactor.py @@ -38,6 +38,10 @@ _REDACTED: Final[str] = "REDACTED" _REDACTED_PATH: Final[str] = "/REDACTED" +#: Returned in place of a URL that cannot be parsed. Failing closed keeps an +#: unparseable, possibly secret-bearing string out of the logs entirely. +_REDACTED_UNPARSEABLE: Final[str] = "REDACTED:unparseable" + class UrlRedactor: """Strip userinfo and non-allowlisted query parameters from a URL. @@ -75,18 +79,20 @@ def redact(self, url: str | Url) -> str: Args: url: Either a parsed ``Url`` or a wire-form string. Strings are - parsed via ``Url.parse``; parse failures fall through to - returning the input unchanged (so logging never silently - drops a URL because it's malformed). + parsed via ``Url.parse``; parse failures fail closed and return + the constant ``"REDACTED:unparseable"`` rather than the input, + so a malformed string that may embed a secret never reaches the + log sink verbatim. Returns: A wire-form URL with userinfo stripped and each non-allowlisted parameter collapsed to ``REDACTED=REDACTED`` (both key and value), - so neither the secret nor the parameter name leaks. + so neither the secret nor the parameter name leaks. When the input + string cannot be parsed, returns ``"REDACTED:unparseable"``. """ parsed = url if isinstance(url, Url) else _safe_parse(url) if parsed is None: - return str(url) + return _REDACTED_UNPARSEABLE return str(self._redact_parsed(parsed)) def _redact_parsed(self, parsed: Url) -> Url: @@ -111,9 +117,24 @@ def _redact_parsed(self, parsed: Url) -> Url: def _safe_parse(raw: str) -> Url | None: + """Parse ``raw`` into a ``Url`` or return ``None`` if it is malformed. + + ``Url.parse`` delegates to ``furl``, which raises stdlib exceptions for + bad input (``ValueError`` for missing scheme/host or an invalid port, + ``UnicodeError`` for IDNA failures, and ``LookupError`` / ``TypeError`` + for other structural surprises). Catching the broad set keeps the caller + failing closed instead of letting an exotic furl error escape into the + logging path. + + Args: + raw: A wire-form URL string. + + Returns: + The parsed ``Url``, or ``None`` if ``raw`` could not be parsed. + """ try: return Url.parse(raw) - except ValueError: + except (ValueError, LookupError, TypeError): return None diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/link_header.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/link_header.py index c888ad8..efc50a1 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/link_header.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/link_header.py @@ -69,9 +69,17 @@ def _iter_links(value: str) -> Iterator[ParsedLink]: def _split_links(value: str) -> Iterator[str]: - """Split a header into link-value segments, ignoring commas inside quotes.""" + """Split a header into link-value segments on the link-value separators. + + Commas inside a quoted parameter value (between ``"``) or inside the + bracketed ```` target are not separators: a comma is a legal + unencoded URI sub-delim (e.g. ``?fields=a,b``), so splitting on it would + shred the target. The angle-bracket depth is tracked alongside quote + state and suppresses the split while inside ``<...>``. + """ buffer: list[str] = [] in_quotes = False + in_brackets = False escaped = False for char in value: if escaped: @@ -83,7 +91,13 @@ def _split_links(value: str) -> Iterator[str]: elif char == '"': in_quotes = not in_quotes buffer.append(char) - elif char == "," and not in_quotes: + elif char == "<" and not in_quotes: + in_brackets = True + buffer.append(char) + elif char == ">" and not in_quotes: + in_brackets = False + buffer.append(char) + elif char == "," and not in_quotes and not in_brackets: yield "".join(buffer) buffer = [] else: diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/paginator.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/paginator.py index 1ac7c1c..89c35e1 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/paginator.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/paginator.py @@ -26,10 +26,12 @@ from __future__ import annotations +import inspect import json from collections.abc import AsyncIterator, Callable, Iterator from typing import TYPE_CHECKING +from ..errors import DeserializationError from ..http.context.dispatch_context import DispatchContext from ..pipeline.dispatch import ( AsyncPipelineLike, @@ -47,11 +49,59 @@ def _decode_body(raw: str) -> object: - """Decode a JSON body string into a Python value (``None`` when empty).""" + """Decode a JSON body string into a Python value (``None`` when empty). + + Args: + raw: The response body text. An empty or whitespace-only body + decodes to ``None``. + + Returns: + The decoded Python value, or ``None`` for an empty body. + + Raises: + DeserializationError: When the body is not well-formed JSON (e.g. an + HTML error page returned with a 200 by a load balancer). The + underlying ``json.JSONDecodeError`` never escapes the SDK error + hierarchy. + """ text = raw.strip() if not text: return None - return json.loads(text) + try: + return json.loads(text) + except json.JSONDecodeError as err: + raise DeserializationError("pagination response body is not valid JSON") from err + + +def _decode_for(raw: str | None, request: Request) -> object: + """Decode a page body, stamping the failing request for resumption. + + Mirrors ``Pager``'s resume contract: when decoding fails, the request + URL that produced the unparseable page is stamped onto the + ``DeserializationError`` (as its ``continuation_token``) so a caller can + rebuild the same request and retry from exactly that page rather than + restarting the whole sequence. + + Args: + raw: The response body text, or ``None`` when the response had no + body. + request: The request that produced ``raw``; its URL becomes the + resume token on a decode failure. + + Returns: + The decoded Python value, or ``None`` for an absent or empty body. + + Raises: + DeserializationError: When the body is not well-formed JSON. + """ + if raw is None: + return None + try: + return _decode_body(raw) + except DeserializationError as err: + if err.continuation_token is None: + err.continuation_token = str(request.url) + raise class Paginator[T]: @@ -89,11 +139,20 @@ def __init__( def _normalise(self, source: SyncPipelineLike | SendSync) -> SendSync: if isinstance(source, SyncPipelineLike): pipeline = source + if inspect.iscoroutinefunction(pipeline.run): + raise TypeError( + "Paginator was given an async pipeline; its run() is a " + "coroutine function. Use AsyncPaginator for async pipelines.", + ) def send(request: Request) -> Response: return pipeline.run(request, self._dispatch_factory()) return send + if inspect.iscoroutinefunction(source): + raise TypeError( + "Paginator was given an async send-callable; use AsyncPaginator.", + ) return source def by_page(self) -> Iterator[Page[T]]: @@ -116,7 +175,8 @@ def by_page(self) -> Iterator[Page[T]]: request = page.next_request def _parse(self, response: Response) -> Page[T]: - payload = _decode_body(response.body.string()) if response.body is not None else None + raw = response.body.string() if response.body is not None else None + payload = _decode_for(raw, response.request) return self._strategy.parse(response, payload, response.request) def __iter__(self) -> Iterator[T]: @@ -152,11 +212,20 @@ def __init__( def _normalise(self, source: AsyncPipelineLike | SendAsync) -> SendAsync: if isinstance(source, AsyncPipelineLike): pipeline = source + if not inspect.iscoroutinefunction(pipeline.run): + raise TypeError( + "AsyncPaginator was given a sync pipeline; its run() is " + "not a coroutine function. Use Paginator for sync pipelines.", + ) async def send(request: Request) -> AsyncResponse: return await pipeline.run(request, self._dispatch_factory()) return send + if not inspect.iscoroutinefunction(source): + raise TypeError( + "AsyncPaginator was given a sync send-callable; use Paginator.", + ) return source async def by_page(self) -> AsyncIterator[Page[T]]: @@ -173,7 +242,8 @@ async def by_page(self) -> AsyncIterator[Page[T]]: request = page.next_request async def _parse(self, response: AsyncResponse) -> Page[T]: - payload = _decode_body(await response.body.string()) if response.body is not None else None + raw = await response.body.string() if response.body is not None else None + payload = _decode_for(raw, response.request) return self._strategy.parse(response, payload, response.request) def __aiter__(self) -> AsyncIterator[T]: diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/strategy.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/strategy.py index 5aa09e7..8e2d029 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/strategy.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pagination/strategy.py @@ -29,7 +29,7 @@ from urllib.parse import urljoin from ..http.common.url import Url -from .link_header import find_rel +from .link_header import parse_link_header from .page import Page if TYPE_CHECKING: @@ -102,6 +102,52 @@ def _with_query_param(request: Request, name: str, value: str) -> Request: return request.with_url(url.with_query(url.query.with_set(name, value))) +def _rel_targets(header: str) -> dict[str, str]: + """Map each relation type in ``header`` to its first link-value target. + + The header is parsed once and every link-value's space-separated ``rel`` + set is expanded so a single pass resolves any relation (``next``, + ``prev``, ...). The first target wins per relation, matching the + first-match semantics of a direct lookup. + + Args: + header: The folded raw ``Link`` header value (all lines joined). + + Returns: + A mapping from lower-cased relation type to its target URI. + """ + targets: dict[str, str] = {} + for target, params in parse_link_header(header): + for rel in params.get("rel", "").split(): + targets.setdefault(rel.casefold(), target) + return targets + + +def _coerce_cursor(cursor: object) -> str: + """Coerce a decoded cursor value to its query-parameter string form. + + Real APIs return cursors as strings *or* numbers (``"next_cursor": + 17283``); a numeric cursor must still drive the next request rather than + silently ending the sequence. Strings pass through, ``int`` and ``float`` + are stringified, and everything else (``None``, ``bool``, containers) + yields the empty string, which the caller treats as exhaustion. + + Args: + cursor: The decoded cursor value dug out of the response body. + + Returns: + The cursor as a non-empty string when it is a usable scalar, or the + empty string when the sequence is exhausted. + """ + if isinstance(cursor, str): + return cursor + if isinstance(cursor, bool): + return "" + if isinstance(cursor, (int, float)): + return str(cursor) + return "" + + @dataclass(frozen=True, slots=True) class CursorStrategy[T]: """Cursor / continuation-token pagination. @@ -115,7 +161,9 @@ class CursorStrategy[T]: items_field: Dotted path to the item list in the body (e.g. ``"data"`` or ``"result.items"``). cursor_response_field: Dotted path to the cursor in the body. An - absent, empty, or ``null`` value ends the sequence. + absent, empty, ``null``, or boolean value ends the sequence; a + non-empty scalar cursor (string, ``int``, or ``float``) is sent + verbatim, coerced to its string form for the query parameter. cursor_param: Query-parameter name to carry the cursor on the next request. """ @@ -131,9 +179,9 @@ def parse( template_request: Request, ) -> Page[T]: items: list[T] = _items_at(payload, self.items_field.split(".")) - cursor = _dig(payload, self.cursor_response_field.split(".")) + cursor = _coerce_cursor(_dig(payload, self.cursor_response_field.split("."))) next_request: Request | None = None - if isinstance(cursor, str) and cursor: + if cursor: next_request = _with_query_param(template_request, self.cursor_param, cursor) return Page(items=items, next_request=next_request, raw=response) @@ -213,6 +261,12 @@ class LinkHeaderStrategy[T]: RFC 5988) is resolved against the template request's URL, so an API that returns ```` rather than an absolute URI still paginates. + RFC 9110 permits the ``Link`` header to be split across multiple header + lines (a ``rel="prev"`` line and a ``rel="next"`` line, say); every line + is folded together before parsing so a ``next`` relation on a later line + is never dropped. The header is parsed a single time and both the + ``next`` and ``prev`` targets are resolved from that one pass. + Args: items_field: Dotted path to the item list in the body. link_header_name: Header to read link relations from (default @@ -229,9 +283,10 @@ def parse( template_request: Request, ) -> Page[T]: items: list[T] = _items_at(payload, self.items_field.split(".")) - header = response.headers.get(self.link_header_name) or "" - next_request = self._request_for(header, "next", template_request) - prev_request = self._request_for(header, "prev", template_request) + header = ", ".join(response.headers.values(self.link_header_name)) + targets = _rel_targets(header) + next_request = self._request_for(targets.get("next"), template_request) + prev_request = self._request_for(targets.get("prev"), template_request) return Page( items=items, next_request=next_request, @@ -240,8 +295,7 @@ def parse( ) @staticmethod - def _request_for(header: str, rel: str, template: Request) -> Request | None: - target = find_rel(header, rel) + def _request_for(target: str | None, template: Request) -> Request | None: if target is None: return None absolute = urljoin(str(template.url), target) diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/_async_transport_runner.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/_async_transport_runner.py index 38b387a..9977c5d 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/_async_transport_runner.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/_async_transport_runner.py @@ -21,7 +21,10 @@ class _AsyncTransportRunner(AsyncPolicy): Like the sync transport runner, side-effects the promotion-chain context to record the exchange in the ``ContextStore`` after the - response arrives, but does not reassign ``ctx.call``. + response arrives, but does not reassign ``ctx.call``. The recorded + ``ExchangeContext.request`` is ``response.request`` — the per-hop request + that produced the response, which differs from ``ctx.call.request`` after + a redirect. """ __slots__ = ("_client",) diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/_transport_runner.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/_transport_runner.py index 9f02b95..16f00ce 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/_transport_runner.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/_transport_runner.py @@ -23,7 +23,10 @@ class _TransportRunner(Policy): runner promotes the immutable telemetry context to an ``ExchangeContext`` once the response is in hand so post-exchange observers (logging, tracing) can look up the latest snapshot via ``ContextStore``. The - promotion is a snapshot update; ``ctx.call`` itself is not reassigned. + promotion records ``response.request`` — the per-hop request that actually + produced the response, which differs from ``ctx.call.request`` after a + redirect. The promotion is a snapshot update; ``ctx.call`` itself is not + reassigned. """ __slots__ = ("_client",) diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/async_pipeline.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/async_pipeline.py index 40b8d2a..9323695 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/async_pipeline.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/async_pipeline.py @@ -6,6 +6,7 @@ from __future__ import annotations from collections.abc import Awaitable, Callable, Sequence +from itertools import pairwise from types import TracebackType from typing import TYPE_CHECKING, Any, Self @@ -42,9 +43,14 @@ class AsyncPipeline: """Composes an ordered sequence of async policies around an ``AsyncHttpClient``. - Mirrors ``Pipeline`` exactly with ``async`` semantics. SansIO steps are - auto-wrapped via the ``side`` attribute (``"request"`` / ``"response"``, - default ``"request"``). The terminal node is an ``_AsyncTransportRunner``. + Mirrors ``Pipeline`` exactly with ``async`` semantics. Bare-callable + SansIO steps are auto-wrapped via an optional ``side`` attribute on the + callable (``"request"`` / ``"response"``, default ``"request"``). The + terminal node is an ``_AsyncTransportRunner``. + + Each ``AsyncPolicy`` instance is owned by a single pipeline: passing one + already wired into another pipeline raises ``ValueError`` rather than + silently re-pointing the original chain. Use as an async context manager so transport ``aclose`` (when defined) runs deterministically:: @@ -65,11 +71,8 @@ def __init__( entry if isinstance(entry, AsyncPolicy) else _wrap_step(entry) for entry in (policies or []) ] - for i, policy in enumerate(wrapped[:-1]): - policy.next = wrapped[i + 1] terminal = _AsyncTransportRunner(transport) - if wrapped: - wrapped[-1].next = terminal + _wire_chain(wrapped, terminal) self._chain: AsyncPolicy = wrapped[0] if wrapped else terminal async def __aenter__(self) -> Self: @@ -105,6 +108,36 @@ async def run( request_ctx.close() +def _wire_chain(wrapped: list[AsyncPolicy], terminal: AsyncPolicy) -> None: + """Link each policy's ``.next`` to the following node, ending at ``terminal``. + + Detects reuse before mutating any state: a caller-supplied policy whose + ``.next`` is already set belongs to another pipeline, and re-pointing it + here would silently corrupt that pipeline's chain. Such reuse raises + ``ValueError`` instead, leaving every instance untouched. + + Args: + wrapped: In-order policies; freshly wrapped SansIO runners carry no + ``.next`` yet, so only reused caller policies trip the guard. + terminal: The transport runner appended after the last policy. + + Raises: + ValueError: If any policy already has its ``.next`` wired, which + means it is owned by a different pipeline. + """ + for policy in wrapped: + if getattr(policy, "next", None) is not None: + raise ValueError( + f"{type(policy).__name__} is already wired into another pipeline; " + f"an AsyncPolicy instance is owned by a single pipeline. Construct a " + f"fresh instance for each pipeline instead of sharing one." + ) + for current, following in pairwise(wrapped): + current.next = following + if wrapped: + wrapped[-1].next = terminal + + def _wrap_step(step: Any) -> AsyncPolicy: if not callable(step): raise TypeError(f"Pipeline step {step!r} is neither an AsyncPolicy nor a callable.") diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/async_staged_builder.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/async_staged_builder.py index 34be840..9b256e9 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/async_staged_builder.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/async_staged_builder.py @@ -11,6 +11,7 @@ from __future__ import annotations +import contextlib from typing import TYPE_CHECKING, Self from .async_pipeline import AsyncPipeline @@ -55,11 +56,15 @@ def prepend(self, policy: AsyncPolicy, *, force: bool = False) -> Self: def replace(self, target: type[AsyncPolicy], new: AsyncPolicy) -> Self: """Replace the first instance of ``target`` with ``new``.""" - for stage, pillar in self._pillars.items(): - if isinstance(pillar, target): - del self._pillars[stage] - self.append(new, force=True) - return self + pillar_stage = next( + (stage for stage, pillar in self._pillars.items() if isinstance(pillar, target)), + None, + ) + if pillar_stage is not None: + # The lookup above finished iterating before we mutate ``_pillars``. + del self._pillars[pillar_stage] + self.append(new, force=True) + return self for stage, bucket in self._buckets.items(): for i, p in enumerate(bucket): if isinstance(p, target): @@ -96,6 +101,11 @@ def build(self) -> AsyncPipeline: def from_pipeline(cls, pipeline: AsyncPipeline) -> Self: """Seed a builder from an existing `AsyncPipeline`. + The harvested policy instances are detached from ``pipeline`` (their + ``.next`` links are cleared) so they can be re-wired into the rebuilt + pipeline. ``pipeline`` is consumed by this call — each policy is owned + by a single pipeline, so the source pipeline must not be run again. + Raises: ValueError: If the input pipeline's policies do not satisfy stage ordering, or if the chain contains a list-constructor @@ -126,6 +136,7 @@ def from_pipeline(cls, pipeline: AsyncPipeline) -> Self: f"non-decreasing stage order. Use the list constructor instead." ) last_stage = policy.STAGE + _detach(policy) builder.append(policy, force=True) return builder @@ -170,4 +181,19 @@ def _reload(self, policies: list[AsyncPolicy]) -> None: self._buckets.setdefault(p.STAGE, []).append(p) +def _detach(policy: AsyncPolicy) -> None: + """Clear ``policy.next`` so the instance can be re-wired into a new chain. + + A policy harvested from an existing pipeline still points at that + pipeline's chain. Clearing the link makes it look freshly constructed to + ``AsyncPipeline``'s single-ownership guard, allowing the rebuild to + re-wire it. + + Args: + policy: The policy whose ``.next`` link is removed if present. + """ + with contextlib.suppress(AttributeError): + del policy.next + + __all__ = ["AsyncStagedPipelineBuilder"] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/pipeline.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/pipeline.py index 3b65ae7..4a481e5 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/pipeline.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/pipeline.py @@ -6,6 +6,7 @@ from __future__ import annotations from collections.abc import Callable, Sequence +from itertools import pairwise from types import TracebackType from typing import TYPE_CHECKING, Any, Self @@ -40,11 +41,17 @@ class Pipeline: with Pipeline(transport, policies=[retry, auth, logger]) as p: response = p.run(request, dispatch_ctx) - SansIO steps in the list are auto-wrapped depending on whether they have - a ``request_side`` or ``response_side`` attribute (set on the callable - by the SDK's built-in step decorators) — by default the runner assumes a - request-side step. For policies that need explicit chain control - (retry, auth challenges), implement the ``Policy`` ABC directly. + Bare-callable SansIO steps in the list are auto-wrapped according to an + optional ``side`` attribute on the callable, valued ``"request"`` or + ``"response"``. Callables without a ``side`` attribute default to the + request side, which suits the common case (header stamping, redaction). + For policies that need explicit chain control (retry, auth challenges), + implement the ``Policy`` ABC directly. + + Each ``Policy`` instance is owned by a single pipeline: its ``.next`` is + wired in place at construction. Passing a policy instance that is already + wired into another pipeline raises ``ValueError`` rather than silently + re-pointing the original chain. Attributes: transport: The terminal HTTP client. @@ -69,16 +76,16 @@ def __init__( Raises: TypeError: If an entry in ``policies`` is neither a ``Policy`` nor a callable matching the SansIO step shape. + ValueError: If a ``Policy`` instance is already wired into + another pipeline (its ``.next`` is set). Each policy + instance is owned by a single pipeline. """ self.transport = transport wrapped: list[Policy] = [ entry if isinstance(entry, Policy) else _wrap_step(entry) for entry in (policies or []) ] - for i, policy in enumerate(wrapped[:-1]): - policy.next = wrapped[i + 1] terminal = _TransportRunner(transport) - if wrapped: - wrapped[-1].next = terminal + _wire_chain(wrapped, terminal) self._chain: Policy = wrapped[0] if wrapped else terminal def __enter__(self) -> Self: @@ -125,6 +132,36 @@ def run( request_ctx.close() +def _wire_chain(wrapped: list[Policy], terminal: Policy) -> None: + """Link each policy's ``.next`` to the following node, ending at ``terminal``. + + Detects reuse before mutating any state: a caller-supplied policy whose + ``.next`` is already set belongs to another pipeline, and re-pointing it + here would silently corrupt that pipeline's chain. Such reuse raises + ``ValueError`` instead, leaving every instance untouched. + + Args: + wrapped: In-order policies; freshly wrapped SansIO runners carry no + ``.next`` yet, so only reused caller policies trip the guard. + terminal: The transport runner appended after the last policy. + + Raises: + ValueError: If any policy already has its ``.next`` wired, which + means it is owned by a different pipeline. + """ + for policy in wrapped: + if getattr(policy, "next", None) is not None: + raise ValueError( + f"{type(policy).__name__} is already wired into another pipeline; " + f"a Policy instance is owned by a single pipeline. Construct a fresh " + f"instance for each pipeline instead of sharing one." + ) + for current, following in pairwise(wrapped): + current.next = following + if wrapped: + wrapped[-1].next = terminal + + def _wrap_step(step: Any) -> Policy: """Wrap a SansIO step in the right runner Policy. diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_redirect.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_redirect.py index 84ce23a..d60227b 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_redirect.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_redirect.py @@ -4,9 +4,11 @@ """Async twin of ``RedirectPolicy``. Mirrors `RedirectPolicy` exactly — same status-code matrix, same -credential stripping, same loop guard — but ``send`` is ``async`` and -operates on ``AsyncResponse``. The per-hop decision helpers are shared via -delegation to a wrapped sync ``RedirectPolicy`` instance. +cross-origin credential stripping (a caller-set ``Authorization`` header is +dropped only when the reissue crosses origin), same loop guard — but +``send`` is ``async`` and operates on ``AsyncResponse``. The per-hop +decision helpers are shared via delegation to a wrapped sync +``RedirectPolicy`` instance, so the cross-origin behaviour is identical. """ from __future__ import annotations diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_retry.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_retry.py index 22747dd..2426ba9 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_retry.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/async_retry.py @@ -7,6 +7,10 @@ delegating into the same private methods on ``RetryPolicy``. The async twin reimplements only the dispatch loop (using ``await``) and the sleep helper (using an async sleep callable). + +Auto-buffering a single-use body for replay drains a synchronous +iterator/stream, which is blocking; the async loop offloads that drain to a +worker thread via ``asyncio.to_thread`` so the event loop keeps running. """ from __future__ import annotations @@ -136,9 +140,15 @@ def no_retries(cls) -> AsyncRetryPolicy: async def send(self, request: Request, ctx: PipelineContext) -> AsyncResponse: cfg = self.config - if cfg.total_retries > 0 and request.body is not None and not request.body.is_replayable(): - request = request.with_body(request.body.to_replayable()) settings = cfg._configure_settings(ctx.options) + body = request.body + if settings["total"] > 0 and body is not None and not body.is_replayable(): + # ``to_replayable`` synchronously drains the sync iterator/stream + # backing the body. Offload that blocking read to a worker thread + # so the event loop is not stalled while a file or socket-backed + # body is buffered into memory. + replayed = await asyncio.to_thread(body.to_replayable) + request = request.with_body(replayed) absolute_deadline = self._clock.monotonic() + settings["timeout"] history: list[RequestHistory[AsyncResponse]] = settings["history"] tracer = resolve_http_tracer(ctx) @@ -156,7 +166,7 @@ async def send(self, request: Request, ctx: PipelineContext) -> AsyncResponse: raise except SdkError as err: history.append(RequestHistory(request=request, error=err)) - if not cfg._decrement_for_error(settings, err): + if not cfg._decrement_for_error(settings, request, err): tracer.attempt_retries_exhausted() ctx.data["retry_history"] = tuple(history) raise diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/redirect.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/redirect.py index 1d68a37..fcef930 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/redirect.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/redirect.py @@ -5,11 +5,12 @@ Walks the response's ``Location`` header through up to ``max_hops`` intermediate responses, following the per-status method/body rules from -RFC 7231 §6.4 and RFC 7538 / RFC 7231 §6.4.7. Credentials are stripped on -every reissue by default (``Authorization`` header dropped, ``userinfo`` -in the ``Location`` URL discarded); loops are detected via a visited-URL -set and cause the policy to return the current response instead of -raising. +RFC 7231 §6.4 and RFC 7538 / RFC 7231 §6.4.7. A caller-set +``Authorization`` header is stripped by default only when the reissue +crosses origin relative to the request being redirected; same-origin hops +keep it. ``userinfo`` in the ``Location`` URL is always discarded. Loops +are detected via a visited-URL set and cause the policy to return the +current response instead of raising. Status-code matrix: @@ -42,6 +43,27 @@ _REDIRECT_STATUSES: frozenset[int] = frozenset({301, 302, 303, 307, 308}) _CONTENT_HEADER_PREFIX: str = "content-" +_DEFAULT_PORTS: dict[str, int] = {"https": 443, "http": 80} + + +def _origin(url: Url) -> tuple[str, str, int | None]: + """Return the ``(scheme, host, port)`` origin tuple for ``url``. + + The scheme and host are lower-cased and the port is resolved to its + scheme default (443 for https, 80 for http) when not explicit, so two + URLs that differ only in an implied/explicit default port compare equal. + + Args: + url: The URL to derive an origin from. + + Returns: + A ``(scheme, host, effective_port)`` tuple suitable for equality + comparison. + """ + scheme = url.scheme.lower() + port = url.port if url.port is not None else _DEFAULT_PORTS.get(scheme) + return scheme, url.host.lower(), port + #: ``ctx.data`` key holding the per-operation ``HttpTracer``. The first policy #: in the chain to need it mints one from the call's @@ -91,10 +113,13 @@ class RedirectPolicy(Policy): allowed_methods: Methods that are followed on ``301`` / ``302`` / ``307`` / ``308``. ``303`` is always rewritten to ``GET`` (which is implicitly allowed). Defaults to ``{GET, HEAD}``. - strip_authorization: When ``True`` (the default), the - ``Authorization`` header is stripped before every redirect - reissue. Set ``False`` only when the redirect chain is - same-origin and the caller has audited the destinations. + strip_authorization: When ``True`` (the default), a caller-set + ``Authorization`` header is stripped before a redirect reissue + only when the reissue crosses origin — a change in scheme, host, + or effective port — relative to the request being redirected. + Same-origin hops (e.g. a trailing-slash 301) keep the header. + Set ``False`` to never strip, e.g. when the caller has audited + every destination in the chain. Example: ```python @@ -182,19 +207,20 @@ def _build_next_request( RuntimeError: 307/308 with a non-replayable body. """ next_url = self._resolve_location(request.url, location) + cross_origin = _origin(request.url) != _origin(next_url) if status == 303: if not self.follow_303: return None - return self._reissue_as_get(request, next_url) + return self._reissue_as_get(request, next_url, cross_origin=cross_origin) # 301, 302, 307, 308 all require the original method to be allowed. if request.method not in self.allowed_methods: return None if status in (307, 308): - return self._reissue_preserving_body(request, next_url) + return self._reissue_preserving_body(request, next_url, cross_origin=cross_origin) # 301 / 302: follow with the original method; body carries over # (matches Java's DefaultRedirectStep — caller can downgrade to GET # by setting follow_303 plus allowing only safe methods). - return self._reissue_preserving_body(request, next_url) + return self._reissue_preserving_body(request, next_url, cross_origin=cross_origin) def _resolve_location(self, base: Url, location: str) -> Url: """Resolve a possibly-relative Location header into an absolute Url. @@ -208,27 +234,58 @@ def _resolve_location(self, base: Url, location: str) -> Url: return parsed return replace(parsed, userinfo=None) - def _reissue_as_get(self, request: Request, next_url: Url) -> Request: + def _reissue_as_get( + self, + request: Request, + next_url: Url, + *, + cross_origin: bool, + ) -> Request: """Build the reissued GET for a 303 hop. Drops the request body and every ``Content-*`` header (per RFC 7231 - §6.4.4 — the body no longer applies to a GET). + §6.4.4 — the body no longer applies to a GET). A caller-set + ``Authorization`` header is stripped only on a cross-origin hop when + ``strip_authorization`` is enabled. + + Args: + request: The request being redirected (the current hop). + next_url: The resolved absolute target of the redirect. + cross_origin: Whether ``next_url`` differs in origin from + ``request.url``. """ stripped = request.with_method(Method.GET).with_url(next_url).with_body(None) for name in tuple(stripped.headers): if name.startswith(_CONTENT_HEADER_PREFIX): stripped = stripped.without_header(name) - if self.strip_authorization: + if self.strip_authorization and cross_origin: stripped = stripped.without_header("Authorization") return stripped - def _reissue_preserving_body(self, request: Request, next_url: Url) -> Request: + def _reissue_preserving_body( + self, + request: Request, + next_url: Url, + *, + cross_origin: bool, + ) -> Request: """Build the reissued request for 301/302/307/308 hops. 307/308 must preserve the body, so a non-replayable body raises ``RuntimeError`` — sending the same payload twice with a single-use body is not possible. 301/302 also carry the body (matches Java's - ``DefaultRedirectStep``); the same replay requirement applies. + ``DefaultRedirectStep``); the same replay requirement applies. A + caller-set ``Authorization`` header is stripped only on a cross-origin + hop when ``strip_authorization`` is enabled. + + Args: + request: The request being redirected (the current hop). + next_url: The resolved absolute target of the redirect. + cross_origin: Whether ``next_url`` differs in origin from + ``request.url``. + + Raises: + RuntimeError: 307/308 with a non-replayable body. """ body = request.body if body is not None and not body.is_replayable(): @@ -238,7 +295,7 @@ def _reissue_preserving_body(self, request: Request, next_url: Url) -> Request: "expected." ) reissued = request.with_url(next_url) - if self.strip_authorization: + if self.strip_authorization and cross_origin: reissued = reissued.without_header("Authorization") return reissued diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/retry.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/retry.py index fb2e758..0b94b73 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/retry.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/retry.py @@ -10,11 +10,14 @@ Single-use request bodies (``RequestBody.from_stream`` / ``RequestBody.from_iter``) are auto-buffered at the top of ``send`` when -``total_retries > 0``: the policy calls ``body.to_replayable()`` and swaps -the result onto the request before the first attempt, so a retry can -re-emit the same payload without raising ``RuntimeError``. The buffering -step is skipped when ``total_retries == 0`` so callers who explicitly opt -out of retries pay no memory cost. +the *effective* per-call retry total is positive: the policy calls +``body.to_replayable()`` and swaps the result onto the request before the +first attempt, so a retry can re-emit the same payload without raising +``RuntimeError``. The decision reads the effective total after merging +per-call overrides (``retry_total`` in ``ctx.options``) with the +constructor default, so a per-call ``retry_total=3`` over an instance built +with ``total_retries=0`` still buffers, and a per-call ``retry_total=0`` +over a retrying instance skips the buffering and pays no memory cost. """ from __future__ import annotations @@ -220,9 +223,9 @@ def no_retries(cls) -> RetryPolicy: # ----- main loop ------------------------------------------------------ def send(self, request: Request, ctx: PipelineContext) -> Response: - if self.total_retries > 0 and request.body is not None and not request.body.is_replayable(): - request = request.with_body(request.body.to_replayable()) settings = self._configure_settings(ctx.options) + if settings["total"] > 0 and request.body is not None and not request.body.is_replayable(): + request = request.with_body(request.body.to_replayable()) absolute_deadline = self._clock.monotonic() + settings["timeout"] history: list[RequestHistory[Response]] = settings["history"] tracer = resolve_http_tracer(ctx) @@ -234,7 +237,7 @@ def send(self, request: Request, ctx: PipelineContext) -> Response: raise except SdkError as err: history.append(RequestHistory(request=request, error=err)) - if not self._decrement_for_error(settings, err): + if not self._decrement_for_error(settings, request, err): tracer.attempt_retries_exhausted() ctx.data["retry_history"] = tuple(history) raise @@ -327,17 +330,35 @@ def _decrement_status(self, settings: dict[str, Any]) -> bool: def _decrement_for_error( self, settings: dict[str, Any], + request: Request, error: BaseException, ) -> bool: """Decrement counters after a network-side error. + ``ServiceRequestError`` is a connect-phase failure: the request never + left the client, so re-sending it is safe for every method. A + ``ServiceResponseError`` is a read-phase failure — the request may + have been fully processed before the read broke — so re-sending a + non-idempotent method (POST/PATCH not in the allowlist) risks + duplicating the write. Those methods are not retried on the read path, + mirroring the careful status-path rule in ``_method_is_retryable``. + + Args: + settings: Mutable per-call settings dict. + request: The request that triggered the error, used to gate + read-phase retries on the method's idempotency. + error: The error raised by the downstream chain. + Returns: ``True`` if the budget allows another attempt. """ - settings["total"] -= 1 if isinstance(error, ServiceRequestError): + settings["total"] -= 1 settings["connect"] -= 1 elif isinstance(error, ServiceResponseError): + if not self._method_is_retryable(settings, request, None): + return False + settings["total"] -= 1 settings["read"] -= 1 else: # pragma: no cover - upstream raised something we don't classify return False diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/tracing_policy.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/tracing_policy.py index 4e5760e..3af7d33 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/tracing_policy.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/policies/tracing_policy.py @@ -36,6 +36,14 @@ from ...instrumentation import HttpTracer, Span from ..context import PipelineContext +#: ``ctx.data`` flag marking that ``operation_started`` has already fired for +#: this operation. Because ``TracingPolicy`` sits inside RETRY / REDIRECT it is +#: re-entered once per attempt / hop; the flag de-duplicates the operation-level +#: lifecycle events so ``operation_started`` fires once on the outermost entry +#: and ``operation_succeeded`` / ``operation_failed`` fire once on the +#: outermost exit. Per-attempt span behaviour is unaffected. +_OPERATION_STARTED_KEY: str = "tracing_operation_started" + class TracingPolicy(Policy): """Wrap each request in a tracing span. @@ -75,9 +83,15 @@ def send(self, request: Request, ctx: PipelineContext) -> Response: http_tracer = resolve_http_tracer(ctx) span = self._tracer.start_span(f"HTTP {request.method}", parent=parent) _set_request_attributes(span, request) - http_tracer.operation_started() + # ``operation_started`` fires once per operation. Because this policy is + # re-entered per retry attempt / redirect hop, only the outermost entry + # (the one that mints the flag) emits the operation lifecycle events. + is_outermost = _OPERATION_STARTED_KEY not in ctx.data + if is_outermost: + ctx.data[_OPERATION_STARTED_KEY] = True + http_tracer.operation_started() with bind_correlation(trace_id=_trace_id(span), span_id=_span_id(span)): - return self._dispatch(request, ctx, span, http_tracer) + return self._dispatch(request, ctx, span, http_tracer, is_outermost) def _dispatch( self, @@ -85,8 +99,15 @@ def _dispatch( ctx: PipelineContext, span: Span, http_tracer: HttpTracer, + is_outermost: bool, ) -> Response: - """Run the downstream chain, emitting tracer events around it.""" + """Run the downstream chain, emitting tracer events around it. + + The per-attempt span is opened and closed on every entry, but the + operation-level ``operation_succeeded`` / ``operation_failed`` events + fire only when the outermost entry unwinds (``is_outermost``), so a + retried or redirected call reports a single operation outcome. + """ _notify_request_sent(http_tracer, request) try: with span.make_current(): @@ -94,7 +115,8 @@ def _dispatch( except BaseException as err: span.set_error(type(err).__name__) span.end(error=err) - http_tracer.operation_failed(err) + if is_outermost: + http_tracer.operation_failed(err) raise _notify_response(http_tracer, response) span.set_attribute("http.response.status_code", int(response.status)) @@ -102,7 +124,8 @@ def _dispatch( if isinstance(retry_count, int) and retry_count > 0: span.set_attribute("http.request.resend_count", retry_count) span.end() - http_tracer.operation_succeeded() + if is_outermost: + http_tracer.operation_succeeded() return response @@ -118,14 +141,19 @@ def _set_request_attributes(span: Span, request: Request) -> None: def _notify_request_sent(http_tracer: HttpTracer, request: Request) -> None: - """Emit ``request_sent`` with the known body byte count, if any.""" + """Emit ``request_sent`` with the body byte count, or ``None`` if unknown. + + A bodyless request reports ``0``; a body with a known length reports that + count; a body whose length is unknown (``content_length()`` returns ``-1``, + e.g. a streamed upload) reports ``None`` so the event still fires and + consumers see a symmetric request_sent stream regardless of body shape. + """ body = request.body if body is None: http_tracer.request_sent(0) return length = body.content_length() - if length >= 0: - http_tracer.request_sent(length) + http_tracer.request_sent(length if length >= 0 else None) def _notify_response(http_tracer: HttpTracer, response: Response) -> None: diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/staged_builder.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/staged_builder.py index 8aa37f2..1be7c36 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/staged_builder.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/pipeline/staged_builder.py @@ -15,6 +15,7 @@ from __future__ import annotations +import contextlib from typing import TYPE_CHECKING, Self from .pipeline import Pipeline @@ -98,12 +99,16 @@ def replace(self, target: type[Policy], new: Policy) -> Self: Raises: ValueError: If no instance of ``target`` exists in the builder. """ - for stage, pillar in self._pillars.items(): - if isinstance(pillar, target): - # Install new at its declared stage; remove the old pillar. - del self._pillars[stage] - self.append(new, force=True) - return self + pillar_stage = next( + (stage for stage, pillar in self._pillars.items() if isinstance(pillar, target)), + None, + ) + if pillar_stage is not None: + # Install new at its declared stage; remove the old pillar. The + # lookup above finished iterating before we mutate ``_pillars``. + del self._pillars[pillar_stage] + self.append(new, force=True) + return self for stage, bucket in self._buckets.items(): for i, p in enumerate(bucket): if isinstance(p, target): @@ -145,6 +150,11 @@ def from_pipeline(cls, pipeline: Pipeline) -> Self: for "build a default pipeline, then surgically swap one piece" workflows. + The harvested policy instances are detached from ``pipeline`` (their + ``.next`` links are cleared) so they can be re-wired into the rebuilt + pipeline. ``pipeline`` is consumed by this call — each policy is owned + by a single pipeline, so the source pipeline must not be run again. + Raises: ValueError: If the input pipeline's policies do not satisfy stage ordering — i.e. their declared stages do not appear @@ -176,6 +186,7 @@ def from_pipeline(cls, pipeline: Pipeline) -> Self: f"non-decreasing stage order. Use the list constructor instead." ) last_stage = policy.STAGE + _detach(policy) builder.append(policy, force=True) return builder @@ -221,4 +232,18 @@ def _reload(self, policies: list[Policy]) -> None: self._buckets.setdefault(p.STAGE, []).append(p) +def _detach(policy: Policy) -> None: + """Clear ``policy.next`` so the instance can be re-wired into a new chain. + + A policy harvested from an existing pipeline still points at that + pipeline's chain. Clearing the link makes it look freshly constructed to + ``Pipeline``'s single-ownership guard, allowing the rebuild to re-wire it. + + Args: + policy: The policy whose ``.next`` link is removed if present. + """ + with contextlib.suppress(AttributeError): + del policy.next + + __all__ = ["StagedPipelineBuilder"] diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/serde/codec.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/serde/codec.py index 836a9e3..80e0afb 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/serde/codec.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/serde/codec.py @@ -27,7 +27,7 @@ import types import typing import uuid -from typing import Final, Union, cast, get_args, get_origin, get_type_hints +from typing import Annotated, Final, Union, cast, get_args, get_origin, get_type_hints from ..errors import DeserializationError, SerializationError from .tristate import ABSENT, NULL, Present, Tristate @@ -112,6 +112,16 @@ class _ModelInfo: # in practice; the lack of an explicit size cap is acceptable for that reason. _MODEL_CACHE: dict[type, _ModelInfo] = {} +_MAX_DEPTH: Final = 200 +"""Recursion ceiling for ``_decode_value``. + +A hostile, deeply nested document (e.g. ``[[[[...]]]]`` thousands of levels +deep) would otherwise exhaust the interpreter stack and surface a bare +``RecursionError``, escaping the codec's ``CodecError`` contract. The guard +trips well before CPython's default limit so the failure is a clean +``CodecError`` instead. +""" + def field_alias( wire_name: str, @@ -133,8 +143,16 @@ def field_alias( Returns: A ``dataclasses.Field`` carrying the alias metadata. + + Raises: + ValueError: If both ``default`` and ``default_factory`` are supplied. """ metadata = {ALIAS_KEY: wire_name} + if default is not dataclasses.MISSING and default_factory is not None: + # ``dataclasses.field`` rejects this combo too, but only at class-body + # evaluation; catching it here keeps the failure local to the call and + # avoids silently ignoring the factory. + raise ValueError("field_alias: pass at most one of default / default_factory") if default is not dataclasses.MISSING: return dataclasses.field(default=default, metadata=metadata) if default_factory is not None: @@ -209,9 +227,12 @@ class Codec: """Stateless engine converting between documents and typed models. Constructed once and reused. Effectively immutable after construction; its - only mutable state is a shared module-level type-hint cache whose dict - operations are atomic under CPython's GIL, so instances are safe to share - across threads. + only mutable state is a shared, append-only module-level type-hint cache. + A cache entry is computed independently per model and never mutated once + stored, so a concurrent miss merely recomputes the same value and the last + write wins with an identical result. The codec therefore does not rely on + the GIL or atomic dict operations for correctness and is safe to share + across threads, including under free-threaded CPython. """ __slots__ = ("_tolerate_unknown",) @@ -244,7 +265,7 @@ def decode[T](self, data: object, target: type[T]) -> T: CodecError: On any structural mismatch or conversion failure, with a wire-name path pointing at the offending location. """ - return cast("T", _decode_value(data, target, (), self._tolerate_unknown)) + return cast("T", _decode_value(data, target, (), self._tolerate_unknown, 0)) def encode(self, value: object) -> object: """Encode a typed value into a plain document. @@ -273,18 +294,59 @@ def _decode_value( target: object, path: tuple[str, ...], tolerate_unknown: bool, + depth: int, ) -> object: """Decode ``data`` into ``target``, dispatching on the type's shape.""" + if depth > _MAX_DEPTH: + raise CodecError( + f"maximum decode depth {_MAX_DEPTH} exceeded (input nested too deeply)", + path=path, + ) + if get_origin(target) is Annotated: + # ``Annotated[X, ...]`` carries no decode meaning here; strip the + # metadata and decode as the underlying ``X`` (without this the value + # would fall through to the container branch and be returned undecoded). + target = get_args(target)[0] if target is object or target is typing.Any: return data if _is_tristate(target): - return _decode_tristate(data, target, path, tolerate_unknown) + return _decode_tristate(data, target, path, tolerate_unknown, depth) origin = get_origin(target) if origin is None: - return _decode_atomic(data, target, path, tolerate_unknown) + return _decode_atomic(data, target, path, tolerate_unknown, depth) if origin in (Union, types.UnionType): - return _decode_union(data, target, path, tolerate_unknown) - return _decode_container(data, target, origin, path, tolerate_unknown) + return _decode_union(data, target, path, tolerate_unknown, depth) + if isinstance(origin, type) and dataclasses.is_dataclass(origin): + # A parametrised generic dataclass target (``Box[int]``) has a real + # dataclass origin; decode it as that dataclass rather than letting it + # fall through the container branch and return the raw dict undecoded. + # The type arguments are mapped onto the class's type parameters so + # generic fields decode against their concrete substitution. + return _decode_dataclass( + data, origin, path, tolerate_unknown, depth, type_args=_type_arg_map(origin, target) + ) + return _decode_container(data, target, origin, path, tolerate_unknown, depth) + + +def _type_arg_map(origin: type, target: object) -> Mapping[object, object]: + """Map a generic dataclass's type parameters onto the supplied arguments. + + For ``Box[int]`` with ``class Box[T]`` this returns ``{T: int}``. Mismatched + counts (or a non-generic origin) yield an empty map, leaving each field's + declared hint untouched. + + Args: + origin: The dataclass origin of the parametrised target. + target: The parametrised generic alias (e.g. ``Box[int]``). + + Returns: + A mapping from each type parameter to its concrete argument. + """ + params = getattr(origin, "__type_params__", ()) or getattr(origin, "__parameters__", ()) + args = get_args(target) + if not params or len(params) != len(args): + return {} + return dict(zip(params, args, strict=True)) def _decode_atomic( @@ -292,13 +354,14 @@ def _decode_atomic( target: object, path: tuple[str, ...], tolerate_unknown: bool, + depth: int, ) -> object: """Decode into a non-parametrised target (dataclass, datetime, enum, scalar).""" if isinstance(target, type): if REGISTRY_KEY in target.__dict__: - return _dispatch_union(data, target, path, tolerate_unknown) + return _dispatch_union(data, target, path, tolerate_unknown, depth) if dataclasses.is_dataclass(target): - return _decode_dataclass(data, target, path, tolerate_unknown) + return _decode_dataclass(data, target, path, tolerate_unknown, depth) if issubclass(target, enum.Enum): return _decode_enum(data, target, path) if issubclass(target, (_dt.datetime, _dt.date, _dt.time)): @@ -313,8 +376,10 @@ def _decode_dataclass( target: type, path: tuple[str, ...], tolerate_unknown: bool, + depth: int, *, exempt_key: str | None = None, + type_args: Mapping[object, object] | None = None, ) -> object: """Decode a mapping into a plain dataclass, field by field. @@ -326,6 +391,9 @@ def _decode_dataclass( exempt_key: A wire key always permitted under strict mode even when no field claims it — used for a discriminated union's tag, which is a structural key rather than a stray unknown one. + type_args: For a parametrised generic dataclass, a map from the class's + type parameters to their concrete arguments, applied to each field's + declared hint before decoding. """ if not isinstance(data, cabc.Mapping): raise CodecError( @@ -334,7 +402,7 @@ def _decode_dataclass( target_name=target.__name__, ) info = _resolve_info(target) - kwargs = _decode_fields(data, target, info, path, tolerate_unknown) + kwargs = _decode_fields(data, target, info, path, tolerate_unknown, depth, type_args) if not tolerate_unknown: _reject_unknown(data, info, path, target.__name__, exempt_key=exempt_key) try: @@ -349,34 +417,92 @@ def _decode_fields( info: _ModelInfo, path: tuple[str, ...], tolerate_unknown: bool, + depth: int, + type_args: Mapping[object, object] | None = None, ) -> dict[str, object]: """Build the constructor kwargs for ``target`` from ``data``.""" kwargs: dict[str, object] = {} for f in dataclasses.fields(target): wire = info.field_to_wire[f.name] hint = info.hints[f.name] + if type_args: + hint = _substitute_type_vars(hint, type_args) if wire not in data: - _require_present_or_default(f, wire, path, target.__name__) + if (default := _missing_field_value(f, hint, wire, path, target.__name__)) is not _OMIT: + kwargs[f.name] = default continue - kwargs[f.name] = _decode_value(data[wire], hint, (*path, wire), tolerate_unknown) + kwargs[f.name] = _decode_value(data[wire], hint, (*path, wire), tolerate_unknown, depth + 1) return kwargs -def _require_present_or_default( +def _substitute_type_vars(hint: object, type_args: Mapping[object, object]) -> object: + """Replace any type parameters in ``hint`` with their concrete arguments. + + Recurses through parametrised generics so a nested ``list[T]`` resolves to + ``list[int]``. A bare type parameter is substituted directly; anything with + no parameter to replace is returned unchanged. + + Args: + hint: A resolved type hint, possibly mentioning a type parameter. + type_args: Map from type parameters to their concrete arguments. + + Returns: + ``hint`` with every known type parameter substituted. + """ + if hint in type_args: + return type_args[hint] + args = get_args(hint) + if not args: + return hint + origin = get_origin(hint) + new_args = tuple(_substitute_type_vars(a, type_args) for a in args) + if new_args == args or origin is None: + return hint + return origin[new_args] + + +_OMIT: Final = object() +"""Sentinel meaning "supply no kwarg; let the constructor's own default apply".""" + + +def _missing_field_value( f: dataclasses.Field[object], + hint: object, wire: str, path: tuple[str, ...], target_name: str, -) -> None: - """Ensure a missing field has a default; otherwise raise ``CodecError``.""" - has_default = f.default is not dataclasses.MISSING - has_factory = f.default_factory is not dataclasses.MISSING - if not (has_default or has_factory): - raise CodecError( - f"missing required field {f.name!r} (wire {wire!r})", - path=path, - target_name=target_name, - ) +) -> object: + """Resolve the value for a field whose wire key is absent. + + A field carrying its own default or default-factory is left for the + constructor to fill (signalled by returning ``_OMIT``). A ``Tristate`` field + with no declared default still has a meaningful "absent" value — ``ABSENT`` + is exactly the type's omitted-key inhabitant — so it is supplied rather than + treated as a missing required field. Any other defaultless field is a + genuine omission and raises. + + Args: + f: The dataclass field whose key was absent from the document. + hint: The field's resolved type hint. + wire: The wire name that was looked up and not found. + path: Wire-name breadcrumb to this location. + target_name: Name of the dataclass being decoded. + + Returns: + ``ABSENT`` for a defaultless ``Tristate`` field, otherwise ``_OMIT``. + + Raises: + CodecError: If the field has neither a default nor ``Tristate`` typing. + """ + if f.default is not dataclasses.MISSING or f.default_factory is not dataclasses.MISSING: + return _OMIT + if _is_tristate(hint): + return ABSENT + raise CodecError( + f"missing required field {f.name!r} (wire {wire!r})", + path=path, + target_name=target_name, + ) def _reject_unknown( @@ -410,15 +536,16 @@ def _decode_container( origin: object, path: tuple[str, ...], tolerate_unknown: bool, + depth: int, ) -> object: """Decode a parametrised container (list/tuple/set/dict/mapping).""" args = get_args(target) if origin is dict or _is_mapping_origin(origin): - return _decode_mapping(data, args, path, tolerate_unknown) + return _decode_mapping(data, args, path, tolerate_unknown, depth) if origin is tuple: - return _decode_tuple(data, args, path, tolerate_unknown) + return _decode_tuple(data, args, path, tolerate_unknown, depth) if origin in (list, set, frozenset) or _is_sequence_origin(origin): - return _decode_sequence(data, origin, args, path, tolerate_unknown) + return _decode_sequence(data, origin, args, path, tolerate_unknown, depth) return data @@ -428,27 +555,57 @@ def _decode_sequence( args: tuple[object, ...], path: tuple[str, ...], tolerate_unknown: bool, + depth: int, ) -> object: """Decode a homogeneous sequence into list/set/frozenset (default list).""" if not isinstance(data, cabc.Iterable) or isinstance(data, (str, bytes, cabc.Mapping)): raise CodecError(f"expected an array, got {type(data).__name__}", path=path) elem = args[0] if args else object items = [ - _decode_value(item, elem, (*path, f"[{i}]"), tolerate_unknown) + _decode_value(item, elem, (*path, f"[{i}]"), tolerate_unknown, depth + 1) for i, item in enumerate(data) ] if origin in (set, cabc.Set, cabc.MutableSet): - return set(items) + return _build_hashed(set, items, path) if origin is frozenset: - return frozenset(items) + return _build_hashed(frozenset, items, path) return items +def _build_hashed[C]( + factory: Callable[[list[object]], C], + items: list[object], + path: tuple[str, ...], +) -> C: + """Build a ``set`` / ``frozenset``, mapping unhashable elements to ``CodecError``. + + A decoded element may be unhashable (e.g. a ``list`` decoded under a + ``set[object]`` field), in which case ``set()`` / ``frozenset()`` raises a + bare ``TypeError`` that would escape the codec's ``CodecError`` contract. + + Args: + factory: ``set`` or ``frozenset``. + items: The decoded elements to collect. + path: Wire-name breadcrumb to this location. + + Returns: + The constructed set or frozenset. + + Raises: + CodecError: If any element is unhashable. + """ + try: + return factory(items) + except TypeError as err: + raise CodecError(f"unhashable element in {factory.__name__}", path=path, error=err) from err + + def _decode_tuple( data: object, args: tuple[object, ...], path: tuple[str, ...], tolerate_unknown: bool, + depth: int, ) -> object: """Decode a homogeneous (``tuple[X, ...]``) or fixed-arity tuple.""" if not isinstance(data, cabc.Iterable) or isinstance(data, (str, bytes, cabc.Mapping)): @@ -457,7 +614,8 @@ def _decode_tuple( if len(args) == 2 and args[1] is Ellipsis: elem = args[0] return tuple( - _decode_value(v, elem, (*path, f"[{i}]"), tolerate_unknown) for i, v in enumerate(seq) + _decode_value(v, elem, (*path, f"[{i}]"), tolerate_unknown, depth + 1) + for i, v in enumerate(seq) ) arity = len(args) if len(seq) != arity: @@ -466,7 +624,7 @@ def _decode_tuple( path=path, ) return tuple( - _decode_value(v, t, (*path, f"[{i}]"), tolerate_unknown) + _decode_value(v, t, (*path, f"[{i}]"), tolerate_unknown, depth + 1) for i, (v, t) in enumerate(zip(seq, args, strict=True)) ) @@ -476,6 +634,7 @@ def _decode_mapping( args: tuple[object, ...], path: tuple[str, ...], tolerate_unknown: bool, + depth: int, ) -> object: """Decode a mapping, recovering each key and value through its declared type. @@ -491,11 +650,12 @@ def _decode_mapping( key_type = args[0] if len(args) == 2 else object value_type = args[1] if len(args) == 2 else object return { - _decode_value(key, key_type, (*path, str(key)), tolerate_unknown): _decode_value( + _decode_value(key, key_type, (*path, str(key)), tolerate_unknown, depth + 1): _decode_value( val, value_type, (*path, str(key)), tolerate_unknown, + depth + 1, ) for key, val in data.items() } @@ -506,22 +666,24 @@ def _decode_union( target: object, path: tuple[str, ...], tolerate_unknown: bool, + depth: int, ) -> object: - """Decode an ``X | None`` union: ``None`` passthrough, else decode ``X``. - - Only single-arm optionals (``X | None``) recover their inner type; ``None`` - is passed through only when ``NoneType`` is genuinely a union member, so a - non-optional union such as ``int | str`` does not silently accept ``None``. - Unions with two or more non-``None`` arms are tagless and cannot be resolved - structurally, so their payload passes through untouched — use a discriminated - union (``@discriminated`` / ``@variant``) when an arm must be reconstructed. + """Decode a union, recovering the inner type only for single-arm optionals. + + A single-arm optional (``X | None``) recovers ``X`` for a non-``None`` + payload and yields ``None`` for a ``None`` payload. Any union with two or + more non-``None`` arms (``int | str``, ``A | B | None``) is tagless and + cannot be resolved structurally, so its payload passes through untouched — + including a ``None`` payload, which is returned as-is rather than rejected. + Use a discriminated union (``@discriminated`` / ``@variant``) when an arm + must be reconstructed. """ all_args = get_args(target) args = [a for a in all_args if a is not type(None)] if data is None and type(None) in all_args: return None if len(args) == 1: - return _decode_value(data, args[0], path, tolerate_unknown) + return _decode_value(data, args[0], path, tolerate_unknown, depth + 1) return data @@ -530,6 +692,7 @@ def _decode_tristate( target: object, path: tuple[str, ...], tolerate_unknown: bool, + depth: int, ) -> object: """Decode a present key into ``NULL`` or ``Present(inner)``. @@ -539,7 +702,7 @@ def _decode_tristate( if data is None: return NULL inner = _tristate_inner(target) - return Present(_decode_value(data, inner, path, tolerate_unknown)) + return Present(_decode_value(data, inner, path, tolerate_unknown, depth + 1)) def _decode_enum(data: object, target: type[enum.Enum], path: tuple[str, ...]) -> object: @@ -598,6 +761,7 @@ def _dispatch_union( base: type, path: tuple[str, ...], tolerate_unknown: bool, + depth: int, ) -> object: """Resolve a discriminated union to a concrete variant and decode it.""" if not isinstance(data, cabc.Mapping): @@ -628,6 +792,7 @@ def _dispatch_union( concrete, path, tolerate_unknown, + depth, exempt_key=tag_field, ) @@ -735,11 +900,34 @@ def _resolve_info(target: type) -> _ModelInfo: info = _MODEL_CACHE.get(target) if info is not None: return info - hints = get_type_hints(target, include_extras=True) + # A PEP 695 generic dataclass (``class Box[T]``) annotates fields with its + # type parameters (``item: T``). Python 3.13+ resolves those automatically, + # but 3.12's ``get_type_hints`` does not see them and raises ``NameError``; + # supply them via ``localns`` so resolution works on every supported version. + type_params = getattr(target, "__type_params__", ()) + localns = {tp.__name__: tp for tp in type_params} or None + try: + hints = get_type_hints(target, include_extras=True, localns=localns) + except NameError as err: + # An unresolvable forward reference (a string annotation whose name is + # not in scope) surfaces as a bare ``NameError`` from ``get_type_hints``; + # wrap it so the codec keeps its ``CodecError`` contract. + raise CodecError( + f"cannot resolve a type hint on {target.__name__}: {err}", + target_name=target.__name__, + error=err, + ) from err field_to_wire: dict[str, str] = {} wire_to_field: dict[str, str] = {} for f in dataclasses.fields(target): wire = f.metadata.get(ALIAS_KEY, f.name) + if wire in wire_to_field: + # Two fields claiming the same wire alias would silently shadow each + # other (last-wins on decode, double-write on encode); reject it. + raise CodecError( + f"fields {wire_to_field[wire]!r} and {f.name!r} both map to wire name {wire!r}", + target_name=target.__name__, + ) field_to_wire[f.name] = wire wire_to_field[wire] = f.name info = _ModelInfo(hints=hints, field_to_wire=field_to_wire, wire_to_field=wire_to_field) @@ -756,6 +944,11 @@ def _is_tristate(target: object) -> bool: both shapes are recognised. A bare, non-parametrised ``Tristate`` field is treated as ``Tristate[object]`` (inner type ``object``). """ + if get_origin(target) is Annotated: + # An ``Annotated[Tristate[X], ...]`` field is still a Tristate; unwrap + # so a defaultless annotated Tristate resolves to ABSENT on an omitted + # key (matching the bare ``Tristate[X]`` contract). + target = get_args(target)[0] if target is Tristate: return True if get_origin(target) is Tristate: @@ -767,6 +960,8 @@ def _is_tristate(target: object) -> bool: def _tristate_inner(target: object) -> object: """Recover ``X`` from a ``Tristate[X]`` (or its expanded union form).""" + if get_origin(target) is Annotated: + target = get_args(target)[0] if get_origin(target) is Tristate: args = get_args(target) return args[0] if args else object diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/serde/json_serde.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/serde/json_serde.py index 0ea77fc..3a23634 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/serde/json_serde.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/serde/json_serde.py @@ -28,10 +28,56 @@ def default(self, o: Any) -> Any: return super().default(o) +def _build_encoder( + encoder_cls: type[json.JSONEncoder], + default: Callable[[Any], Any] | None, + sort_keys: bool, + allow_nan: bool, +) -> json.JSONEncoder: + """Build an encoder, chaining the built-in ``default`` to a user one. + + Called once per ``JsonSerializer``. When a custom ``default`` is supplied, + a small subclass routes its callable behind the encoder class so the + built-in datetime / date / time / bytes handling still runs first. + + Args: + encoder_cls: The base ``json.JSONEncoder`` subclass. + default: Optional user fallback for unencodable values. + sort_keys: Whether to emit object keys in sorted order. + allow_nan: Whether to permit ``NaN`` / ``Infinity`` tokens. + + Returns: + A configured, reusable encoder instance. + """ + kwargs: dict[str, Any] = { + "sort_keys": sort_keys, + "allow_nan": allow_nan, + "separators": (",", ":"), + } + if default is None: + return encoder_cls(**kwargs) + fallback: Callable[[Any], Any] = default + + class _ChainedEncoder(encoder_cls): # type: ignore[valid-type, misc] + def default(self, o: Any) -> Any: + try: + return super().default(o) + except TypeError: + return fallback(o) + + return _ChainedEncoder(**kwargs) + + class JsonSerializer: - """Serialise Python values into JSON strings / bytes / streams.""" + """Serialise Python values into JSON strings / bytes / streams. + + The encoder is built once at construction and reused for every + ``serialize`` call. A ``json.JSONEncoder`` instance is stateless across + ``encode`` calls, so sharing it costs nothing and avoids re-deriving an + encoder class on every serialize when a custom ``default`` is configured. + """ - __slots__ = ("_allow_nan", "_default", "_encoder_cls", "_sort_keys") + __slots__ = ("_encoder",) def __init__( self, @@ -52,10 +98,7 @@ def __init__( non-standard JSON tokens. encoder_cls: ``json.JSONEncoder`` subclass to use. """ - self._default = default - self._sort_keys = sort_keys - self._allow_nan = allow_nan - self._encoder_cls = encoder_cls + self._encoder = _build_encoder(encoder_cls, default, sort_keys, allow_nan) def serialize(self, value: Any) -> str: """Serialise ``value`` to a JSON string. @@ -70,33 +113,11 @@ def serialize(self, value: Any) -> str: Raises: SerializationError: If encoding fails. """ - encoder = self._build_encoder() try: - return encoder.encode(value) + return self._encoder.encode(value) except (TypeError, ValueError, UnicodeDecodeError) as err: raise SerializationError(str(err), error=err) from err - def _build_encoder(self) -> json.JSONEncoder: - """Build an encoder chaining the built-in ``default`` to the user one.""" - encoder_cls = self._encoder_cls - kwargs: dict[str, Any] = { - "sort_keys": self._sort_keys, - "allow_nan": self._allow_nan, - "separators": (",", ":"), - } - if self._default is None: - return encoder_cls(**kwargs) - fallback: Callable[[Any], Any] = self._default - - class _ChainedEncoder(encoder_cls): # type: ignore[valid-type, misc] - def default(self, o: Any) -> Any: - try: - return super().default(o) - except TypeError: - return fallback(o) - - return _ChainedEncoder(**kwargs) - def serialize_to_bytes(self, value: Any) -> bytes: """Serialise ``value`` to UTF-8-encoded JSON bytes.""" return self.serialize(value).encode("utf-8") diff --git a/packages/dexpace-sdk-core/src/dexpace/sdk/core/util/proxy.py b/packages/dexpace-sdk-core/src/dexpace/sdk/core/util/proxy.py index 4e99740..bfa2509 100644 --- a/packages/dexpace-sdk-core/src/dexpace/sdk/core/util/proxy.py +++ b/packages/dexpace-sdk-core/src/dexpace/sdk/core/util/proxy.py @@ -18,9 +18,12 @@ The ``ProxyOptions.from_configuration`` factory bridges the proxy value type to the layered ``Configuration`` lookup: it reads ``HTTPS_PROXY`` -(preferred) or ``HTTP_PROXY`` as full URLs and ``NO_PROXY`` as a -comma-separated bypass list. Parse failures degrade to ``None`` rather than -raising — bad proxy configuration should never bring down the caller. +(preferred) or ``HTTP_PROXY`` as proxy URLs and ``NO_PROXY`` as a +comma-separated bypass list. The URL scheme selects the transport flavour, a +missing port defaults by scheme, scheme-less ``host:port`` forms are accepted, +and percent-encoded credentials are decoded. Bad proxy configuration degrades +to ``None`` rather than raising — but because a silently-unused proxy is an +outage-grade misconfiguration, an unusable value is logged at WARNING. """ from __future__ import annotations @@ -32,7 +35,7 @@ from dataclasses import dataclass, field from enum import StrEnum from typing import Self -from urllib.parse import urlsplit +from urllib.parse import SplitResult, unquote, urlsplit from ..config.configuration import Configuration @@ -44,6 +47,16 @@ # Glob metacharacters that switch an entry into ``fnmatch`` mode. _GLOB_CHARS: frozenset[str] = frozenset("*?[") +# Default proxy port per URL scheme when the proxy URL omits one. SOCKS +# proxies conventionally listen on 1080. +_DEFAULT_PORTS: dict[str, int] = { + "http": 80, + "https": 443, + "socks4": 1080, + "socks5": 1080, + "socks5h": 1080, +} + def _strip_port(host: str) -> str: """Drop a trailing ``:port`` (and IPv6 brackets) so matching is host-only. @@ -75,8 +88,10 @@ def _compile_bypass(pattern: str) -> Callable[[str], bool]: their ``fnmatch`` semantics. Every other (bare) entry uses conventional suffix matching: a candidate matches when it equals the entry or ends with ``"." + entry``. Leading dot(s) on the entry are stripped so - ``.example.com`` and ``example.com`` behave identically, and a trailing - ``:port`` is dropped so ``example.com:443`` matches on its host part. + ``.example.com`` and ``example.com`` behave identically. A trailing + ``:port`` is dropped from both bare and glob entries — candidate hosts are + matched on their host part alone, so a ported glob like + ``*.example.com:443`` would otherwise never match. Args: pattern: A raw ``NO_PROXY`` list entry (already stripped). @@ -85,7 +100,8 @@ def _compile_bypass(pattern: str) -> Callable[[str], bool]: A predicate mapping a lower-cased candidate host to a bypass boolean. """ if any(char in pattern for char in _GLOB_CHARS): - regex = re.compile(fnmatch.translate(pattern), re.IGNORECASE) + glob = _strip_port(pattern) + regex = re.compile(fnmatch.translate(glob), re.IGNORECASE) return lambda host: regex.match(host) is not None suffix = _strip_port(pattern).lstrip(".").lower() dotted = "." + suffix @@ -97,6 +113,26 @@ def matches(host: str) -> bool: return matches +def _split_proxy_url(proxy_url: str) -> SplitResult: + """Parse a proxy URL, tolerating a missing ``scheme://`` prefix. + + ``urlsplit`` parses a scheme-less ``proxy:8080`` as scheme ``proxy`` with + path ``8080`` — losing the host and port entirely. When the value has no + recognised ``scheme://`` authority marker, this prepends ``//`` so the + whole value is parsed as a network location (host:port), matching how + callers conventionally write a bare proxy address. + + Args: + proxy_url: A proxy URL, possibly scheme-less. + + Returns: + The ``SplitResult`` of parsing the (possibly normalised) URL. + """ + if "//" not in proxy_url: + return urlsplit("//" + proxy_url) + return urlsplit(proxy_url) + + class ProxyType(StrEnum): """Supported proxy transport flavours. @@ -111,6 +147,56 @@ class ProxyType(StrEnum): SOCKS5 = "SOCKS5" +# Map a proxy URL scheme to the modelled transport flavour. A scheme absent +# from this table is unsupported and is rejected (with a WARNING) rather than +# silently downgraded to HTTP. ``socks5h`` (remote DNS) maps to ``SOCKS5``. +_SCHEME_TO_TYPE: dict[str, ProxyType] = { + "http": ProxyType.HTTP, + "https": ProxyType.HTTP, + "socks4": ProxyType.SOCKS4, + "socks5": ProxyType.SOCKS5, + "socks5h": ProxyType.SOCKS5, +} + + +def _resolve_endpoint(proxy_url: str) -> tuple[SplitResult, ProxyType, str, int] | None: + """Resolve a proxy URL into its parsed parts, type, host, and port. + + Applies scheme→type mapping, scheme-by-default port resolution, and + scheme-less ``host:port`` handling. Any value that cannot yield a usable + endpoint is logged at WARNING (a silently-disabled proxy is outage-grade) + and yields ``None``. + + Args: + proxy_url: The raw proxy URL string. + + Returns: + A ``(SplitResult, ProxyType, host, port)`` tuple, or ``None`` if the + URL is unusable. + """ + try: + split = _split_proxy_url(proxy_url) + except ValueError: + _LOG.warning("ignoring proxy URL %r: failed to parse", proxy_url) + return None + scheme = split.scheme.lower() + proxy_type = _SCHEME_TO_TYPE["http"] if scheme == "" else _SCHEME_TO_TYPE.get(scheme) + if proxy_type is None: + _LOG.warning("ignoring proxy URL %r: unsupported scheme %r", proxy_url, scheme) + return None + if not split.hostname: + _LOG.warning("ignoring proxy URL %r: missing hostname", proxy_url) + return None + try: + port = split.port + except ValueError: + _LOG.warning("ignoring proxy URL %r: invalid port", proxy_url) + return None + if port is None: + port = _DEFAULT_PORTS.get(scheme, 80) + return split, proxy_type, split.hostname, port + + @dataclass(frozen=True, slots=True) class ProxyOptions: """Immutable proxy configuration with pre-compiled bypass matchers. @@ -195,55 +281,58 @@ def __repr__(self) -> str: def from_configuration(cls, config: Configuration) -> Self | None: """Build a ``ProxyOptions`` from layered configuration env vars. - Reads ``HTTPS_PROXY`` (preferred) or ``HTTP_PROXY`` as full proxy - URLs (``http://user:pass@proxy.corp:8080``). Reads ``NO_PROXY`` as - a comma-separated bypass list. A ``NO_PROXY`` value of ``"*"`` - bypasses everything and short-circuits to ``None``. + Reads ``HTTPS_PROXY`` (preferred) or ``HTTP_PROXY`` as proxy URLs and + ``NO_PROXY`` as a comma-separated bypass list. A ``NO_PROXY`` value of + ``"*"`` bypasses everything and short-circuits to ``None``. + + The proxy URL is parsed leniently so common real-world forms work: + + - The URL ``scheme`` selects the transport flavour: ``http``/``https`` + map to ``HTTP``, ``socks4`` to ``SOCKS4``, ``socks5``/``socks5h`` to + ``SOCKS5``. An *unsupported* scheme is rejected (logged at WARNING), + never silently downgraded to HTTP. + - A missing port defaults by scheme (``http`` 80, ``https`` 443, SOCKS + 1080) instead of dropping the proxy. + - A scheme-less ``proxy:8080`` is parsed as host:port (assumed HTTP). + - Percent-encoded credentials are ``unquote()``-decoded. + + Because a silently-unused proxy is an outage-grade misconfiguration, a + genuinely unusable proxy value is logged at WARNING (not DEBUG). Args: config: Layered configuration to read from. Returns: A populated ``ProxyOptions``, or ``None`` when no proxy is - configured, when ``NO_PROXY=*``, or when the proxy URL fails to - parse (a debug-level log line records the failure). + configured, when ``NO_PROXY=*``, or when the proxy URL is + unusable (a WARNING log line records why). """ no_proxy_raw = config.get(Configuration.NO_PROXY) if no_proxy_raw is not None and no_proxy_raw.strip() == "*": return None proxy_url = config.get(Configuration.HTTPS_PROXY) or config.get(Configuration.HTTP_PROXY) - if proxy_url is None or proxy_url == "": - return None - try: - parsed = urlsplit(proxy_url) - except ValueError: - _LOG.debug("failed to parse proxy URL %r", proxy_url) - return None - if not parsed.hostname: - _LOG.debug("proxy URL %r missing hostname", proxy_url) - return None - try: - port = parsed.port - except ValueError: - _LOG.debug("proxy URL %r has invalid port", proxy_url) + if not proxy_url: return None - if port is None: - _LOG.debug("proxy URL %r missing port", proxy_url) + endpoint = _resolve_endpoint(proxy_url) + if endpoint is None: return None + split, proxy_type, host, port = endpoint non_proxy_hosts: tuple[str, ...] = () if no_proxy_raw is not None and no_proxy_raw.strip(): non_proxy_hosts = tuple( entry.strip() for entry in no_proxy_raw.split(",") if entry.strip() ) + username = unquote(split.username) if split.username is not None else None + password = unquote(split.password) if split.password is not None else None try: return cls( - type=ProxyType.HTTP, - host=parsed.hostname, + type=proxy_type, + host=host, port=port, non_proxy_hosts=non_proxy_hosts, - username=parsed.username, - password=parsed.password, + username=username, + password=password, ) except ValueError: - _LOG.debug("proxy URL %r failed ProxyOptions validation", proxy_url) + _LOG.warning("ignoring proxy URL %r: failed ProxyOptions validation", proxy_url) return None diff --git a/packages/dexpace-sdk-core/tests/auth/test_digest.py b/packages/dexpace-sdk-core/tests/auth/test_digest.py index 800bed2..7ee618a 100644 --- a/packages/dexpace-sdk-core/tests/auth/test_digest.py +++ b/packages/dexpace-sdk-core/tests/auth/test_digest.py @@ -112,7 +112,9 @@ def test_digest_prefers_sha256_when_both_offered(self) -> None: params = _parse_auth(value) assert params["algorithm"] == "SHA-256" - def test_digest_nonce_counter_increments(self) -> None: + def test_digest_nonce_counter_increments_on_reuse(self) -> None: + # Reusing the same server nonce across requests increments ``nc`` + # (RFC 7616 §3.4: count of requests sent with this nonce). handler = DigestChallengeHandler( _USERNAME, _PASSWORD, @@ -129,6 +131,78 @@ def test_digest_nonce_counter_increments(self) -> None: assert _parse_auth(first[1])["nc"] == "00000001" assert _parse_auth(second[1])["nc"] == "00000002" + def test_digest_nc_resets_for_new_nonce(self) -> None: + # A fresh server nonce restarts ``nc`` at 00000001 (RFC 7616 §3.4), + # even after a prior nonce advanced the count. A single global counter + # would wrongly emit 00000003 here. + handler = DigestChallengeHandler( + _USERNAME, + _PASSWORD, + preferred_algorithms=("MD5",), + cnonce_factory=lambda: _FIXED_CNONCE, + ) + first_nonce = AuthenticateChallenge( + scheme="Digest", + parameters={**_RFC_PARAMS, "algorithm": "MD5", "nonce": "nonce-aaa"}, + ) + second_nonce = AuthenticateChallenge( + scheme="Digest", + parameters={**_RFC_PARAMS, "algorithm": "MD5", "nonce": "nonce-bbb"}, + ) + # Advance the count on the first nonce. + handler.handle(Method.GET, _URL, [first_nonce], is_proxy=False) + second = handler.handle(Method.GET, _URL, [first_nonce], is_proxy=False) + assert second is not None + assert _parse_auth(second[1])["nc"] == "00000002" + # A different nonce must reset to 00000001. + fresh = handler.handle(Method.GET, _URL, [second_nonce], is_proxy=False) + assert fresh is not None + assert _parse_auth(fresh[1])["nc"] == "00000001" + + def test_digest_nc_resumes_per_nonce_when_alternating(self) -> None: + # Each nonce keeps its own independent count: alternating between two + # nonces must resume each one's count rather than share a global one. + handler = DigestChallengeHandler( + _USERNAME, + _PASSWORD, + preferred_algorithms=("MD5",), + cnonce_factory=lambda: _FIXED_CNONCE, + ) + + def nc_for(nonce: str) -> str: + challenge = AuthenticateChallenge( + scheme="Digest", + parameters={**_RFC_PARAMS, "algorithm": "MD5", "nonce": nonce}, + ) + result = handler.handle(Method.GET, _URL, [challenge], is_proxy=False) + assert result is not None + return _parse_auth(result[1])["nc"] + + assert nc_for("nonce-aaa") == "00000001" + assert nc_for("nonce-bbb") == "00000001" + assert nc_for("nonce-aaa") == "00000002" + assert nc_for("nonce-bbb") == "00000002" + assert nc_for("nonce-aaa") == "00000003" + + def test_digest_nonce_count_map_is_bounded(self) -> None: + # A long-lived handler hitting many distinct nonces must not grow the + # per-nonce map without bound; the oldest entry is evicted past the cap. + from dexpace.sdk.core.http.auth.digest import _MAX_TRACKED_NONCES + + handler = DigestChallengeHandler( + _USERNAME, + _PASSWORD, + preferred_algorithms=("MD5",), + cnonce_factory=lambda: _FIXED_CNONCE, + ) + for index in range(_MAX_TRACKED_NONCES + 50): + challenge = AuthenticateChallenge( + scheme="Digest", + parameters={**_RFC_PARAMS, "algorithm": "MD5", "nonce": f"nonce-{index}"}, + ) + handler.handle(Method.GET, _URL, [challenge], is_proxy=False) + assert len(handler._nonce_counts) == _MAX_TRACKED_NONCES + def test_digest_is_proxy_returns_proxy_authorization_header(self) -> None: handler = DigestChallengeHandler( _USERNAME, diff --git a/packages/dexpace-sdk-core/tests/auth/test_policies.py b/packages/dexpace-sdk-core/tests/auth/test_policies.py index 5cc85a5..1ff79ea 100644 --- a/packages/dexpace-sdk-core/tests/auth/test_policies.py +++ b/packages/dexpace-sdk-core/tests/auth/test_policies.py @@ -111,6 +111,53 @@ def test_basic_auth_policy_stamps_header() -> None: assert client.calls[0].headers.get("authorization") == "Basic dXNlcjpwYXNz" +class _RedirectingClient(HttpClient): + """302s the first request to ``location`` then replies 200. + + Records every request so a test can inspect which headers reached the + foreign host on the reissued hop. + """ + + def __init__(self, location: str) -> None: + self._location = location + self.calls: list[Request] = [] + + def execute(self, request: Request) -> Response: + from dexpace.sdk.core.http.common import Headers + + self.calls.append(request) + if len(self.calls) == 1: + return Response( + request=request, + protocol=Protocol.HTTP_1_1, + status=Status.FOUND, + headers=Headers([("Location", self._location)]), + ) + return Response(request=request, protocol=Protocol.HTTP_1_1, status=Status.OK) + + +def test_key_credential_policy_withholds_credential_cross_origin() -> None: + from dexpace.sdk.core.pipeline.policies.redirect import RedirectPolicy + + client = _RedirectingClient("https://attacker.example.net/loot") + policy = KeyCredentialPolicy(KeyCredential("hunter2"), "X-API-Key") + with Pipeline(client, policies=[RedirectPolicy(), policy]) as p: + p.run(_request(), DispatchContext(_instr("0" * 15 + "30"))) + assert client.calls[0].headers.get("x-api-key") == "hunter2" + assert "x-api-key" not in client.calls[1].headers + + +def test_basic_auth_policy_withholds_credential_cross_origin() -> None: + from dexpace.sdk.core.pipeline.policies.redirect import RedirectPolicy + + client = _RedirectingClient("https://attacker.example.net/loot") + policy = BasicAuthPolicy(BasicAuthCredential("user", "pass")) + with Pipeline(client, policies=[RedirectPolicy(), policy]) as p: + p.run(_request(), DispatchContext(_instr("0" * 15 + "31"))) + assert client.calls[0].headers.get("authorization") == "Basic dXNlcjpwYXNz" + assert "authorization" not in client.calls[1].headers + + class _StaticCredential: """Minimal TokenCredential — returns the same token unless explicitly told.""" @@ -214,6 +261,36 @@ def on_challenge(self, request: Request, response: Response) -> bool: assert len(client.calls) == 2 +def test_bearer_token_policy_withholds_token_cross_origin() -> None: + """A redirect to a foreign host must not receive the bearer token.""" + from dexpace.sdk.core.pipeline.policies.redirect import RedirectPolicy + + client = _RedirectingClient("https://attacker.example.net/loot") + cred = _StaticCredential() + policy = BearerTokenPolicy(cred, "scope-a") + with Pipeline(client, policies=[RedirectPolicy(), policy]) as p: + p.run(_request(), DispatchContext(_instr("0" * 15 + "32"))) + assert client.calls[0].headers.get("authorization") == "Bearer abc" + assert "authorization" not in client.calls[1].headers + # The foreign hop neither acquired nor refreshed a token. + assert cred.calls == 1 + + +def test_bearer_token_policy_skips_https_enforcement_cross_origin() -> None: + """A cross-origin reissue to http:// forwards unchanged, no HTTPS error.""" + from dexpace.sdk.core.pipeline.policies.redirect import RedirectPolicy + + # http:// target is cross-origin (scheme + host differ); the bearer policy + # must forward it unchanged rather than raising the HTTPS-only error. + client = _RedirectingClient("http://other.example.org/next") + cred = _StaticCredential() + policy = BearerTokenPolicy(cred, "scope-a") + with Pipeline(client, policies=[RedirectPolicy(), policy]) as p: + response = p.run(_request(), DispatchContext(_instr("0" * 15 + "33"))) + assert response.status is Status.OK + assert "authorization" not in client.calls[1].headers + + class _SlowCredential: """TokenCredential whose token fetch is slow — exercises concurrent refresh.""" @@ -248,15 +325,20 @@ def test_bearer_token_policy_serializes_concurrent_refresh() -> None: trace_ids = [f"{i:032x}" for i in range(1, 9)] - def _send(trace: str) -> None: - with Pipeline(client, policies=[policy]) as p: + # A single pipeline (and thus a single policy instance, whose lock and + # cache serialize the refresh) is shared across the worker threads. The + # pipeline run is concurrency-safe; reusing one policy across separately + # constructed pipelines is not — each Policy is owned by one pipeline. + with Pipeline(client, policies=[policy]) as p: + + def _send(trace: str) -> None: p.run(_request(), DispatchContext(_instr(trace))) - threads = [threading.Thread(target=_send, args=(t,)) for t in trace_ids] - for t in threads: - t.start() - for t in threads: - t.join() + threads = [threading.Thread(target=_send, args=(t,)) for t in trace_ids] + for t in threads: + t.start() + for t in threads: + t.join() assert cred.calls == 1 diff --git a/packages/dexpace-sdk-core/tests/http/test_loggable_body_fixes.py b/packages/dexpace-sdk-core/tests/http/test_loggable_body_fixes.py index ffdba97..5aae5e6 100644 --- a/packages/dexpace-sdk-core/tests/http/test_loggable_body_fixes.py +++ b/packages/dexpace-sdk-core/tests/http/test_loggable_body_fixes.py @@ -238,3 +238,37 @@ def test_snapshot_after_cap_does_not_block_further_writes(self) -> None: list(body.iter_bytes()) assert body.snapshot(max_bytes=2) == b"01" assert body.snapshot() == b"0123456789" + + +class TestRequestTapResetOnReplay: + """A replayed loggable request body captures one copy, not body+body.""" + + def test_second_iteration_does_not_double_the_capture(self) -> None: + # A replayable inner body iterated twice (retry / 307 redirect) must + # not accumulate ``body + body`` in the tap. Each iter_bytes resets + # the tap so snapshot/captured_size reflect a single payload. + body = LoggableRequestBody(RequestBody.from_bytes(b"payload")) + assert b"".join(body.iter_bytes()) == b"payload" + assert b"".join(body.iter_bytes()) == b"payload" + assert body.snapshot() == b"payload" + assert body.captured_size == len(b"payload") + + def test_capture_reflects_only_the_latest_attempt(self) -> None: + # After many replays the capture is still one copy, never N copies. + body = LoggableRequestBody(RequestBody.from_bytes(b"abc")) + for _ in range(5): + assert b"".join(body.iter_bytes()) == b"abc" + assert body.snapshot() == b"abc" + assert body.captured_size == 3 + + def test_tap_reset_is_eager_before_first_chunk(self) -> None: + # The reset happens at call time, so a fresh (undrained) iterator from + # a replay already shows an empty capture rather than the prior copy. + body = LoggableRequestBody(RequestBody.from_bytes(b"hello")) + list(body.iter_bytes()) + assert body.captured_size == 5 + # Obtain a new iterator without draining it; the tap must already be + # cleared by the eager reset inside iter_bytes. + body.iter_bytes() + assert body.captured_size == 0 + assert body.snapshot() == b"" diff --git a/packages/dexpace-sdk-core/tests/http/test_media_type.py b/packages/dexpace-sdk-core/tests/http/test_media_type.py index 3ddedef..8bc45df 100644 --- a/packages/dexpace-sdk-core/tests/http/test_media_type.py +++ b/packages/dexpace-sdk-core/tests/http/test_media_type.py @@ -48,6 +48,12 @@ def test_parse_quoted_pair(self) -> None: mt = MediaType.parse('text/plain; foo="a\\"b"') assert dict(mt.parameters)["foo"] == 'a"b' + def test_parse_quoted_value_with_semicolon(self) -> None: + # A ``;`` inside a quoted-string is part of the value, not a separator. + mt = MediaType.parse('multipart/mixed; boundary="a;b"') + assert dict(mt.parameters)["boundary"] == "a;b" + assert len(mt.parameters) == 1 + class TestNormalisation: def test_lower_cases_type_and_subtype(self) -> None: @@ -108,6 +114,14 @@ def test_str_quotes_boundary_with_spaces(self) -> None: # Round-trip: parsing the rendered form recovers the original value. assert MediaType.parse(rendered) == mt + def test_str_quotes_boundary_with_semicolon_round_trips(self) -> None: + mt = MediaType.of("multipart", "mixed", {"boundary": "a;b;c"}) + rendered = str(mt) + assert 'boundary="a;b;c"' in rendered + # Round-trip: parsing the rendered form recovers the original value + # despite the embedded semicolon. + assert MediaType.parse(rendered) == mt + def test_equality_independent_of_param_order(self) -> None: a = MediaType.of("text", "plain", {"a": "1", "b": "2"}) b = MediaType.of("text", "plain", {"b": "2", "a": "1"}) diff --git a/packages/dexpace-sdk-core/tests/http/test_media_type_params.py b/packages/dexpace-sdk-core/tests/http/test_media_type_params.py index 054dea5..2afe598 100644 --- a/packages/dexpace-sdk-core/tests/http/test_media_type_params.py +++ b/packages/dexpace-sdk-core/tests/http/test_media_type_params.py @@ -60,6 +60,35 @@ def test_quoted_value_with_separators_preserved(self) -> None: mt = MediaType.parse('multipart/form-data; boundary="foo bar"') assert dict(mt.parameters)["boundary"] == "foo bar" + def test_quoted_value_with_semicolon_not_split(self) -> None: + # A ``;`` inside a quoted-string is part of the value, not a parameter + # separator: a naive ``value.split(";")`` would mis-parse this. + mt = MediaType.parse('multipart/mixed; boundary="a;b;c"') + assert dict(mt.parameters)["boundary"] == "a;b;c" + assert len(mt.parameters) == 1 + + def test_quoted_semicolon_does_not_create_spurious_param(self) -> None: + # Two real parameters; the first carries a quoted ``;``. Both must + # survive, and the embedded ``;`` must not split off a third entry. + mt = MediaType.parse('multipart/mixed; boundary="a;b"; charset=utf-8') + params = dict(mt.parameters) + assert params["boundary"] == "a;b" + assert params["charset"] == "utf-8" + assert len(mt.parameters) == 2 + + def test_quoted_escaped_quote_before_semicolon(self) -> None: + # An escaped quote inside the quoted-string must not prematurely close + # the quote, so the following ``;`` stays inside the value. + mt = MediaType.parse('text/plain; foo="a\\";b"') + assert dict(mt.parameters)["foo"] == 'a";b' + assert len(mt.parameters) == 1 + + def test_quoted_semicolon_round_trips(self) -> None: + # parse -> str -> parse recovers the original value with the embedded + # semicolon intact. + original = MediaType.of("multipart", "mixed", {"boundary": "a;b;c"}) + assert MediaType.parse(str(original)) == original + class TestCaseFolding: def test_type_and_subtype_lowercased(self) -> None: diff --git a/packages/dexpace-sdk-core/tests/http/test_multipart.py b/packages/dexpace-sdk-core/tests/http/test_multipart.py index b616394..cf780b8 100644 --- a/packages/dexpace-sdk-core/tests/http/test_multipart.py +++ b/packages/dexpace-sdk-core/tests/http/test_multipart.py @@ -129,3 +129,68 @@ def test_ascii_filename_still_works() -> None: body = MultipartRequestBody([MultipartField(name="file", value=b"x", filename="upload.bin")]) drained = _drain(body) assert b'filename="upload.bin"' in drained + + +@pytest.mark.parametrize("ctrl", ["\r", "\n", "\0"]) +def test_control_char_in_name_rejected(ctrl: str) -> None: + # CR/LF/NUL in a field name would let an attacker inject extra part + # headers or a fabricated boundary line into the multipart payload. + with pytest.raises(ValueError, match="control characters"): + MultipartField(name=f"key{ctrl}", value=b"v") + + +@pytest.mark.parametrize("ctrl", ["\r", "\n", "\0"]) +def test_control_char_in_filename_rejected(ctrl: str) -> None: + # Filenames are the classic attacker-controlled value on file uploads. + with pytest.raises(ValueError, match="control characters"): + MultipartField(name="file", value=b"v", filename=f"a{ctrl}b.txt") + + +def test_crlf_injection_filename_rejected() -> None: + # The canonical header-injection payload: a filename that smuggles a new + # Content-Type header and body after a CRLF must be refused outright. + payload = 'a"\r\nContent-Type: text/html\r\n\r\n