Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions packages/dexpace-sdk-core/src/dexpace/sdk/core/http/auth/digest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,20 @@
verification are out of scope — matching the Java v1 cut.

A single handler instance is intended to be reused across requests so the
per-client nonce counter (``nc``) advances monotonically. The counter is
guarded by a ``threading.Lock`` since CPython integer increment is not
atomic with respect to other threads' reads of the same variable.
per-nonce request counter (``nc``) advances correctly. RFC 7616 §3.4 defines
``nc`` as the count of requests the client has sent with the *current* server
nonce, so the counter restarts at ``00000001`` for each fresh nonce and
increments only while a nonce is reused. The per-nonce counts are tracked in a
bounded mapping guarded by a ``threading.Lock`` since CPython dict mutation and
integer increment are not atomic with respect to other threads.
"""

from __future__ import annotations

import hashlib
import secrets
import threading
from collections import OrderedDict
from collections.abc import Callable
from dataclasses import dataclass
from typing import Final
Expand Down Expand Up @@ -58,6 +62,13 @@ class _ResolvedChallenge:
"SHA-256-SESS": hashlib.sha256,
}

# Cap on how many distinct server nonces retain a request count. A long-lived
# handler hitting many rotating nonces would otherwise grow the map without
# bound; the oldest entry is evicted past this size. The limit is generous —
# servers rotate nonces rarely — so eviction effectively never affects an
# active nonce.
_MAX_TRACKED_NONCES: Final[int] = 1024


class DigestChallengeHandler:
"""Satisfy a ``Digest`` challenge per RFC 7616.
Expand All @@ -79,8 +90,8 @@ class DigestChallengeHandler:

__slots__ = (
"_cnonce_factory",
"_counter",
"_lock",
"_nonce_counts",
"_password",
"_preferred",
"_username",
Expand All @@ -102,7 +113,10 @@ def __init__(
self._password = password
self._preferred = preferred_algorithms
self._cnonce_factory = cnonce_factory or (lambda: secrets.token_hex(16))
self._counter = 0
# Maps a server nonce to the count of requests sent with it. Insertion
# order is preserved so the oldest nonce can be evicted once the map
# exceeds ``_MAX_TRACKED_NONCES``.
self._nonce_counts: OrderedDict[str, int] = OrderedDict()
self._lock = threading.Lock()

def can_handle(self, challenges: list[AuthenticateChallenge]) -> bool:
Expand All @@ -122,7 +136,7 @@ def handle(
resolved = self._resolve(selected)
if resolved is None:
return None
nc = self._next_nc()
nc = self._next_nc(resolved.nonce)
cnonce = self._cnonce_factory()
uri = _request_uri(url)
response = self._compute_response(
Expand Down Expand Up @@ -223,13 +237,33 @@ def _select(self, challenges: list[AuthenticateChallenge]) -> AuthenticateChalle
return challenge
return None

def _next_nc(self) -> str:
def _next_nc(self, nonce: str) -> str:
"""Return the next ``nc`` value for ``nonce`` as 8 lowercase hex digits.

RFC 7616 §3.4 defines ``nc`` as the count of requests sent with the
current server nonce, so the count starts at ``00000001`` for each
fresh nonce and increments on every reuse. A bounded, insertion-ordered
map tracks the per-nonce counts; the oldest nonce is evicted once the
map exceeds ``_MAX_TRACKED_NONCES`` to cap memory on a long-lived
handler that sees many rotating nonces.

Args:
nonce: The server nonce the request is being signed against.

Returns:
The request count for ``nonce`` formatted as 8 lowercase hex
digits, e.g. ``"00000001"`` for the first request.
"""
with self._lock:
# Clamp to 32 bits — ``nc`` is rendered as 8 hex digits per
# RFC 7616, and wrapping after 2**32-1 is acceptable since the
# server hashes the value (no monotonic check).
self._counter = (self._counter + 1) & 0xFFFFFFFF
value = self._counter
value = (self._nonce_counts.get(nonce, 0) + 1) & 0xFFFFFFFF
self._nonce_counts[nonce] = value
# Refresh recency so an actively reused nonce is never evicted.
self._nonce_counts.move_to_end(nonce)
if len(self._nonce_counts) > _MAX_TRACKED_NONCES:
self._nonce_counts.popitem(last=False)
return f"{value:08x}"

def _credentials_encodable(self, charset: str) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class KeyCredentialPolicy(Policy):
SansIO-shaped (no chain wrapping needed) but implemented as a ``Policy``
so it integrates uniformly with the rest of the pipeline.

The credential header is stamped only while the request stays on the
origin recorded on the first pass through the policy. When a downstream
redirect reissues the request against a different origin (scheme, host,
or effective port), the credential is withheld so it never reaches a
foreign host.

Attributes:
header_name: Header to write.
prefix: Optional prefix (with trailing space) for the header value.
Expand All @@ -63,12 +69,21 @@ def __init__(
self.prefix = f"{prefix} " if prefix else ""

def send(self, request: Request, ctx: PipelineContext) -> Response:
if _crosses_recorded_origin(request, ctx):
return self.next.send(request, ctx)
value = f"{self.prefix}{self._credential.key}"
return self.next.send(request.with_header(self.header_name, value), ctx)


class BasicAuthPolicy(Policy):
"""Stamp ``Authorization: Basic <base64>`` from a ``BasicAuthCredential``."""
"""Stamp ``Authorization: Basic <base64>`` from a ``BasicAuthCredential``.

The credential is stamped only while the request stays on the origin
recorded on the first pass through the policy. When a downstream redirect
reissues the request against a different origin (scheme, host, or
effective port), the credential is withheld so it never reaches a foreign
host.
"""

STAGE = Stage.AUTH
__slots__ = ("_credential",)
Expand All @@ -79,6 +94,8 @@ def __init__(self, credential: BasicAuthCredential) -> None:
self._credential = credential

def send(self, request: Request, ctx: PipelineContext) -> Response:
if _crosses_recorded_origin(request, ctx):
return self.next.send(request, ctx)
value = f"Basic {self._credential.encoded}"
return self.next.send(request.with_header("Authorization", value), ctx)

Expand All @@ -91,6 +108,13 @@ class BearerTokenPolicy(Policy):
returns True, or after a 401 response with ``WWW-Authenticate``. Enforces
HTTPS unless ``enforce_https=False`` is passed in ``ctx.options``.

The token is acquired and stamped only while the request stays on the
origin recorded on the first pass. When a downstream redirect reissues the
request against a different origin (scheme, host, or effective port), the
policy forwards the request unchanged — it does not acquire, refresh, or
stamp a token — so the bearer token never reaches a foreign host. The
HTTPS-enforcement check applies only on that same-origin stamping path.

Concurrent refreshes are serialized via a ``threading.Lock`` using a
double-checked pattern so the credential's ``get_token_info`` is invoked
at most once per refresh window even under heavy concurrent send pressure.
Expand Down Expand Up @@ -226,6 +250,11 @@ def _authorize(
*,
force_refresh: bool = False,
) -> Request:
# A redirect that crossed origin must not receive the bearer token:
# forward the request unchanged without acquiring or refreshing one,
# and skip the HTTPS enforcement that only governs the stamping path.
if _crosses_recorded_origin(request, ctx):
return request
if ctx.options.get("enforce_https", True) and not _is_https(request.url):
raise ServiceRequestError(
"Bearer token authentication is not permitted for non-HTTPS URLs."
Expand Down Expand Up @@ -257,6 +286,13 @@ class AsyncBearerTokenPolicy(AsyncPolicy):
the returned ``(name, value)`` pair is stamped on the retried request.
A 401 invalidates the cached origin token; a 407 leaves it alone because
the proxy, not the origin, rejected the request.

The token is acquired and stamped only while the request stays on the
origin recorded on the first pass. When a downstream redirect reissues the
request against a different origin (scheme, host, or effective port), the
policy forwards the request unchanged — it does not acquire, refresh, or
stamp a token — so the bearer token never reaches a foreign host. The
HTTPS-enforcement check applies only on that same-origin stamping path.
"""

STAGE = Stage.AUTH
Expand Down Expand Up @@ -381,6 +417,11 @@ async def _authorize(
*,
force_refresh: bool = False,
) -> Request:
# A redirect that crossed origin must not receive the bearer token:
# forward the request unchanged without acquiring or refreshing one,
# and skip the HTTPS enforcement that only governs the stamping path.
if _crosses_recorded_origin(request, ctx):
return request
if ctx.options.get("enforce_https", True) and not _is_https(request.url):
raise ServiceRequestError(
"Bearer token authentication is not permitted for non-HTTPS URLs."
Expand All @@ -398,6 +439,51 @@ async def _authorize(
return request.with_header("Authorization", f"{token.token_type} {token.token}")


_DEFAULT_PORTS: dict[str, int] = {"https": 443, "http": 80}
_AUTH_ORIGIN_KEY: str = "_auth_origin"


def _origin(url: Url) -> tuple[str, str, int | None]:
"""Return the ``(scheme, host, port)`` origin tuple for ``url``.

The scheme and host are lower-cased and the port is resolved to its
scheme default (443 for https, 80 for http) when not explicit, so two
URLs that differ only in an implied/explicit default port compare equal.

Args:
url: The URL to derive an origin from.

Returns:
A ``(scheme, host, effective_port)`` tuple suitable for equality
comparison.
"""
scheme = url.scheme.lower()
port = url.port if url.port is not None else _DEFAULT_PORTS.get(scheme)
return scheme, url.host.lower(), port


def _crosses_recorded_origin(request: Request, ctx: PipelineContext) -> bool:
"""Report whether ``request`` left the origin recorded for this operation.

On the first pass through an auth policy the request's origin is stored in
``ctx.data`` (which is per-operation), so a later redirect reissue can be
compared against it. When the current origin differs, the credential must
not be stamped — this is what stops a redirect to a foreign host from
receiving the caller's credentials.

Args:
request: The request the auth policy is about to forward.
ctx: The per-operation pipeline context.

Returns:
``True`` when the request's origin differs from the one recorded on
the first pass; ``False`` on the first pass or a same-origin reissue.
"""
current = _origin(request.url)
recorded: tuple[str, str, int | None] = ctx.data.setdefault(_AUTH_ORIGIN_KEY, current)
return recorded != current


def _is_https(url: Url) -> bool:
"""Return True if ``url``'s scheme is ``https`` (case-insensitive)."""
return url.scheme.lower() == "https"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,47 @@
_TOKEN_SEPARATORS = frozenset('()<>@,;:\\"/[]?={} \t')


def _split_params(value: str) -> list[str]:
"""Split a media-type wire string on ``;`` outside quoted-strings.

A naive ``value.split(";")`` mis-parses a quoted parameter value that
legitimately contains a semicolon (legal inside an RFC 7230 §3.2.6
quoted-string, e.g. an exotic multipart boundary). This walks the string
with a small state machine that treats ``;`` as a separator only when it
is not enclosed in double quotes, honouring the ``\\X`` quoted-pair escape
so an escaped quote does not toggle the quoted state.

Args:
value: The full wire-form media-type string.

Returns:
The segments split on top-level ``;`` separators, in order. Surrounding
whitespace is not stripped here — callers strip per segment.
"""
segments: list[str] = []
current: list[str] = []
in_quotes = False
i = 0
while i < len(value):
ch = value[i]
if in_quotes and ch == "\\" and i + 1 < len(value):
current.append(ch)
current.append(value[i + 1])
i += 2
continue
if ch == '"':
in_quotes = not in_quotes
elif ch == ";" and not in_quotes:
segments.append("".join(current))
current = []
i += 1
continue
current.append(ch)
i += 1
segments.append("".join(current))
return segments


def _unquote(s: str) -> str:
"""Decode a quoted-string parameter value per RFC 7230 §3.2.6.

Expand Down Expand Up @@ -146,7 +187,7 @@ def parse(cls, value: str) -> Self:
"""
if not value or not value.strip():
raise ValueError("media type must not be blank")
segments = [segment.strip() for segment in value.split(";")]
segments = [segment.strip() for segment in _split_params(value)]
mime = segments[0]
slash = mime.find("/")
if slash <= 0 or slash == len(mime) - 1:
Expand Down
Loading