From 78ac323aaf192a94e54ab2911a8e3aded1ac42c8 Mon Sep 17 00:00:00 2001 From: Manoj Bajaj Date: Wed, 13 May 2026 15:51:13 +0530 Subject: [PATCH 1/2] refactor: use authlib for oauth flows Entire-Checkpoint: 3fd27c6c5d9c --- pyproject.toml | 3 +- src/authsome/auth/flows/base.py | 44 ++--- src/authsome/auth/flows/dcr_pkce.py | 75 ++------- src/authsome/auth/flows/oauth2_client.py | 151 ++++++++++++++++++ src/authsome/auth/flows/pkce.py | 82 ++-------- src/authsome/auth/utils.py | 11 -- tests/auth/test_oauth2_client.py | 194 +++++++++++++++++++++++ uv.lock | 27 ++++ 8 files changed, 419 insertions(+), 168 deletions(-) create mode 100644 src/authsome/auth/flows/oauth2_client.py create mode 100644 tests/auth/test_oauth2_client.py diff --git a/pyproject.toml b/pyproject.toml index 639d1c7..b5be3e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "python-multipart>=0.0.27", "jinja2>=3.1", "py-key-value-aio[disk]", + "authlib>=1.7.2", ] [project.optional-dependencies] @@ -74,4 +75,4 @@ select = ["E", "F", "I", "UP", "N", "C90"] ignore = [] [tool.ruff.lint.mccabe] -max-complexity = 20 \ No newline at end of file +max-complexity = 20 diff --git a/src/authsome/auth/flows/base.py b/src/authsome/auth/flows/base.py index aba31cb..cd0ffa9 100644 --- a/src/authsome/auth/flows/base.py +++ b/src/authsome/auth/flows/base.py @@ -7,9 +7,9 @@ from datetime import timedelta from typing import TYPE_CHECKING, Any -import requests as http_client from loguru import logger +from authsome.auth.flows.oauth2_client import refresh_oauth_token, revoke_oauth_token from authsome.auth.models.connection import ConnectionRecord, ProviderClientRecord from authsome.auth.models.enums import ConnectionStatus from authsome.auth.models.provider import ProviderDefinition @@ -92,27 +92,23 @@ async def revoke( if not revocation_url: return - def _do_revoke(token: str, token_type: str) -> None: - payload = {"token": token} - if client_id: - payload["client_id"] = client_id - if client_secret: - payload["client_secret"] = client_secret - + def _do_revoke(token: str, token_type: str, token_type_hint: str) -> None: try: - http_client.post( - revocation_url, - data=payload, - timeout=15, + revoke_oauth_token( + provider=provider, + token=token, + token_type_hint=token_type_hint, + client_id=client_id, + client_secret=client_secret, ) except Exception as exc: logger.warning(f"{token_type.capitalize()} token revocation failed (continuing): {{}}", exc) if record.access_token: - _do_revoke(record.access_token, "access") + _do_revoke(record.access_token, "access", "access_token") if record.refresh_token: - _do_revoke(record.refresh_token, "refresh") + _do_revoke(record.refresh_token, "refresh", "refresh_token") def refresh( self, @@ -134,22 +130,12 @@ def refresh( if not client_id: raise RefreshFailedError("No client_id available for refresh", provider=provider.name) - payload: dict[str, str] = { - "grant_type": "refresh_token", - "refresh_token": record.refresh_token, - "client_id": client_id, - } - if client_secret: - payload["client_secret"] = client_secret - - resp = http_client.post( - provider.oauth.token_url, - data=payload, - headers={"Accept": "application/json"}, - timeout=30, + token = refresh_oauth_token( + provider=provider, + refresh_token=record.refresh_token, + client_id=client_id, + client_secret=client_secret, ) - resp.raise_for_status() - token = resp.json() now = utc_now() record.access_token = token["access_token"] diff --git a/src/authsome/auth/flows/dcr_pkce.py b/src/authsome/auth/flows/dcr_pkce.py index 1978e7e..4a9ad94 100644 --- a/src/authsome/auth/flows/dcr_pkce.py +++ b/src/authsome/auth/flows/dcr_pkce.py @@ -3,7 +3,6 @@ from __future__ import annotations import json -import secrets import urllib.parse from datetime import timedelta from typing import TYPE_CHECKING, Any @@ -11,10 +10,11 @@ import requests as http_client from authsome.auth.flows.base import AuthFlow, FlowResult +from authsome.auth.flows.oauth2_client import create_pkce_authorization, exchange_authorization_code from authsome.auth.models.connection import AccountInfo, ConnectionRecord, ProviderClientRecord from authsome.auth.models.enums import AuthType, ConnectionStatus from authsome.auth.models.provider import ProviderDefinition -from authsome.auth.utils import generate_pkce, resolve_callback_url +from authsome.auth.utils import resolve_callback_url from authsome.errors import AuthenticationFailedError, DiscoveryError from authsome.utils import utc_now @@ -51,21 +51,13 @@ async def begin( assert client_id is not None # guaranteed: either passed in or registered above - code_verifier, code_challenge = generate_pkce() - - state = secrets.token_urlsafe(32) - auth_params: dict[str, str] = { - "response_type": "code", - "client_id": client_id, - "redirect_uri": redirect_uri, - "state": state, - "code_challenge": code_challenge, - "code_challenge_method": "S256", - } - if effective_scopes: - auth_params["scope"] = " ".join(effective_scopes) - - auth_url = f"{provider.oauth.authorization_url}?{urllib.parse.urlencode(auth_params)}" + auth_url, state, code_verifier = create_pkce_authorization( + provider=provider, + client_id=client_id, + client_secret=client_secret, + redirect_uri=redirect_uri, + scopes=effective_scopes, + ) runtime_session.state = "waiting_for_user" runtime_session.payload["auth_url"] = auth_url @@ -99,10 +91,10 @@ async def resume( if not auth_code: raise AuthenticationFailedError("Authorization timed out or no code received", provider=provider.name) - returned_state = callback_data.get("state") + returned_state = callback_data.get("state", "") expected_state = runtime_session.payload.get("internal_state") - if returned_state != expected_state: - raise AuthenticationFailedError("OAuth state mismatch — potential CSRF attack", provider=provider.name) + if not expected_state: + raise AuthenticationFailedError("OAuth state missing from session", provider=provider.name) # If DCR registered a client, it's stored in payload if "internal_client_id" in runtime_session.payload: @@ -119,9 +111,11 @@ async def resume( redirect_uri = runtime_session.payload.get("callback_url", "") effective_scopes = json.loads(runtime_session.payload.get("internal_scopes", "[]")) - token_data = await self._exchange_code( + token_data = exchange_authorization_code( provider=provider, auth_code=auth_code, + expected_state=expected_state, + returned_state=returned_state, redirect_uri=redirect_uri, client_id=client_id, client_secret=client_secret, @@ -223,42 +217,3 @@ async def _register_client( if not client_id: raise AuthenticationFailedError("DCR response missing client_id", provider=provider.name) return client_id, reg_data.get("client_secret") - - @staticmethod - async def _exchange_code( - *, - provider: ProviderDefinition, - auth_code: str, - redirect_uri: str, - client_id: str, - client_secret: str | None, - code_verifier: str, - ) -> dict[str, Any]: - assert provider.oauth is not None - payload: dict[str, str] = { - "grant_type": "authorization_code", - "code": auth_code, - "redirect_uri": redirect_uri, - "client_id": client_id, - "code_verifier": code_verifier, - } - if client_secret: - payload["client_secret"] = client_secret - - try: - resp = http_client.post( - provider.oauth.token_url, data=payload, headers={"Accept": "application/json"}, timeout=30 - ) - resp.raise_for_status() - data = resp.json() - except http_client.RequestException as exc: - raise AuthenticationFailedError(f"Token exchange failed: {exc}", provider=provider.name) from exc - except json.JSONDecodeError as exc: - raise AuthenticationFailedError("Token response was not valid JSON", provider=provider.name) from exc - - if "access_token" not in data: - raise AuthenticationFailedError( - f"Token exchange error: {data.get('error', '')} — {data.get('error_description', 'Unknown error')}", - provider=provider.name, - ) - return data diff --git a/src/authsome/auth/flows/oauth2_client.py b/src/authsome/auth/flows/oauth2_client.py new file mode 100644 index 0000000..4179eb2 --- /dev/null +++ b/src/authsome/auth/flows/oauth2_client.py @@ -0,0 +1,151 @@ +"""Small Authlib-backed helpers for OAuth2 client flows.""" + +from __future__ import annotations + +import urllib.parse +from typing import Any + +import requests as http_client +from authlib.common.security import generate_token +from authlib.integrations.base_client.errors import OAuthError +from authlib.integrations.requests_client import OAuth2Session +from authlib.oauth2 import OAuth2Error + +from authsome.auth.models.provider import ProviderDefinition +from authsome.errors import AuthenticationFailedError, RefreshFailedError + +_PKCE_VERIFIER_LENGTH = 64 + + +def create_pkce_authorization( + *, + provider: ProviderDefinition, + client_id: str, + client_secret: str | None, + redirect_uri: str, + scopes: list[str], +) -> tuple[str, str, str]: + """Create an authorization URL and state using Authlib's PKCE support.""" + assert provider.oauth is not None + + session = OAuth2Session( + client_id=client_id, + client_secret=client_secret, + scope=" ".join(scopes) if scopes else None, + redirect_uri=redirect_uri, + code_challenge_method="S256", + token_endpoint_auth_method=_token_endpoint_auth_method(client_secret), + ) + code_verifier = generate_token(_PKCE_VERIFIER_LENGTH) + authorization_url, state = session.create_authorization_url( + provider.oauth.authorization_url, + code_verifier=code_verifier, + ) + return authorization_url, state, code_verifier + + +def exchange_authorization_code( + *, + provider: ProviderDefinition, + auth_code: str, + expected_state: str, + returned_state: str, + redirect_uri: str, + client_id: str, + client_secret: str | None, + code_verifier: str, +) -> dict[str, Any]: + """Exchange an authorization code for tokens using Authlib.""" + assert provider.oauth is not None + + session = OAuth2Session( + client_id=client_id, + client_secret=client_secret, + state=expected_state, + redirect_uri=redirect_uri, + token_endpoint_auth_method=_token_endpoint_auth_method(client_secret), + ) + authorization_response = _authorization_response_url( + redirect_uri=redirect_uri, + code=auth_code, + state=returned_state, + ) + + try: + token = session.fetch_token( + provider.oauth.token_url, + authorization_response=authorization_response, + code_verifier=code_verifier, + timeout=30, + ) + except (OAuthError, OAuth2Error, http_client.RequestException, ValueError) as exc: + raise AuthenticationFailedError(f"Token exchange failed: {exc}", provider=provider.name) from exc + + if "access_token" not in token: + error = token.get("error", "") + error_desc = token.get("error_description", "Unknown error") + raise AuthenticationFailedError(f"Token exchange error: {error} - {error_desc}", provider=provider.name) + + return dict(token) + + +def refresh_oauth_token( + *, + provider: ProviderDefinition, + refresh_token: str, + client_id: str, + client_secret: str | None, +) -> dict[str, Any]: + """Refresh an OAuth access token using Authlib.""" + assert provider.oauth is not None + + session = OAuth2Session( + client_id=client_id, + client_secret=client_secret, + token_endpoint_auth_method=_token_endpoint_auth_method(client_secret), + ) + try: + token = session.refresh_token( + provider.oauth.token_url, + refresh_token=refresh_token, + timeout=30, + ) + except (OAuthError, OAuth2Error, http_client.RequestException, ValueError) as exc: + raise RefreshFailedError(f"Token refresh failed: {exc}", provider=provider.name) from exc + + return dict(token) + + +def revoke_oauth_token( + *, + provider: ProviderDefinition, + token: str, + token_type_hint: str, + client_id: str | None, + client_secret: str | None, +) -> None: + """Revoke an OAuth token using Authlib.""" + assert provider.oauth is not None + assert provider.oauth.revocation_url is not None + + session = OAuth2Session( + client_id=client_id, + client_secret=client_secret, + revocation_endpoint_auth_method=_token_endpoint_auth_method(client_secret), + ) + session.revoke_token( + provider.oauth.revocation_url, + token=token, + token_type_hint=token_type_hint, + timeout=15, + ) + + +def _token_endpoint_auth_method(client_secret: str | None) -> str: + return "client_secret_post" if client_secret else "none" + + +def _authorization_response_url(*, redirect_uri: str, code: str, state: str) -> str: + parsed = urllib.parse.urlsplit(redirect_uri) + query = urllib.parse.urlencode({"code": code, "state": state}) + return urllib.parse.urlunsplit((parsed.scheme, parsed.netloc, parsed.path, query, parsed.fragment)) diff --git a/src/authsome/auth/flows/pkce.py b/src/authsome/auth/flows/pkce.py index 3c5b8d2..0ec26be 100644 --- a/src/authsome/auth/flows/pkce.py +++ b/src/authsome/auth/flows/pkce.py @@ -3,18 +3,15 @@ from __future__ import annotations import json -import secrets -import urllib.parse from datetime import timedelta from typing import TYPE_CHECKING, Any -import requests as http_client - from authsome.auth.flows.base import AuthFlow, FlowResult +from authsome.auth.flows.oauth2_client import create_pkce_authorization, exchange_authorization_code from authsome.auth.models.connection import AccountInfo, ConnectionRecord from authsome.auth.models.enums import AuthType, ConnectionStatus from authsome.auth.models.provider import ProviderDefinition -from authsome.auth.utils import generate_pkce, resolve_callback_url +from authsome.auth.utils import resolve_callback_url from authsome.errors import AuthenticationFailedError from authsome.utils import utc_now @@ -44,23 +41,16 @@ async def begin( raise AuthenticationFailedError("PKCE flow requires a client_id.", provider=provider.name) effective_scopes = scopes or provider.oauth.scopes or [] - code_verifier, code_challenge = generate_pkce() redirect_uri = resolve_callback_url(runtime_session) - state = secrets.token_urlsafe(32) - auth_params: dict[str, str] = { - "response_type": "code", - "client_id": client_id, - "redirect_uri": redirect_uri, - "state": state, - "code_challenge": code_challenge, - "code_challenge_method": "S256", - } - if effective_scopes: - auth_params["scope"] = " ".join(effective_scopes) - - auth_url = f"{provider.oauth.authorization_url}?{urllib.parse.urlencode(auth_params)}" + auth_url, state, code_verifier = create_pkce_authorization( + provider=provider, + client_id=client_id, + client_secret=client_secret, + redirect_uri=redirect_uri, + scopes=effective_scopes, + ) runtime_session.state = "waiting_for_user" runtime_session.payload["auth_url"] = auth_url @@ -92,18 +82,20 @@ async def resume( if not auth_code: raise AuthenticationFailedError("Authorization timed out or no code received", provider=provider.name) - returned_state = callback_data.get("state") + returned_state = callback_data.get("state", "") expected_state = runtime_session.payload.get("internal_state") - if returned_state != expected_state: - raise AuthenticationFailedError("OAuth state mismatch — potential CSRF attack", provider=provider.name) + if not expected_state: + raise AuthenticationFailedError("OAuth state missing from session", provider=provider.name) code_verifier = runtime_session.payload.get("internal_code_verifier", "") redirect_uri = runtime_session.payload.get("callback_url", "") effective_scopes = json.loads(runtime_session.payload.get("internal_scopes", "[]")) - token_data = await self._exchange_code( + token_data = exchange_authorization_code( provider=provider, auth_code=auth_code, + expected_state=expected_state, + returned_state=returned_state, redirect_uri=redirect_uri, client_id=client_id, client_secret=client_secret, @@ -135,47 +127,3 @@ async def resume( metadata=metadata, ) ) - - @staticmethod - async def _exchange_code( - *, - provider: ProviderDefinition, - auth_code: str, - redirect_uri: str, - client_id: str, - client_secret: str | None, - code_verifier: str, - ) -> dict[str, Any]: - assert provider.oauth is not None - payload: dict[str, str] = { - "grant_type": "authorization_code", - "code": auth_code, - "redirect_uri": redirect_uri, - "client_id": client_id, - "code_verifier": code_verifier, - } - if client_secret: - payload["client_secret"] = client_secret - - try: - resp = http_client.post( - provider.oauth.token_url, - data=payload, - headers={"Accept": "application/json"}, - timeout=30, - ) - resp.raise_for_status() - except http_client.RequestException as exc: - raise AuthenticationFailedError(f"Token exchange failed: {exc}", provider=provider.name) from exc - - try: - data = resp.json() - except json.JSONDecodeError as exc: - raise AuthenticationFailedError("Token response was not valid JSON", provider=provider.name) from exc - - if "access_token" not in data: - error = data.get("error", "") - error_desc = data.get("error_description", "Unknown error") - raise AuthenticationFailedError(f"Token exchange error: {error} — {error_desc}", provider=provider.name) - - return data diff --git a/src/authsome/auth/utils.py b/src/authsome/auth/utils.py index 01173d6..9ba8246 100644 --- a/src/authsome/auth/utils.py +++ b/src/authsome/auth/utils.py @@ -2,10 +2,7 @@ from __future__ import annotations -import hashlib import re -import secrets -from base64 import urlsafe_b64encode from typing import TYPE_CHECKING from urllib.parse import urlsplit, urlunsplit @@ -17,14 +14,6 @@ _DEFAULT_CALLBACK_URL = build_callback_url(DEFAULT_SERVER_BASE_URL) -def generate_pkce() -> tuple[str, str]: - """Generate code verifier and challenge for PKCE.""" - code_verifier = secrets.token_urlsafe(64)[:128] - digest = hashlib.sha256(code_verifier.encode("ascii")).digest() - code_challenge = urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") - return code_verifier, code_challenge - - def resolve_callback_url(runtime_session: AuthSession) -> str: """Resolve the callback URL.""" callback_override = runtime_session.payload.get("callback_url_override") diff --git a/tests/auth/test_oauth2_client.py b/tests/auth/test_oauth2_client.py new file mode 100644 index 0000000..f8507a5 --- /dev/null +++ b/tests/auth/test_oauth2_client.py @@ -0,0 +1,194 @@ +"""Tests for Authlib-backed OAuth2 flow helpers.""" + +from __future__ import annotations + +import urllib.parse +from typing import Any + +import pytest +from requests import Session + +from authsome.auth.flows.oauth2_client import ( + create_pkce_authorization, + exchange_authorization_code, + refresh_oauth_token, + revoke_oauth_token, +) +from authsome.auth.flows.pkce import PkceFlow +from authsome.auth.models.enums import AuthType, FlowType +from authsome.auth.models.provider import OAuthConfig, ProviderDefinition +from authsome.errors import AuthenticationFailedError + + +class _TokenResponse: + status_code = 200 + + def json(self) -> dict[str, Any]: + return { + "access_token": "access-token", + "refresh_token": "refresh-token", + "token_type": "Bearer", + "expires_in": 3600, + } + + +def _make_oauth_provider() -> ProviderDefinition: + return ProviderDefinition( + name="oauth-test", + display_name="OAuth Test", + auth_type=AuthType.OAUTH2, + flow=FlowType.PKCE, + oauth=OAuthConfig( + authorization_url="https://example.com/oauth/authorize", + token_url="https://example.com/oauth/token", + revocation_url="https://example.com/oauth/revoke", + scopes=["read"], + ), + ) + + +def test_create_pkce_authorization_uses_authlib_pkce_params() -> None: + provider = _make_oauth_provider() + + auth_url, state, code_verifier = create_pkce_authorization( + provider=provider, + client_id="client-id", + client_secret=None, + redirect_uri="http://127.0.0.1:7999/auth/callback", + scopes=["read", "write"], + ) + + parsed = urllib.parse.urlsplit(auth_url) + params = urllib.parse.parse_qs(parsed.query) + + assert parsed.scheme == "https" + assert parsed.netloc == "example.com" + assert parsed.path == "/oauth/authorize" + assert params["client_id"] == ["client-id"] + assert params["scope"] == ["read write"] + assert params["state"] == [state] + assert params["code_challenge_method"] == ["S256"] + assert params["code_challenge"][0] + assert code_verifier + + +def test_exchange_authorization_code_uses_authlib_token_request(monkeypatch: pytest.MonkeyPatch) -> None: + provider = _make_oauth_provider() + captured: dict[str, Any] = {} + + def fake_post(self: Session, url: str, data: dict[str, str], **kwargs: Any) -> _TokenResponse: + captured["url"] = url + captured["data"] = data + captured["headers"] = kwargs.get("headers") + captured["auth"] = kwargs.get("auth") + return _TokenResponse() + + monkeypatch.setattr(Session, "post", fake_post) + + token = exchange_authorization_code( + provider=provider, + auth_code="auth-code", + expected_state="expected-state", + returned_state="expected-state", + redirect_uri="http://127.0.0.1:7999/auth/callback", + client_id="client-id", + client_secret=None, + code_verifier="verifier", + ) + + assert token["access_token"] == "access-token" + assert captured["url"] == "https://example.com/oauth/token" + assert captured["data"]["grant_type"] == "authorization_code" + assert captured["data"]["code"] == "auth-code" + assert captured["data"]["code_verifier"] == "verifier" + assert captured["auth"] is not None + + +def test_exchange_authorization_code_rejects_state_mismatch() -> None: + provider = _make_oauth_provider() + + with pytest.raises(AuthenticationFailedError, match="Token exchange failed"): + exchange_authorization_code( + provider=provider, + auth_code="auth-code", + expected_state="expected-state", + returned_state="wrong-state", + redirect_uri="http://127.0.0.1:7999/auth/callback", + client_id="client-id", + client_secret=None, + code_verifier="verifier", + ) + + +def test_refresh_oauth_token_uses_authlib_refresh_request(monkeypatch: pytest.MonkeyPatch) -> None: + provider = _make_oauth_provider() + captured: dict[str, Any] = {} + + def fake_post(self: Session, url: str, data: dict[str, str], **kwargs: Any) -> _TokenResponse: + captured["url"] = url + captured["data"] = data + captured["auth"] = kwargs.get("auth") + return _TokenResponse() + + monkeypatch.setattr(Session, "post", fake_post) + + token = refresh_oauth_token( + provider=provider, + refresh_token="old-refresh-token", + client_id="client-id", + client_secret="client-secret", + ) + + assert token["access_token"] == "access-token" + assert captured["url"] == "https://example.com/oauth/token" + assert captured["data"]["grant_type"] == "refresh_token" + assert captured["data"]["refresh_token"] == "old-refresh-token" + assert captured["auth"] is not None + + +def test_revoke_oauth_token_uses_authlib_revocation_request(monkeypatch: pytest.MonkeyPatch) -> None: + provider = _make_oauth_provider() + captured: dict[str, Any] = {} + + def fake_post(self: Session, url: str, data: dict[str, str], **kwargs: Any) -> _TokenResponse: + captured["url"] = url + captured["data"] = data + captured["auth"] = kwargs.get("auth") + return _TokenResponse() + + monkeypatch.setattr(Session, "post", fake_post) + + revoke_oauth_token( + provider=provider, + token="access-token", + token_type_hint="access_token", + client_id="client-id", + client_secret="client-secret", + ) + + assert captured["url"] == "https://example.com/oauth/revoke" + assert captured["data"]["token"] == "access-token" + assert captured["data"]["token_type_hint"] == "access_token" + assert captured["auth"] is not None + + +@pytest.mark.asyncio +async def test_pkce_flow_begin_stores_authlib_authorization_state() -> None: + from unittest.mock import Mock + + provider = _make_oauth_provider() + session = Mock() + session.payload = {} + + await PkceFlow().begin( + provider=provider, + profile="default", + connection_name="default", + runtime_session=session, + client_id="client-id", + ) + + assert session.state == "waiting_for_user" + assert session.payload["auth_url"].startswith("https://example.com/oauth/authorize?") + assert session.payload["internal_state"] + assert session.payload["internal_code_verifier"] diff --git a/uv.lock b/uv.lock index 05287dc..f8eb144 100644 --- a/uv.lock +++ b/uv.lock @@ -119,11 +119,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/64/b4/17d4b0b2a2dc85a6df63d1157e028ed19f90d4cd97c36717afef2bc2f395/attrs-26.1.0-py3-none-any.whl", hash = "sha256:c647aa4a12dfbad9333ca4e71fe62ddc36f4e63b2d260a37a8b83d2f043ac309", size = 67548, upload-time = "2026-03-19T14:22:23.645Z" }, ] +[[package]] +name = "authlib" +version = "1.7.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, + { name = "joserfc" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/36/98/7d93f30d029643c0275dbc0bd6d5a6f670661ee6c9a94d93af7ab4887600/authlib-1.7.2.tar.gz", hash = "sha256:2cea25fefcd4e7173bdf1372c0afc265c8034b23a8cd5dcb6a9164b826c64231", size = 176511, upload-time = "2026-05-06T08:10:23.116Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/95/adcb68e20c34162e9135f370d6e31737719c2b6f94bc953fe7ed1f10fe21/authlib-1.7.2-py2.py3-none-any.whl", hash = "sha256:3e1faedc9d87e7d56a164eca3ccb6ace0d61b94abe83e92242f8dc8bba9b4a9f", size = 259548, upload-time = "2026-05-06T08:10:21.436Z" }, +] + [[package]] name = "authsome" version = "0.2.4" source = { editable = "." } dependencies = [ + { name = "authlib" }, { name = "click" }, { name = "cryptography" }, { name = "fastapi" }, @@ -149,6 +163,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "authlib", specifier = ">=1.7.2" }, { name = "click", specifier = ">=8.0" }, { name = "cryptography", specifier = ">=41.0" }, { name = "fastapi", specifier = ">=0.115" }, @@ -698,6 +713,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, ] +[[package]] +name = "joserfc" +version = "1.6.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3b/dc/5f768c2e391e9afabe5d18e3221346deb5fb6338565f1ccc9e7c6d7befdd/joserfc-1.6.5.tar.gz", hash = "sha256:1482a7db78fb4602e44ed89e51b599d052e091288c7c532c5b694e20149dec48", size = 231881, upload-time = "2026-05-06T04:58:13.408Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/3b/ad1cb22e75c963b1f07c8a2329bf47227ce7e4361df5eb2fb101b2ce33ef/joserfc-1.6.5-py3-none-any.whl", hash = "sha256:e9878a0f8243fe7b95e11fdda81374ca9f7a689e302751579d3dfdeec559675e", size = 70464, upload-time = "2026-05-06T04:58:11.668Z" }, +] + [[package]] name = "kaitaistruct" version = "0.11" From 23c5f2dd7a0b046865402822ae89ee107f97d44e Mon Sep 17 00:00:00 2001 From: Manoj Bajaj Date: Wed, 13 May 2026 16:18:29 +0530 Subject: [PATCH 2/2] refactor: use authlib for device token polling Entire-Checkpoint: 53df39be90c1 --- src/authsome/auth/flows/device_code.py | 65 +++++++++--------------- src/authsome/auth/flows/oauth2_client.py | 35 +++++++++++++ tests/auth/test_oauth2_client.py | 54 ++++++++++++++++++++ 3 files changed, 114 insertions(+), 40 deletions(-) diff --git a/src/authsome/auth/flows/device_code.py b/src/authsome/auth/flows/device_code.py index 7a3d0df..86a885f 100644 --- a/src/authsome/auth/flows/device_code.py +++ b/src/authsome/auth/flows/device_code.py @@ -8,9 +8,12 @@ from typing import TYPE_CHECKING, Any import requests +from authlib.integrations.base_client.errors import OAuthError +from authlib.oauth2 import OAuth2Error from loguru import logger from authsome.auth.flows.base import AuthFlow, FlowResult +from authsome.auth.flows.oauth2_client import fetch_device_token from authsome.auth.models.connection import AccountInfo, ConnectionRecord from authsome.auth.models.enums import AuthType, ConnectionStatus from authsome.auth.models.provider import ProviderDefinition @@ -183,57 +186,39 @@ async def poll_for_token( effective_expires_in = min(expires_in, 300) deadline = time.monotonic() + effective_expires_in - use_json = provider.oauth.device_token_request == "json" - while time.monotonic() < deadline: await asyncio.sleep(poll_interval) try: - if use_json: - resp = requests.post( - provider.oauth.token_url, - json={"device_code": device_code}, - headers={"Accept": "application/json", "Content-Type": "application/json"}, - timeout=30, - ) - else: - payload: dict[str, str] = { - "grant_type": "urn:ietf:params:oauth:grant-type:device_code", - "device_code": device_code, - } - if client_id: - payload["client_id"] = client_id - if client_secret: - payload["client_secret"] = client_secret - resp = requests.post( - provider.oauth.token_url, data=payload, headers={"Accept": "application/json"}, timeout=30 - ) + return fetch_device_token( + provider=provider, + device_code=device_code, + client_id=client_id, + client_secret=client_secret, + ) except requests.RequestException as exc: logger.warning("Token poll request failed: {}, retrying...", exc) continue - - try: - data = resp.json() except json.JSONDecodeError: logger.warning("Token poll response was not JSON, retrying...") continue - - if resp.status_code == 200 and "access_token" in data: - return data - - error = data.get("error", "") - if error == "authorization_pending": - continue - elif error == "slow_down": - poll_interval += 5 - elif error == "access_denied": - raise AuthenticationFailedError("User denied the authorization request", provider=provider.name) - elif error == "expired_token": - raise AuthenticationFailedError("Device code has expired. Please try again.", provider=provider.name) - else: + except (OAuthError, OAuth2Error) as exc: + error = getattr(exc, "error", "") + if error == "authorization_pending": + continue + if error == "slow_down": + poll_interval += 5 + continue + if error == "access_denied": + raise AuthenticationFailedError("User denied the authorization request", provider=provider.name) + if error == "expired_token": + raise AuthenticationFailedError( + "Device code has expired. Please try again.", provider=provider.name + ) + description = getattr(exc, "description", None) raise AuthenticationFailedError( - f"Token endpoint error: {data.get('error_description', error or 'Unknown error')}", + f"Token endpoint error: {description or error or 'Unknown error'}", provider=provider.name, - ) + ) from exc raise AuthenticationFailedError("Device authorization timed out.", provider=provider.name) diff --git a/src/authsome/auth/flows/oauth2_client.py b/src/authsome/auth/flows/oauth2_client.py index 4179eb2..d924d25 100644 --- a/src/authsome/auth/flows/oauth2_client.py +++ b/src/authsome/auth/flows/oauth2_client.py @@ -15,6 +15,7 @@ from authsome.errors import AuthenticationFailedError, RefreshFailedError _PKCE_VERIFIER_LENGTH = 64 +_DEVICE_CODE_GRANT = "urn:ietf:params:oauth:grant-type:device_code" def create_pkce_authorization( @@ -141,6 +142,40 @@ def revoke_oauth_token( ) +def fetch_device_token( + *, + provider: ProviderDefinition, + device_code: str, + client_id: str | None, + client_secret: str | None, +) -> dict[str, Any]: + """Poll a device-code token endpoint once using Authlib response handling.""" + assert provider.oauth is not None + + session = OAuth2Session( + client_id=client_id, + client_secret=client_secret, + token_endpoint_auth_method=_token_endpoint_auth_method(client_secret), + ) + + if provider.oauth.device_token_request == "json": + response = http_client.post( + provider.oauth.token_url, + json={"device_code": device_code}, + headers={"Accept": "application/json", "Content-Type": "application/json"}, + timeout=30, + ) + return dict(session.parse_response_token(response)) + + token = session.fetch_token( + provider.oauth.token_url, + grant_type=_DEVICE_CODE_GRANT, + device_code=device_code, + timeout=30, + ) + return dict(token) + + def _token_endpoint_auth_method(client_secret: str | None) -> str: return "client_secret_post" if client_secret else "none" diff --git a/tests/auth/test_oauth2_client.py b/tests/auth/test_oauth2_client.py index f8507a5..ef363f1 100644 --- a/tests/auth/test_oauth2_client.py +++ b/tests/auth/test_oauth2_client.py @@ -11,6 +11,7 @@ from authsome.auth.flows.oauth2_client import ( create_pkce_authorization, exchange_authorization_code, + fetch_device_token, refresh_oauth_token, revoke_oauth_token, ) @@ -172,6 +173,59 @@ def fake_post(self: Session, url: str, data: dict[str, str], **kwargs: Any) -> _ assert captured["auth"] is not None +def test_fetch_device_token_uses_authlib_form_request(monkeypatch: pytest.MonkeyPatch) -> None: + provider = _make_oauth_provider() + captured: dict[str, Any] = {} + + def fake_post(self: Session, url: str, data: dict[str, str], **kwargs: Any) -> _TokenResponse: + captured["url"] = url + captured["data"] = data + captured["auth"] = kwargs.get("auth") + return _TokenResponse() + + monkeypatch.setattr(Session, "post", fake_post) + + token = fetch_device_token( + provider=provider, + device_code="device-code", + client_id="client-id", + client_secret=None, + ) + + assert token["access_token"] == "access-token" + assert captured["url"] == "https://example.com/oauth/token" + assert captured["data"]["grant_type"] == "urn:ietf:params:oauth:grant-type:device_code" + assert captured["data"]["device_code"] == "device-code" + assert captured["auth"] is not None + + +def test_fetch_device_token_keeps_json_variant(monkeypatch: pytest.MonkeyPatch) -> None: + provider = _make_oauth_provider() + assert provider.oauth is not None + provider.oauth.device_token_request = "json" + captured: dict[str, Any] = {} + + def fake_post(url: str, json: dict[str, str], **kwargs: Any) -> _TokenResponse: + captured["url"] = url + captured["json"] = json + captured["headers"] = kwargs.get("headers") + return _TokenResponse() + + monkeypatch.setattr("authsome.auth.flows.oauth2_client.http_client.post", fake_post) + + token = fetch_device_token( + provider=provider, + device_code="device-code", + client_id="client-id", + client_secret="client-secret", + ) + + assert token["access_token"] == "access-token" + assert captured["url"] == "https://example.com/oauth/token" + assert captured["json"] == {"device_code": "device-code"} + assert captured["headers"] == {"Accept": "application/json", "Content-Type": "application/json"} + + @pytest.mark.asyncio async def test_pkce_flow_begin_stores_authlib_authorization_state() -> None: from unittest.mock import Mock