diff --git a/propelauth_fastapi/__init__.py b/propelauth_fastapi/__init__.py index 77ced9e..9360db6 100644 --- a/propelauth_fastapi/__init__.py +++ b/propelauth_fastapi/__init__.py @@ -4,7 +4,11 @@ from fastapi import Depends, HTTPException from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from propelauth_py import TokenVerificationMetadata, init_base_auth, Auth -from propelauth_py.errors import ForbiddenException, UnauthorizedException +from propelauth_py.errors import ( + EndUserApiKeyException, + ForbiddenException, + UnauthorizedException, +) from propelauth_py.user import User _security = HTTPBearer(auto_error=False) @@ -50,6 +54,39 @@ def __call__(self, credentials: HTTPAuthorizationCredentials = Depends(_security return None +@dataclass +class Org: + org_id: str + name: str + is_saml_configured: bool + + +class AccessTokenOrAPIKeyDependency: + def __init__(self, auth: Auth): + self.auth = auth + + def __call__( + self, credentials: HTTPAuthorizationCredentials = Depends(_security) + ) -> Org: + if credentials is None: + raise HTTPException(status_code=401) + user: User = self.auth.optional_user(credentials) + if user: + org_ids = list(user.org_id_to_org_member_info.keys()) + org_id = org_ids[0] # we will only have one org per user + fetch_output = self.auth.fetch_org(org_id) + return Org(**fetch_output) + + authorization_header = credentials.scheme + " " + credentials.credentials + + try: + metadata = self.auth.validate_api_key(authorization_header).get("org") + fetch_output = self.auth.fetch_org(metadata.get("org_id")) + return Org(**fetch_output) + except EndUserApiKeyException: + raise HTTPException(status_code=401) + + def _require_org_member_wrapper(auth: Auth, debug_mode: bool): def require_org_member(user: User, required_org_id: str): try: