Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"python-multipart>=0.0.27",
"jinja2>=3.1",
"py-key-value-aio[disk]",
"authlib>=1.7.2",
"pyjwt>=2.12.1",
"base58>=2.1.1",
]
Expand Down
44 changes: 15 additions & 29 deletions src/authsome/auth/flows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from datetime import timedelta
from typing import TYPE_CHECKING, Any

import requests as http_client
from loguru import logger

from authsome.auth.flows.oauth2_client import refresh_oauth_token, revoke_oauth_token
from authsome.auth.models.connection import ConnectionRecord, ProviderClientRecord
from authsome.auth.models.enums import ConnectionStatus
from authsome.auth.models.provider import ProviderDefinition
Expand Down Expand Up @@ -92,27 +92,23 @@ async def revoke(
if not revocation_url:
return

def _do_revoke(token: str, token_type: str) -> None:
payload = {"token": token}
if client_id:
payload["client_id"] = client_id
if client_secret:
payload["client_secret"] = client_secret

def _do_revoke(token: str, token_type: str, token_type_hint: str) -> None:
try:
http_client.post(
revocation_url,
data=payload,
timeout=15,
revoke_oauth_token(
provider=provider,
token=token,
token_type_hint=token_type_hint,
client_id=client_id,
client_secret=client_secret,
)
except Exception as exc:
logger.warning(f"{token_type.capitalize()} token revocation failed (continuing): {{}}", exc)

if record.access_token:
_do_revoke(record.access_token, "access")
_do_revoke(record.access_token, "access", "access_token")

if record.refresh_token:
_do_revoke(record.refresh_token, "refresh")
_do_revoke(record.refresh_token, "refresh", "refresh_token")

def refresh(
self,
Expand All @@ -134,22 +130,12 @@ def refresh(
if not client_id:
raise RefreshFailedError("No client_id available for refresh", provider=provider.name)

payload: dict[str, str] = {
"grant_type": "refresh_token",
"refresh_token": record.refresh_token,
"client_id": client_id,
}
if client_secret:
payload["client_secret"] = client_secret

resp = http_client.post(
provider.oauth.token_url,
data=payload,
headers={"Accept": "application/json"},
timeout=30,
token = refresh_oauth_token(
provider=provider,
refresh_token=record.refresh_token,
client_id=client_id,
client_secret=client_secret,
)
resp.raise_for_status()
token = resp.json()

now = utc_now()
record.access_token = token["access_token"]
Expand Down
75 changes: 15 additions & 60 deletions src/authsome/auth/flows/dcr_pkce.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
from __future__ import annotations

import json
import secrets
import urllib.parse
from datetime import timedelta
from typing import TYPE_CHECKING, Any

import requests as http_client

from authsome.auth.flows.base import AuthFlow, FlowResult
from authsome.auth.flows.oauth2_client import create_pkce_authorization, exchange_authorization_code
from authsome.auth.models.connection import AccountInfo, ConnectionRecord, ProviderClientRecord
from authsome.auth.models.enums import AuthType, ConnectionStatus
from authsome.auth.models.provider import ProviderDefinition
from authsome.auth.utils import generate_pkce, resolve_callback_url
from authsome.auth.utils import resolve_callback_url
from authsome.errors import AuthenticationFailedError, DiscoveryError
from authsome.utils import utc_now

Expand Down Expand Up @@ -51,21 +51,13 @@ async def begin(

assert client_id is not None # guaranteed: either passed in or registered above

code_verifier, code_challenge = generate_pkce()

state = secrets.token_urlsafe(32)
auth_params: dict[str, str] = {
"response_type": "code",
"client_id": client_id,
"redirect_uri": redirect_uri,
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
}
if effective_scopes:
auth_params["scope"] = " ".join(effective_scopes)

auth_url = f"{provider.oauth.authorization_url}?{urllib.parse.urlencode(auth_params)}"
auth_url, state, code_verifier = create_pkce_authorization(
provider=provider,
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
scopes=effective_scopes,
)

runtime_session.state = "waiting_for_user"
runtime_session.payload["auth_url"] = auth_url
Expand Down Expand Up @@ -99,10 +91,10 @@ async def resume(
if not auth_code:
raise AuthenticationFailedError("Authorization timed out or no code received", provider=provider.name)

returned_state = callback_data.get("state")
returned_state = callback_data.get("state", "")
expected_state = runtime_session.payload.get("internal_state")
if returned_state != expected_state:
raise AuthenticationFailedError("OAuth state mismatch — potential CSRF attack", provider=provider.name)
if not expected_state:
raise AuthenticationFailedError("OAuth state missing from session", provider=provider.name)

# If DCR registered a client, it's stored in payload
if "internal_client_id" in runtime_session.payload:
Expand All @@ -119,9 +111,11 @@ async def resume(
redirect_uri = runtime_session.payload.get("callback_url", "")
effective_scopes = json.loads(runtime_session.payload.get("internal_scopes", "[]"))

token_data = await self._exchange_code(
token_data = exchange_authorization_code(
provider=provider,
auth_code=auth_code,
expected_state=expected_state,
returned_state=returned_state,
redirect_uri=redirect_uri,
client_id=client_id,
client_secret=client_secret,
Expand Down Expand Up @@ -223,42 +217,3 @@ async def _register_client(
if not client_id:
raise AuthenticationFailedError("DCR response missing client_id", provider=provider.name)
return client_id, reg_data.get("client_secret")

@staticmethod
async def _exchange_code(
*,
provider: ProviderDefinition,
auth_code: str,
redirect_uri: str,
client_id: str,
client_secret: str | None,
code_verifier: str,
) -> dict[str, Any]:
assert provider.oauth is not None
payload: dict[str, str] = {
"grant_type": "authorization_code",
"code": auth_code,
"redirect_uri": redirect_uri,
"client_id": client_id,
"code_verifier": code_verifier,
}
if client_secret:
payload["client_secret"] = client_secret

try:
resp = http_client.post(
provider.oauth.token_url, data=payload, headers={"Accept": "application/json"}, timeout=30
)
resp.raise_for_status()
data = resp.json()
except http_client.RequestException as exc:
raise AuthenticationFailedError(f"Token exchange failed: {exc}", provider=provider.name) from exc
except json.JSONDecodeError as exc:
raise AuthenticationFailedError("Token response was not valid JSON", provider=provider.name) from exc

if "access_token" not in data:
raise AuthenticationFailedError(
f"Token exchange error: {data.get('error', '')} — {data.get('error_description', 'Unknown error')}",
provider=provider.name,
)
return data
65 changes: 25 additions & 40 deletions src/authsome/auth/flows/device_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from typing import TYPE_CHECKING, Any

import requests
from authlib.integrations.base_client.errors import OAuthError
from authlib.oauth2 import OAuth2Error
from loguru import logger

from authsome.auth.flows.base import AuthFlow, FlowResult
from authsome.auth.flows.oauth2_client import fetch_device_token
from authsome.auth.models.connection import AccountInfo, ConnectionRecord
from authsome.auth.models.enums import AuthType, ConnectionStatus
from authsome.auth.models.provider import ProviderDefinition
Expand Down Expand Up @@ -183,57 +186,39 @@ async def poll_for_token(
effective_expires_in = min(expires_in, 300)
deadline = time.monotonic() + effective_expires_in

use_json = provider.oauth.device_token_request == "json"

while time.monotonic() < deadline:
await asyncio.sleep(poll_interval)

try:
if use_json:
resp = requests.post(
provider.oauth.token_url,
json={"device_code": device_code},
headers={"Accept": "application/json", "Content-Type": "application/json"},
timeout=30,
)
else:
payload: dict[str, str] = {
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": device_code,
}
if client_id:
payload["client_id"] = client_id
if client_secret:
payload["client_secret"] = client_secret
resp = requests.post(
provider.oauth.token_url, data=payload, headers={"Accept": "application/json"}, timeout=30
)
return fetch_device_token(
provider=provider,
device_code=device_code,
client_id=client_id,
client_secret=client_secret,
)
except requests.RequestException as exc:
logger.warning("Token poll request failed: {}, retrying...", exc)
continue

try:
data = resp.json()
except json.JSONDecodeError:
logger.warning("Token poll response was not JSON, retrying...")
continue

if resp.status_code == 200 and "access_token" in data:
return data

error = data.get("error", "")
if error == "authorization_pending":
continue
elif error == "slow_down":
poll_interval += 5
elif error == "access_denied":
raise AuthenticationFailedError("User denied the authorization request", provider=provider.name)
elif error == "expired_token":
raise AuthenticationFailedError("Device code has expired. Please try again.", provider=provider.name)
else:
except (OAuthError, OAuth2Error) as exc:
error = getattr(exc, "error", "")
if error == "authorization_pending":
continue
if error == "slow_down":
poll_interval += 5
continue
if error == "access_denied":
raise AuthenticationFailedError("User denied the authorization request", provider=provider.name)
if error == "expired_token":
raise AuthenticationFailedError(
"Device code has expired. Please try again.", provider=provider.name
)
description = getattr(exc, "description", None)
raise AuthenticationFailedError(
f"Token endpoint error: {data.get('error_description', error or 'Unknown error')}",
f"Token endpoint error: {description or error or 'Unknown error'}",
provider=provider.name,
)
) from exc

raise AuthenticationFailedError("Device authorization timed out.", provider=provider.name)
Loading
Loading