From f54583baf492662ba128c1b22511d413250907f1 Mon Sep 17 00:00:00 2001 From: Viktar Taustyka Date: Mon, 17 Apr 2023 09:57:01 +0200 Subject: [PATCH 1/5] C1: structure, reorganization, interfaces. --- src/business_logic/endsession/__init__.py | 0 src/business_logic/endsession/dto/__init__.py | 3 + src/business_logic/endsession/dto/request.py | 15 +++++ .../endsession/endsession_service.py | 65 +++++++++++++++++++ src/business_logic/endsession/errors.py | 0 src/business_logic/endsession/interfaces.py | 18 +++++ .../endsession/service_impls/__init__.py | 0 .../endsession_validation_service.py | 0 .../endsession/validators/__init__.py | 0 .../validate_end_session_request.py | 0 src/presentation/api/routes/endsession.py | 4 +- 11 files changed, 103 insertions(+), 2 deletions(-) create mode 100644 src/business_logic/endsession/__init__.py create mode 100644 src/business_logic/endsession/dto/__init__.py create mode 100644 src/business_logic/endsession/dto/request.py create mode 100644 src/business_logic/endsession/endsession_service.py create mode 100644 src/business_logic/endsession/errors.py create mode 100644 src/business_logic/endsession/interfaces.py create mode 100644 src/business_logic/endsession/service_impls/__init__.py create mode 100644 src/business_logic/endsession/service_impls/endsession_validation_service.py create mode 100644 src/business_logic/endsession/validators/__init__.py create mode 100644 src/business_logic/endsession/validators/validate_end_session_request.py diff --git a/src/business_logic/endsession/__init__.py b/src/business_logic/endsession/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/business_logic/endsession/dto/__init__.py b/src/business_logic/endsession/dto/__init__.py new file mode 100644 index 00000000..1b4efa0b --- /dev/null +++ b/src/business_logic/endsession/dto/__init__.py @@ -0,0 +1,3 @@ +from .request import RequestEndSessionModel + +__all__ = ["RequestEndSessionModel"] \ No newline at end of file diff --git a/src/business_logic/endsession/dto/request.py b/src/business_logic/endsession/dto/request.py new file mode 100644 index 00000000..d45da234 --- /dev/null +++ b/src/business_logic/endsession/dto/request.py @@ -0,0 +1,15 @@ +from typing import Optional + +from pydantic import BaseModel + + +class RequestEndSessionModel(BaseModel): + id_token_hint: str + post_logout_redirect_uri: Optional[str] + state: Optional[str] + + class Config: + orm_mode = True + + def __repr__(self) -> str: # pragma: no cover + return f"Model {self.__class__.__name__}" \ No newline at end of file diff --git a/src/business_logic/endsession/endsession_service.py b/src/business_logic/endsession/endsession_service.py new file mode 100644 index 00000000..982af88d --- /dev/null +++ b/src/business_logic/endsession/endsession_service.py @@ -0,0 +1,65 @@ +from src.data_access.postgresql.repositories.client import ClientRepository +from src.data_access.postgresql.repositories.persistent_grant import PersistentGrantRepository +from src.business_logic.dependencies.database import get_repository_no_depends +from src.business_logic.services.jwt_token import JWTService + +from .dto.request import RequestEndSessionModel + +from typing import Union, Optional, Any + + +class EndSessionService: + """ + Service for endsession endpoint ...... + """ + def __init__( + self, + client_repo: ClientRepository, + persistent_grant_repo: PersistentGrantRepository, + jwt_service: JWTService + ) -> None: + self.client_repo = client_repo + self.persistent_grant_repo = persistent_grant_repo + self.jwt_service = jwt_service + self._request_model: Optional[RequestEndSessionModel]= None + + async def end_session(self) -> Optional[str]: + if self.request_model is not None: + decoded_id_token_hint = await self._decode_id_token_hint(id_token_hint=self.request_model.id_token_hint) + await self._logout( + client_id=decoded_id_token_hint['client_id'], + user_id=decoded_id_token_hint['sub'] + ) + if self.request_model.post_logout_redirect_uri: + if await self._validate_logout_redirect_uri( + logout_redirect_uri=self.request_model.post_logout_redirect_uri, + client_id=decoded_id_token_hint["client_id"] + ): + logout_redirect_uri = self.request_model.post_logout_redirect_uri + if self.request_model.state: + logout_redirect_uri += f"&state={self.request_model.state}" + return logout_redirect_uri + return None + + async def _decode_id_token_hint(self, id_token_hint: str) -> dict[str, Any]: + decoded_data = await self.jwt_service.decode_token(token=id_token_hint) + return decoded_data + + async def _logout(self, client_id: str, user_id: int) -> None: + await self.persistent_grant_repo.delete_persistent_grant_by_client_and_user_id( + client_id=client_id, + user_id=user_id + ) + + + async def _validate_logout_redirect_uri(self, client_id: str, logout_redirect_uri: str) -> bool: + result = await self.client_repo.validate_post_logout_redirect_uri(client_id, logout_redirect_uri) + return result + + @property + def request_model(self) -> Optional[RequestEndSessionModel]: + return self._request_model + + @request_model.setter + def request_model(self, request_model: RequestEndSessionModel) -> None: + self._request_model = request_model \ No newline at end of file diff --git a/src/business_logic/endsession/errors.py b/src/business_logic/endsession/errors.py new file mode 100644 index 00000000..e69de29b diff --git a/src/business_logic/endsession/interfaces.py b/src/business_logic/endsession/interfaces.py new file mode 100644 index 00000000..ee3d2a0e --- /dev/null +++ b/src/business_logic/endsession/interfaces.py @@ -0,0 +1,18 @@ +from typing import Protocol, Optional, Any +from typing import TYPE_CHECKING +# if TYPE_CHECKING: +from src.business_logic.endsession.dto import RequestEndSessionModel + +class EndSessionServiceProtocol(Protocol): + + async def end_session(self) -> Optional[str]: + raise NotImplementedError + + @property + def request_model(self) -> Optional[RequestEndSessionModel]: + raise NotImplementedError + + @request_model.setter + def request_model(self, request_model: RequestEndSessionModel) -> None: + raise NotImplementedError + diff --git a/src/business_logic/endsession/service_impls/__init__.py b/src/business_logic/endsession/service_impls/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/business_logic/endsession/service_impls/endsession_validation_service.py b/src/business_logic/endsession/service_impls/endsession_validation_service.py new file mode 100644 index 00000000..e69de29b diff --git a/src/business_logic/endsession/validators/__init__.py b/src/business_logic/endsession/validators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/business_logic/endsession/validators/validate_end_session_request.py b/src/business_logic/endsession/validators/validate_end_session_request.py new file mode 100644 index 00000000..e69de29b diff --git a/src/presentation/api/routes/endsession.py b/src/presentation/api/routes/endsession.py index 71a5fd74..6cd39bae 100644 --- a/src/presentation/api/routes/endsession.py +++ b/src/presentation/api/routes/endsession.py @@ -5,7 +5,7 @@ from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse, RedirectResponse -from src.business_logic.services.endsession import EndSessionService +from src.business_logic.endsession.interfaces import EndSessionServiceProtocol from src.data_access.postgresql.errors.client import ( ClientPostLogoutRedirectUriError, ) @@ -24,7 +24,7 @@ @endsession_router.get("/", status_code=status.HTTP_204_NO_CONTENT) async def end_session( request_model: RequestEndSessionModel = Depends(), - service_class: EndSessionService = Depends( + service_class: EndSessionServiceProtocol = Depends( provide_endsession_service_stub ), ) -> Union[int, RedirectResponse, JSONResponse]: From 9ffa60ebf8ecdda3ad204966418edd575137750d Mon Sep 17 00:00:00 2001 From: Viktar Taustyka Date: Mon, 17 Apr 2023 17:48:19 +0200 Subject: [PATCH 2/5] C2: api, endsession_service.py, change tests for api --- src/business_logic/common/__init__.py | 0 src/business_logic/common/errors.py | 2 + src/business_logic/common/interfaces.py | 6 ++ src/business_logic/common/validators.py | 17 +++++ .../endsession/endsession_service.py | 54 ++++++++-------- src/business_logic/endsession/interfaces.py | 7 --- .../validators/validate_id_token_hint.py | 17 +++++ .../validate_logout_redirect_uri.py | 16 +++++ src/di/providers/services.py | 3 +- src/presentation/api/routes/endsession.py | 4 +- tests/test_api/test_endsession_endpoint.py | 63 +++++++++---------- 11 files changed, 116 insertions(+), 73 deletions(-) create mode 100644 src/business_logic/common/__init__.py create mode 100644 src/business_logic/common/errors.py create mode 100644 src/business_logic/common/interfaces.py create mode 100644 src/business_logic/common/validators.py create mode 100644 src/business_logic/endsession/validators/validate_id_token_hint.py create mode 100644 src/business_logic/endsession/validators/validate_logout_redirect_uri.py diff --git a/src/business_logic/common/__init__.py b/src/business_logic/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/business_logic/common/errors.py b/src/business_logic/common/errors.py new file mode 100644 index 00000000..f8fefd1f --- /dev/null +++ b/src/business_logic/common/errors.py @@ -0,0 +1,2 @@ +class InvalidClientIdError(Exception): + ... diff --git a/src/business_logic/common/interfaces.py b/src/business_logic/common/interfaces.py new file mode 100644 index 00000000..647e2997 --- /dev/null +++ b/src/business_logic/common/interfaces.py @@ -0,0 +1,6 @@ +from typing import Protocol, Any + + +class ValidatorProtocol(Protocol): + async def __call__(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError diff --git a/src/business_logic/common/validators.py b/src/business_logic/common/validators.py new file mode 100644 index 00000000..c8a3715e --- /dev/null +++ b/src/business_logic/common/validators.py @@ -0,0 +1,17 @@ +from src.business_logic.common.errors import InvalidClientIdError + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from src.data_access.postgresql.repositories import ClientRepository + + +class ValidateClient: + def __init__( + self, + client_repo: 'ClientRepository' + ) -> None: + self._client_repo = client_repo + + async def __call__(self, client_id: str) -> None: + if not await self._client_repo.exists(client_id=client_id): + raise InvalidClientIdError('Client with specified client_id doesn\'t exists') diff --git a/src/business_logic/endsession/endsession_service.py b/src/business_logic/endsession/endsession_service.py index 982af88d..75e9c2a1 100644 --- a/src/business_logic/endsession/endsession_service.py +++ b/src/business_logic/endsession/endsession_service.py @@ -6,6 +6,7 @@ from .dto.request import RequestEndSessionModel from typing import Union, Optional, Any +# from src.business_logic.common.interfaces import ValidatorProtocol class EndSessionService: @@ -16,29 +17,37 @@ def __init__( self, client_repo: ClientRepository, persistent_grant_repo: PersistentGrantRepository, - jwt_service: JWTService + jwt_service: JWTService, + # id_token_hint_validator: ValidatorProtocol, + # logout_redirect_uri_validator: ValidatorProtocol, ) -> None: self.client_repo = client_repo self.persistent_grant_repo = persistent_grant_repo self.jwt_service = jwt_service - self._request_model: Optional[RequestEndSessionModel]= None + # self._request_model: Optional[RequestEndSessionModel]= None - async def end_session(self) -> Optional[str]: - if self.request_model is not None: - decoded_id_token_hint = await self._decode_id_token_hint(id_token_hint=self.request_model.id_token_hint) - await self._logout( - client_id=decoded_id_token_hint['client_id'], - user_id=decoded_id_token_hint['sub'] - ) - if self.request_model.post_logout_redirect_uri: - if await self._validate_logout_redirect_uri( - logout_redirect_uri=self.request_model.post_logout_redirect_uri, - client_id=decoded_id_token_hint["client_id"] - ): - logout_redirect_uri = self.request_model.post_logout_redirect_uri - if self.request_model.state: - logout_redirect_uri += f"&state={self.request_model.state}" - return logout_redirect_uri + async def end_session(self, request_model: RequestEndSessionModel) -> Optional[str]: + + decoded_id_token_hint = await self._decode_id_token_hint(id_token_hint=request_model.id_token_hint) + # validate decoded_id_token_hint + + await self._logout( + client_id=decoded_id_token_hint['client_id'], + user_id=decoded_id_token_hint['sub'] + ) + # await ValidateLogoutRedirectUri() + # logout_redirect_uri = request_model.post_logout_redirect_uri + + # + if request_model.post_logout_redirect_uri: + if await self._validate_logout_redirect_uri( + logout_redirect_uri=request_model.post_logout_redirect_uri, + client_id=decoded_id_token_hint["client_id"] + ): + logout_redirect_uri = request_model.post_logout_redirect_uri + if request_model.state: + logout_redirect_uri += f"&state={request_model.state}" + return logout_redirect_uri return None async def _decode_id_token_hint(self, id_token_hint: str) -> dict[str, Any]: @@ -51,15 +60,6 @@ async def _logout(self, client_id: str, user_id: int) -> None: user_id=user_id ) - async def _validate_logout_redirect_uri(self, client_id: str, logout_redirect_uri: str) -> bool: result = await self.client_repo.validate_post_logout_redirect_uri(client_id, logout_redirect_uri) return result - - @property - def request_model(self) -> Optional[RequestEndSessionModel]: - return self._request_model - - @request_model.setter - def request_model(self, request_model: RequestEndSessionModel) -> None: - self._request_model = request_model \ No newline at end of file diff --git a/src/business_logic/endsession/interfaces.py b/src/business_logic/endsession/interfaces.py index ee3d2a0e..b1658a58 100644 --- a/src/business_logic/endsession/interfaces.py +++ b/src/business_logic/endsession/interfaces.py @@ -8,11 +8,4 @@ class EndSessionServiceProtocol(Protocol): async def end_session(self) -> Optional[str]: raise NotImplementedError - @property - def request_model(self) -> Optional[RequestEndSessionModel]: - raise NotImplementedError - - @request_model.setter - def request_model(self, request_model: RequestEndSessionModel) -> None: - raise NotImplementedError diff --git a/src/business_logic/endsession/validators/validate_id_token_hint.py b/src/business_logic/endsession/validators/validate_id_token_hint.py new file mode 100644 index 00000000..05f7a74f --- /dev/null +++ b/src/business_logic/endsession/validators/validate_id_token_hint.py @@ -0,0 +1,17 @@ +from business_logic.endsession.dto import RequestEndSessionModel +from src.data_access.postgresql.repositories import PersistentGrantRepository +from typing import Any + +class ValidateIdTokenHint: + """ + Checks that id_token_hint exists. + """ + def __init__( + self, + persistent_grant_repo: PersistentGrantRepository + ): + self._persistant_grant_repo = persistent_grant_repo + + async def __call__(self, decoded_id_token_hint: dict[str, Any]): + pass + # if await request.id_token_hint == self._persistant_grant_repo diff --git a/src/business_logic/endsession/validators/validate_logout_redirect_uri.py b/src/business_logic/endsession/validators/validate_logout_redirect_uri.py new file mode 100644 index 00000000..b6b6fe22 --- /dev/null +++ b/src/business_logic/endsession/validators/validate_logout_redirect_uri.py @@ -0,0 +1,16 @@ +from src.data_access.postgresql.repositories.client import ClientRepository +from src.data_access.postgresql.repositories import PersistentGrantRepository +from typing import Any + +class ValidateLogoutRedirectUri: + """ + Checks that id_token_hint exists. + """ + def __init__( + self, + client_repo: ClientRepository + ): + self._client_repo = client_repo + + async def __call__(self, client_id: str, logout_redirect_uri: str): + await self._client_repo.validate_post_logout_redirect_uri(client_id, logout_redirect_uri) diff --git a/src/di/providers/services.py b/src/di/providers/services.py index 51d14675..f29e80fa 100644 --- a/src/di/providers/services.py +++ b/src/di/providers/services.py @@ -13,7 +13,7 @@ ThirdPartyGitLabService, ThirdPartyMicrosoftService, DeviceService, - EndSessionService, + # EndSessionService, IntrospectionServies, JWTService, LoginFormService, @@ -26,6 +26,7 @@ UserInfoServices, WellKnownServices, ) +from src.business_logic.endsession.endsession_service import EndSessionService from src.data_access.postgresql.repositories import ( ClientRepository, DeviceRepository, diff --git a/src/presentation/api/routes/endsession.py b/src/presentation/api/routes/endsession.py index 6cd39bae..8d0c9c9a 100644 --- a/src/presentation/api/routes/endsession.py +++ b/src/presentation/api/routes/endsession.py @@ -29,9 +29,7 @@ async def end_session( ), ) -> Union[int, RedirectResponse, JSONResponse]: try: - service_class = service_class - service_class.request_model = request_model - logout_redirect_uri = await service_class.end_session() + logout_redirect_uri = await service_class.end_session(request_model) if logout_redirect_uri is None: return status.HTTP_204_NO_CONTENT diff --git a/tests/test_api/test_endsession_endpoint.py b/tests/test_api/test_endsession_endpoint.py index d6fd94c5..77876a5c 100644 --- a/tests/test_api/test_endsession_endpoint.py +++ b/tests/test_api/test_endsession_endpoint.py @@ -1,12 +1,13 @@ import jwt import pytest -from fastapi import status +from fastapi import status, FastAPI from httpx import AsyncClient from sqlalchemy import delete, insert from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio.engine import AsyncEngine from sqlalchemy.orm import sessionmaker +import src.di.providers as prov from src.data_access.postgresql.tables.persistent_grant import PersistentGrant from tests.test_unit.fixtures import ( end_session_request_model, @@ -14,7 +15,7 @@ TokenHint, ) from src.presentation.api.models.endsession import RequestEndSessionModel -from src.business_logic.services.endsession import EndSessionService +from src.business_logic.endsession.endsession_service import EndSessionService from tests.test_unit.fixtures import ( end_session_request_model, TOKEN_HINT_DATA, @@ -28,11 +29,24 @@ async def test_successful_authorize_request( self, engine: AsyncEngine, client: AsyncClient, - end_session_service: EndSessionService, - end_session_request_model: RequestEndSessionModel, + app: FastAPI, + # end_session_service: EndSessionService, + # end_session_request_model: RequestEndSessionModel, ) -> None: - service = end_session_service - service.request_model = end_session_request_model + # service = end_session_service + # service.request_model = end_session_request_model + # service = client.app.dependency_cache.get(provide_endsession_service_stub)() + # provide_endsession_service = app.dependency_overrides[provide_endsession_service_stub] + # service = provide_endsession_service_stub() + # service = provide_endsession_service() + # service = app.dependency_overrides[provide_endsession_service_stub]() + + service = prov.provide_endsession_service( + client_repo=prov.provide_client_repo(engine), + persistent_grant_repo=prov.provide_persistent_grant_repo(engine), + jwt_service=prov.provide_jwt_service(), + ) + secret = await service.client_repo.get_client_secrete_by_client_id( client_id=TOKEN_HINT_DATA["client_id"] ) @@ -71,14 +85,10 @@ async def test_successful_authorize_request( await session.commit() async def test_successful_authorize_request_without_uri( - self, - engine: AsyncEngine, - client: AsyncClient, - end_session_service: EndSessionService, - end_session_request_model: RequestEndSessionModel, + self, + engine: AsyncEngine, + client: AsyncClient, ) -> None: - service = end_session_service - service.request_model = end_session_request_model session_factory = sessionmaker( engine, expire_on_commit=False, class_=AsyncSession ) @@ -105,14 +115,10 @@ async def test_successful_authorize_request_without_uri( assert response.status_code == status.HTTP_204_NO_CONTENT async def test_successful_authorize_request_wrong_uri( - self, - engine: AsyncEngine, - client: AsyncClient, - end_session_service: EndSessionService, - end_session_request_model: RequestEndSessionModel, + self, + engine: AsyncEngine, + client: AsyncClient, ) -> None: - service = end_session_service - service.request_model = end_session_request_model session_factory = sessionmaker( engine, expire_on_commit=False, class_=AsyncSession ) @@ -145,11 +151,7 @@ async def test_successful_authorize_request_wrong_uri( async def test_end_session_bad_token( self, client: AsyncClient, - end_session_service: EndSessionService, - end_session_request_model: RequestEndSessionModel, ) -> None: - service = end_session_service - service.request_model = end_session_request_model expected_content = '{"message":"Bad id_token_hint"}' params = { "id_token_hint": "id_token_hint", @@ -160,16 +162,11 @@ async def test_end_session_bad_token( assert response.content.decode("UTF-8") == expected_content async def test_end_session_not_full_token( - self, - client: AsyncClient, - end_session_service: EndSessionService, - end_session_request_model: RequestEndSessionModel, + self, + client: AsyncClient ) -> None: hint = TokenHint() short_token_hint = await hint.get_short_token_hint() - service = end_session_service - service.request_model = end_session_request_model - service.request_model.id_token_hint = short_token_hint expected_content = ( '{"message":"The id_token_hint is missing something"}' @@ -185,13 +182,9 @@ async def test_end_session_not_full_token( async def test_end_session_no_persistent_grant( self, client: AsyncClient, - end_session_service: EndSessionService, - end_session_request_model: RequestEndSessionModel, ) -> None: hint = TokenHint() token_hint = await hint.get_token_hint() - service = end_session_service - service.request_model = end_session_request_model expected_content = '{"message":"You are not logged in"}' params = { "id_token_hint": token_hint, From fe6510cfe433c6ba1e82610327a53f28d9664efd Mon Sep 17 00:00:00 2001 From: Viktar Taustyka Date: Tue, 18 Apr 2023 18:15:14 +0200 Subject: [PATCH 3/5] C3: validators --- .../endsession/endsession_service.py | 18 ++-- .../endsession/validators/__init__.py | 3 + .../validators/validate_id_token_hint.py | 16 +++ .../validate_logout_redirect_uri.py | 5 +- src/business_logic/get_tokens/__init__.py | 5 + src/business_logic/get_tokens/dto/__init__.py | 5 + src/business_logic/get_tokens/dto/request.py | 39 +++++++ src/business_logic/get_tokens/dto/response.py | 17 +++ src/business_logic/get_tokens/errors.py | 10 ++ src/business_logic/get_tokens/factory.py | 67 ++++++++++++ src/business_logic/get_tokens/interfaces.py | 12 +++ .../get_tokens/service_impls/__init__.py | 5 + .../get_tokens/service_impls/auth_code.py | 102 ++++++++++++++++++ .../service_impls/client_credentials.py} | 0 .../get_tokens/service_impls/device_code.py | 0 .../get_tokens/service_impls/refresh_token.py | 37 +++++++ .../get_tokens/validators/__init__.py | 7 ++ .../validators/validate_grant_and_client.py | 20 ++++ .../validators/validate_grant_expiration.py | 6 ++ .../validators/validate_persistent_grant.py | 17 +++ .../validators/validate_redirect_uri.py | 18 ++++ src/business_logic/jwt_manager/__init__.py | 4 + .../jwt_manager/dto/__init__.py | 8 ++ src/business_logic/jwt_manager/dto/input.py | 33 ++++++ src/business_logic/jwt_manager/dto/output.py | 0 src/business_logic/jwt_manager/interfaces.py | 17 +++ .../jwt_manager/service_impls/__init__.py | 0 .../jwt_manager/service_impls/jwt_service.py | 44 ++++++++ src/di/providers/jwt_manager.py | 10 ++ src/di/providers/services_factory.py | 31 ++++++ .../api/exception_handlers/__init__.py | 10 ++ .../http400_invalid_client.py | 17 +++ .../http400_invalid_grant.py | 21 ++++ .../http400_unsupported_grant_type.py | 17 +++ 34 files changed, 612 insertions(+), 9 deletions(-) create mode 100644 src/business_logic/get_tokens/__init__.py create mode 100644 src/business_logic/get_tokens/dto/__init__.py create mode 100644 src/business_logic/get_tokens/dto/request.py create mode 100644 src/business_logic/get_tokens/dto/response.py create mode 100644 src/business_logic/get_tokens/errors.py create mode 100644 src/business_logic/get_tokens/factory.py create mode 100644 src/business_logic/get_tokens/interfaces.py create mode 100644 src/business_logic/get_tokens/service_impls/__init__.py create mode 100644 src/business_logic/get_tokens/service_impls/auth_code.py rename src/business_logic/{endsession/validators/validate_end_session_request.py => get_tokens/service_impls/client_credentials.py} (100%) create mode 100644 src/business_logic/get_tokens/service_impls/device_code.py create mode 100644 src/business_logic/get_tokens/service_impls/refresh_token.py create mode 100644 src/business_logic/get_tokens/validators/__init__.py create mode 100644 src/business_logic/get_tokens/validators/validate_grant_and_client.py create mode 100644 src/business_logic/get_tokens/validators/validate_grant_expiration.py create mode 100644 src/business_logic/get_tokens/validators/validate_persistent_grant.py create mode 100644 src/business_logic/get_tokens/validators/validate_redirect_uri.py create mode 100644 src/business_logic/jwt_manager/__init__.py create mode 100644 src/business_logic/jwt_manager/dto/__init__.py create mode 100644 src/business_logic/jwt_manager/dto/input.py create mode 100644 src/business_logic/jwt_manager/dto/output.py create mode 100644 src/business_logic/jwt_manager/interfaces.py create mode 100644 src/business_logic/jwt_manager/service_impls/__init__.py create mode 100644 src/business_logic/jwt_manager/service_impls/jwt_service.py create mode 100644 src/di/providers/jwt_manager.py create mode 100644 src/di/providers/services_factory.py create mode 100644 src/presentation/api/exception_handlers/__init__.py create mode 100644 src/presentation/api/exception_handlers/http400_invalid_client.py create mode 100644 src/presentation/api/exception_handlers/http400_invalid_grant.py create mode 100644 src/presentation/api/exception_handlers/http400_unsupported_grant_type.py diff --git a/src/business_logic/endsession/endsession_service.py b/src/business_logic/endsession/endsession_service.py index 75e9c2a1..92c8eeb2 100644 --- a/src/business_logic/endsession/endsession_service.py +++ b/src/business_logic/endsession/endsession_service.py @@ -4,6 +4,11 @@ from src.business_logic.services.jwt_token import JWTService from .dto.request import RequestEndSessionModel +from .validators import ( + ValidateDecodedIdTokenHint, + ValidateLogoutRedirectUri, + ValidateIdTokenHint + ) from typing import Union, Optional, Any # from src.business_logic.common.interfaces import ValidatorProtocol @@ -18,28 +23,27 @@ def __init__( client_repo: ClientRepository, persistent_grant_repo: PersistentGrantRepository, jwt_service: JWTService, - # id_token_hint_validator: ValidatorProtocol, - # logout_redirect_uri_validator: ValidatorProtocol, ) -> None: self.client_repo = client_repo self.persistent_grant_repo = persistent_grant_repo self.jwt_service = jwt_service # self._request_model: Optional[RequestEndSessionModel]= None + # id_token_hint_validator: ValidatorProtocol = ValidateIdTokenHint + # decoded_id_token_hint_validator: ValidatorProtocol = ValidateDecodedIdTokenHint + # logout_redirect_uri_validator: ValidatorProtocol = ValidateLogoutRedirectUri async def end_session(self, request_model: RequestEndSessionModel) -> Optional[str]: - + # await id_token_hint_validator(request_model) decoded_id_token_hint = await self._decode_id_token_hint(id_token_hint=request_model.id_token_hint) - # validate decoded_id_token_hint + # await decoded_id_token_hint_validator(decoded_id_token_hint: dict[str, Any]) await self._logout( client_id=decoded_id_token_hint['client_id'], user_id=decoded_id_token_hint['sub'] ) - # await ValidateLogoutRedirectUri() - # logout_redirect_uri = request_model.post_logout_redirect_uri - # if request_model.post_logout_redirect_uri: + # ? await logout_redirect_uri_validator(request_model, decoded_id_token_hint["client_id"]: str) if await self._validate_logout_redirect_uri( logout_redirect_uri=request_model.post_logout_redirect_uri, client_id=decoded_id_token_hint["client_id"] diff --git a/src/business_logic/endsession/validators/__init__.py b/src/business_logic/endsession/validators/__init__.py index e69de29b..544c045a 100644 --- a/src/business_logic/endsession/validators/__init__.py +++ b/src/business_logic/endsession/validators/__init__.py @@ -0,0 +1,3 @@ +from .validate_id_token_hint import (ValidateDecodedIdTokenHint, + ValidateIdTokenHint) +from .validate_logout_redirect_uri import ValidateLogoutRedirectUri \ No newline at end of file diff --git a/src/business_logic/endsession/validators/validate_id_token_hint.py b/src/business_logic/endsession/validators/validate_id_token_hint.py index 05f7a74f..232e1993 100644 --- a/src/business_logic/endsession/validators/validate_id_token_hint.py +++ b/src/business_logic/endsession/validators/validate_id_token_hint.py @@ -1,8 +1,24 @@ from business_logic.endsession.dto import RequestEndSessionModel +from business_logic.services import JWTService from src.data_access.postgresql.repositories import PersistentGrantRepository from typing import Any class ValidateIdTokenHint: + """ + Checks that id_token_hint exists. + """ + def __init__( + self, + jwt_service: JWTService + ): + self._jwt_service = jwt_service + + async def __call__(self, request_model: RequestEndSessionModel): + if not await self._jwt_service.verify_token(token=request_model.id_token_hint, aud="admin"): + raise ### + + +class ValidateDecodedIdTokenHint: """ Checks that id_token_hint exists. """ diff --git a/src/business_logic/endsession/validators/validate_logout_redirect_uri.py b/src/business_logic/endsession/validators/validate_logout_redirect_uri.py index b6b6fe22..8e3382d4 100644 --- a/src/business_logic/endsession/validators/validate_logout_redirect_uri.py +++ b/src/business_logic/endsession/validators/validate_logout_redirect_uri.py @@ -12,5 +12,6 @@ def __init__( ): self._client_repo = client_repo - async def __call__(self, client_id: str, logout_redirect_uri: str): - await self._client_repo.validate_post_logout_redirect_uri(client_id, logout_redirect_uri) + async def __call__(self, request_model, client_id): + if request_model.post_logout_redirect_uri: + await self._client_repo.validate_post_logout_redirect_uri(client_id, logout_redirect_uri) diff --git a/src/business_logic/get_tokens/__init__.py b/src/business_logic/get_tokens/__init__.py new file mode 100644 index 00000000..fc2ffe7c --- /dev/null +++ b/src/business_logic/get_tokens/__init__.py @@ -0,0 +1,5 @@ +from .factory import TokenServiceFactory +from .interfaces import TokenServiceProtocol + + +__all__ = ['TokenServiceFactory', 'TokenServiceProtocol'] diff --git a/src/business_logic/get_tokens/dto/__init__.py b/src/business_logic/get_tokens/dto/__init__.py new file mode 100644 index 00000000..eeb3537c --- /dev/null +++ b/src/business_logic/get_tokens/dto/__init__.py @@ -0,0 +1,5 @@ +from .request import RequestTokenModel +from .response import ResponseTokenModel + + +__all__ = ['RequestTokenModel', 'ResponseTokenModel'] diff --git a/src/business_logic/get_tokens/dto/request.py b/src/business_logic/get_tokens/dto/request.py new file mode 100644 index 00000000..f9f1075a --- /dev/null +++ b/src/business_logic/get_tokens/dto/request.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass +from typing import Optional +from fastapi import Form +from pydantic import BaseModel + + +class RequestTokenModel(BaseModel): + client_id: Optional[str] + client_secret: Optional[str] + grant_type: str + scope: Optional[str] + redirect_uri: Optional[str] + code: Optional[str] + code_verifier: Optional[str] + username: Optional[str] + password: Optional[str] + acr_values: Optional[str] + refresh_token: Optional[str] + device_code: Optional[str] + + @classmethod + def as_form( + cls, + client_id: Optional[str] = Form(...), + client_secret: Optional[str] = Form(None), + grant_type: str = Form(...), + scope: str = Form(None), + redirect_uri: Optional[str] = Form(None), + code: Optional[str] = Form(None), + code_verifier: Optional[str] = Form(None), + username: Optional[str] = Form(None), + password: Optional[str] = Form(None), + acr_values: Optional[str] = Form(None), + refresh_token: Optional[str] = Form(None), + device_code: Optional[str] = Form(None) + ) -> 'RequestTokenModel': + return cls(client_id=client_id, client_secret=client_secret, grant_type=grant_type, scope=scope, + redirect_uri=redirect_uri, code=code, code_verifier=code_verifier, username=username, + password=password, acr_values=acr_values, refresh_token=refresh_token, device_code=device_code) diff --git a/src/business_logic/get_tokens/dto/response.py b/src/business_logic/get_tokens/dto/response.py new file mode 100644 index 00000000..76f5556c --- /dev/null +++ b/src/business_logic/get_tokens/dto/response.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel +from typing import Optional + + +class ResponseTokenModel(BaseModel): + + access_token: Optional[str] + token_type: Optional[str] + refresh_token: Optional[str] + expires_in: Optional[int] + id_token: Optional[str] + refresh_expires_in : Optional[int] + not_before_policy : Optional[int] = None + scope : Optional[str] = None + + class Config: + orm_mode = True diff --git a/src/business_logic/get_tokens/errors.py b/src/business_logic/get_tokens/errors.py new file mode 100644 index 00000000..87200919 --- /dev/null +++ b/src/business_logic/get_tokens/errors.py @@ -0,0 +1,10 @@ +class InvalidGrantError(Exception): + ... + + +class InvalidRedirectUriError(Exception): + ... + + +class UnsupportedGrantTypeError(Exception): + ... diff --git a/src/business_logic/get_tokens/factory.py b/src/business_logic/get_tokens/factory.py new file mode 100644 index 00000000..b1cd2d5a --- /dev/null +++ b/src/business_logic/get_tokens/factory.py @@ -0,0 +1,67 @@ +from __future__ import annotations +from src.business_logic.get_tokens.service_impls import ( + AuthorizationCodeTokenService, + RefreshTokenGrantService, +) +from src.business_logic.get_tokens.validators import ( + ValidatePersistentGrant, + ValidateRedirectUri, + ValidateGrantByClient, + ValidateGrantExpired, +) +from src.business_logic.get_tokens.errors import UnsupportedGrantTypeError +from src.business_logic.common.validators import ValidateClient + + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .interfaces import TokenServiceProtocol + from src.data_access.postgresql.repositories import ( + BlacklistedTokenRepository, + ClientRepository, + DeviceRepository, + PersistentGrantRepository, + UserRepository, + ) + from src.business_logic.jwt_manager.interfaces import JWTManagerProtocol + + +class TokenServiceFactory: + def __init__( + self, + client_repo: ClientRepository, + persistent_grant_repo: PersistentGrantRepository, + user_repo: UserRepository, + device_repo: DeviceRepository, + jwt_manager: JWTManagerProtocol, + blacklisted_repo: BlacklistedTokenRepository, + ) -> None: + self._client_repo = client_repo + self._persistent_grant_repo = persistent_grant_repo + self._user_repo = user_repo + self._device_repo = device_repo + self._jwt_manager = jwt_manager + self._blacklisted_repo = blacklisted_repo + + def get_service_impl(self, grant_type: str) -> TokenServiceProtocol: + if grant_type == 'authorization_code': + return AuthorizationCodeTokenService( + grant_validator=ValidatePersistentGrant(persistent_grant_repo=self._persistent_grant_repo), + redirect_uri_validator=ValidateRedirectUri(client_repo=self._client_repo), + client_validator=ValidateClient(client_repo=self._client_repo), + code_validator=ValidateGrantByClient(persistent_grant_repo=self._persistent_grant_repo), + grant_exp_validator=ValidateGrantExpired(), + jwt_manager=self._jwt_manager, + persistent_grant_repo=self._persistent_grant_repo + ) + elif grant_type == 'refresh_token': + return RefreshTokenGrantService( + grant_validator=ValidatePersistentGrant(persistent_grant_repo=self._persistent_grant_repo), + redirect_uri_validator=ValidateRedirectUri(client_repo=self._client_repo), + client_validator=ValidateClient(client_repo=self._client_repo), + code_validator=ValidateGrantByClient(persistent_grant_repo=self._persistent_grant_repo), + jwt_manager=self._jwt_manager, + persistent_grant_repo=self._persistent_grant_repo + ) + else: + raise UnsupportedGrantTypeError diff --git a/src/business_logic/get_tokens/interfaces.py b/src/business_logic/get_tokens/interfaces.py new file mode 100644 index 00000000..efb0495c --- /dev/null +++ b/src/business_logic/get_tokens/interfaces.py @@ -0,0 +1,12 @@ +from typing import Protocol +from typing import TYPE_CHECKING, Any +if TYPE_CHECKING: + from src.business_logic.get_tokens.dto import ( + RequestTokenModel, + ResponseTokenModel, + ) + + +class TokenServiceProtocol(Protocol): + async def get_tokens(self, request_data: 'RequestTokenModel') -> 'ResponseTokenModel': + raise NotImplementedError diff --git a/src/business_logic/get_tokens/service_impls/__init__.py b/src/business_logic/get_tokens/service_impls/__init__.py new file mode 100644 index 00000000..56fad48d --- /dev/null +++ b/src/business_logic/get_tokens/service_impls/__init__.py @@ -0,0 +1,5 @@ +from .auth_code import AuthorizationCodeTokenService +from .refresh_token import RefreshTokenGrantService + + +__all__ = ['AuthorizationCodeTokenService', 'RefreshTokenGrantService'] diff --git a/src/business_logic/get_tokens/service_impls/auth_code.py b/src/business_logic/get_tokens/service_impls/auth_code.py new file mode 100644 index 00000000..af91e664 --- /dev/null +++ b/src/business_logic/get_tokens/service_impls/auth_code.py @@ -0,0 +1,102 @@ +from __future__ import annotations +import time +import uuid +from src.business_logic.get_tokens.dto import RequestTokenModel, ResponseTokenModel +from src.business_logic.jwt_manager.dto import AccessTokenPayload, RefreshTokenPayload, IdTokenPayload +from src.dyna_config import DOMAIN_NAME +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from src.business_logic.common.interfaces import ValidatorProtocol + from src.business_logic.jwt_manager.interfaces import JWTManagerProtocol + from src.data_access.postgresql.repositories import PersistentGrantRepository + + +class AuthorizationCodeTokenService: + """ + Service for switching authorization code that was generated by + authorization endpoint in auth code flow and sent to the client + as a response. + """ + def __init__( + self, + grant_validator: ValidatorProtocol, + redirect_uri_validator: ValidatorProtocol, + client_validator: ValidatorProtocol, + code_validator: ValidatorProtocol, + grant_exp_validator: ValidatorProtocol, + jwt_manager: JWTManagerProtocol, + persistent_grant_repo: PersistentGrantRepository + ): + self._grant_validator = grant_validator + self._redirect_uri_validator = redirect_uri_validator + self._client_validator = client_validator + self._code_validator = code_validator + self._grant_expiration_validator = grant_exp_validator + self._jwt_manager = jwt_manager + self._persistent_grant_repo = persistent_grant_repo + + async def get_tokens(self, request_data: RequestTokenModel) -> ResponseTokenModel: + await self._client_validator(request_data.client_id) + await self._code_validator(request_data.code, request_data.client_id, request_data.grant_type) + await self._grant_validator(request_data.code, request_data.grant_type) + await self._redirect_uri_validator(request_data.redirect_uri, request_data.client_id) + + grant = await self._persistent_grant_repo.get_grant(grant_type=request_data.grant_type, grant_data=request_data.code) + print(grant) + user_id = grant.user_id + current_unix_time = int(time.time()) + + access_token = await self._get_access_token(request_data=request_data, user_id=user_id, unix_time=current_unix_time) + refresh_token = await self._get_refresh_token(request_data=request_data) + id_token = await self._get_id_token(request_data=request_data, user_id=user_id, unix_time=current_unix_time) + + + await self._persistent_grant_repo.delete_grant(grant=grant) + await self._persistent_grant_repo.create_grant( + client_id=grant.client_id, + grant_data=refresh_token, + user_id=user_id, + grant_type_id=2, + expiration_time=84700 + ) + + return ResponseTokenModel( + access_token=access_token, + refresh_token=refresh_token, + id_token=id_token, + token_type='Bearer', + expires_in=600, + refresh_expires_in=1800 + ) + + async def _get_access_token(self, request_data: RequestTokenModel, user_id: str, unix_time: int) -> str: + payload = AccessTokenPayload( + sub=user_id, + iss=DOMAIN_NAME, + client_id=request_data.client_id, + iat=unix_time, + exp=unix_time + 600, + aud=request_data.client_id, + jti=str(uuid.uuid4()), + acr=0, + ) + return self._jwt_manager.encode(payload=payload, algorithm='RS256') + + async def _get_refresh_token(self, request_data: RequestTokenModel) -> str: + payload = RefreshTokenPayload( + jti=str(uuid.uuid4()) + ) + return self._jwt_manager.encode(payload=payload, algorithm='RS256') + + async def _get_id_token(self, request_data: RequestTokenModel, user_id: str, unix_time: int) -> str: + payload = IdTokenPayload( + sub=user_id, + iss=DOMAIN_NAME, + client_id=request_data.client_id, + iat=unix_time, + exp=unix_time + 600, + aud=request_data.client_id, + jti=str(uuid.uuid4()), + acr=0, + ) + return self._jwt_manager.encode(payload=payload, algorithm='RS256') diff --git a/src/business_logic/endsession/validators/validate_end_session_request.py b/src/business_logic/get_tokens/service_impls/client_credentials.py similarity index 100% rename from src/business_logic/endsession/validators/validate_end_session_request.py rename to src/business_logic/get_tokens/service_impls/client_credentials.py diff --git a/src/business_logic/get_tokens/service_impls/device_code.py b/src/business_logic/get_tokens/service_impls/device_code.py new file mode 100644 index 00000000..e69de29b diff --git a/src/business_logic/get_tokens/service_impls/refresh_token.py b/src/business_logic/get_tokens/service_impls/refresh_token.py new file mode 100644 index 00000000..2dac4bf3 --- /dev/null +++ b/src/business_logic/get_tokens/service_impls/refresh_token.py @@ -0,0 +1,37 @@ +from __future__ import annotations +import time +import uuid +from src.business_logic.get_tokens.dto import RequestTokenModel, ResponseTokenModel +from src.business_logic.jwt_manager.dto import AccessTokenPayload, RefreshTokenPayload, IdTokenPayload +from src.dyna_config import DOMAIN_NAME +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from src.business_logic.common.interfaces import ValidatorProtocol + from src.business_logic.jwt_manager.interfaces import JWTManagerProtocol + from src.data_access.postgresql.repositories import PersistentGrantRepository + + +class RefreshTokenGrantService: + def __init__( + self, + grant_validator: ValidatorProtocol, + redirect_uri_validator: ValidatorProtocol, + client_validator: ValidatorProtocol, + refresh_token_validator: ValidatorProtocol, + grant_exp_validator: ValidatorProtocol, + jwt_manager: JWTManagerProtocol, + persistent_grant_repo: PersistentGrantRepository + ): + self._grant_validator = grant_validator + self._redirect_uri_validator = redirect_uri_validator + self._client_validator = client_validator + self._code_validator = refresh_token_validator + self._grant_expiration_validator = grant_exp_validator + self._jwt_manager = jwt_manager + self._persistent_grant_repo = persistent_grant_repo + + async def get_tokens(self, request_data: RequestTokenModel) -> ResponseTokenModel: + await self._client_validator(request_data.client_id) + await self._code_validator(request_data.refresh_token, request_data.client_id, request_data.grant_type) + await self._grant_validator(request_data.refresh_token, request_data.grant_type) + await self._redirect_uri_validator(request_data.redirect_uri, request_data.client_id) diff --git a/src/business_logic/get_tokens/validators/__init__.py b/src/business_logic/get_tokens/validators/__init__.py new file mode 100644 index 00000000..e144df7c --- /dev/null +++ b/src/business_logic/get_tokens/validators/__init__.py @@ -0,0 +1,7 @@ +from .validate_persistent_grant import ValidatePersistentGrant +from .validate_redirect_uri import ValidateRedirectUri +from .validate_grant_and_client import ValidateGrantByClient +from .validate_grant_expiration import ValidateGrantExpired + + +__all__ = ['ValidatePersistentGrant', 'ValidateRedirectUri', 'ValidateGrantByClient', 'ValidateGrantExpired'] diff --git a/src/business_logic/get_tokens/validators/validate_grant_and_client.py b/src/business_logic/get_tokens/validators/validate_grant_and_client.py new file mode 100644 index 00000000..a875c200 --- /dev/null +++ b/src/business_logic/get_tokens/validators/validate_grant_and_client.py @@ -0,0 +1,20 @@ +from src.business_logic.get_tokens.errors import InvalidGrantError + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from src.data_access.postgresql.repositories import PersistentGrantRepository + + +class ValidateGrantByClient: + """ + Checks that authorization code was issued to the authenticated Client. + """ + def __init__( + self, + persistent_grant_repo: 'PersistentGrantRepository' + ): + self._persistent_grant_repo = persistent_grant_repo + + async def __call__(self, authorization_code: str, client_id: str, grant_type: str) -> None: + if not await self._persistent_grant_repo.exists_grant_for_client(authorization_code, client_id, grant_type): + raise InvalidGrantError('Invalid data provided.') diff --git a/src/business_logic/get_tokens/validators/validate_grant_expiration.py b/src/business_logic/get_tokens/validators/validate_grant_expiration.py new file mode 100644 index 00000000..28014de2 --- /dev/null +++ b/src/business_logic/get_tokens/validators/validate_grant_expiration.py @@ -0,0 +1,6 @@ +from src.business_logic.get_tokens.errors import InvalidGrantError + + +class ValidateGrantExpired: + def __call__(self, grant_created_datetime: str, grant_expiration: int) -> None: + pass diff --git a/src/business_logic/get_tokens/validators/validate_persistent_grant.py b/src/business_logic/get_tokens/validators/validate_persistent_grant.py new file mode 100644 index 00000000..aaec560b --- /dev/null +++ b/src/business_logic/get_tokens/validators/validate_persistent_grant.py @@ -0,0 +1,17 @@ +from src.business_logic.get_tokens.errors import InvalidGrantError + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from src.data_access.postgresql.repositories import PersistentGrantRepository + + +class ValidatePersistentGrant: + def __init__( + self, + persistent_grant_repo: 'PersistentGrantRepository' + ): + self._persistent_grant_repo = persistent_grant_repo + + async def __call__(self, code_to_validate: str, grant_type: str) -> None: + if not await self._persistent_grant_repo.exists(code_to_validate, grant_type): + raise InvalidGrantError('Invalid grant value or grant type.') diff --git a/src/business_logic/get_tokens/validators/validate_redirect_uri.py b/src/business_logic/get_tokens/validators/validate_redirect_uri.py new file mode 100644 index 00000000..4bf0307a --- /dev/null +++ b/src/business_logic/get_tokens/validators/validate_redirect_uri.py @@ -0,0 +1,18 @@ +from src.business_logic.get_tokens.errors import InvalidRedirectUriError + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from src.data_access.postgresql.repositories import ClientRepository + + +class ValidateRedirectUri: + def __init__( + self, + client_repo: 'ClientRepository' + ): + self._client_repo = client_repo + + async def __call__(self, redirect_uri: str, client_id: str) -> None: + uris_list = await self._client_repo.list_all_redirect_uris_by_client(client_id=client_id) + if redirect_uri not in uris_list: + raise InvalidRedirectUriError diff --git a/src/business_logic/jwt_manager/__init__.py b/src/business_logic/jwt_manager/__init__.py new file mode 100644 index 00000000..e72b7b67 --- /dev/null +++ b/src/business_logic/jwt_manager/__init__.py @@ -0,0 +1,4 @@ +from .service_impls.jwt_service import JWTManager + + +__all__ = ['JWTManager'] diff --git a/src/business_logic/jwt_manager/dto/__init__.py b/src/business_logic/jwt_manager/dto/__init__.py new file mode 100644 index 00000000..c937fe95 --- /dev/null +++ b/src/business_logic/jwt_manager/dto/__init__.py @@ -0,0 +1,8 @@ +from .input import ( + AccessTokenPayload, + RefreshTokenPayload, + IdTokenPayload, +) + + +__all__ = ['AccessTokenPayload', 'RefreshTokenPayload', 'IdTokenPayload'] diff --git a/src/business_logic/jwt_manager/dto/input.py b/src/business_logic/jwt_manager/dto/input.py new file mode 100644 index 00000000..5f048939 --- /dev/null +++ b/src/business_logic/jwt_manager/dto/input.py @@ -0,0 +1,33 @@ +from pydantic import BaseModel +from typing import Optional, Any + + +class BaseJWTPayload(BaseModel): + sub: str # user id + iss: str # auth service uri + iat: int # time of creation + exp: int # time when token will expire + aud: str # name for whom token was generated + client_id: str # id of the client who issued a token + jti: str # uniques identifier for token, UUID4 + acr: Optional[int] # default 0 + + +class AccessTokenPayload(BaseJWTPayload): + typ: str = 'Bearer' + + +class RefreshTokenPayload(BaseModel): + jti: str + + +class IdTokenPayload(BaseJWTPayload): + typ: str = 'ID' + email: Optional[str] = None + email_verified: Optional[str] = None + given_name: Optional[str] = None + last_name: Optional[str] = None + preferred_username: Optional[str] = None + picture: Optional[str] = None + zoneinfo: Optional[str] = None + locale: Optional[str] = None diff --git a/src/business_logic/jwt_manager/dto/output.py b/src/business_logic/jwt_manager/dto/output.py new file mode 100644 index 00000000..e69de29b diff --git a/src/business_logic/jwt_manager/interfaces.py b/src/business_logic/jwt_manager/interfaces.py new file mode 100644 index 00000000..f1848b47 --- /dev/null +++ b/src/business_logic/jwt_manager/interfaces.py @@ -0,0 +1,17 @@ +from src.business_logic.jwt_manager.dto import ( + AccessTokenPayload, + RefreshTokenPayload, + IdTokenPayload, +) +from typing import Protocol, Any, Union + + +Payload = Union[AccessTokenPayload, RefreshTokenPayload, IdTokenPayload] + + +class JWTManagerProtocol(Protocol): + def encode(self, payload: Payload, algorithm: str) -> str: + raise NotImplementedError + + def decode(self, token: str, audience: str,**kwargs: Any) -> dict[str, Any]: + raise NotImplementedError diff --git a/src/business_logic/jwt_manager/service_impls/__init__.py b/src/business_logic/jwt_manager/service_impls/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/business_logic/jwt_manager/service_impls/jwt_service.py b/src/business_logic/jwt_manager/service_impls/jwt_service.py new file mode 100644 index 00000000..1fb6ca39 --- /dev/null +++ b/src/business_logic/jwt_manager/service_impls/jwt_service.py @@ -0,0 +1,44 @@ +from __future__ import annotations +import logging +import jwt +from src.config.rsa_keys import RSAKeypair +from src.di import Container +from src.business_logic.jwt_manager.dto import ( + AccessTokenPayload, + RefreshTokenPayload, + IdTokenPayload, +) +from typing import Any, Optional, Union + + +logger = logging.getLogger(__name__) + + +Payload = Union[AccessTokenPayload, RefreshTokenPayload, IdTokenPayload] + + +class JWTManager: + def __init__(self, keys: RSAKeypair = Container().config().keys) -> None: + self.keys = keys + + def encode(self, payload: Payload, algorithm: str, secret: Optional[str] = None) -> str: + if secret: + key = secret + else: + key = self.keys.private_key + + token = jwt.encode( + payload=payload.dict(exclude_none=True), key=key, algorithm=algorithm + ) + return token + + def decode(self, token: str, audience: Optional[str] = None, **kwargs: Any) -> dict[str, Any]: + token = token.replace("Bearer ", "") + if audience: + decoded_info = jwt.decode(token, key=self.keys.public_key, algorithms=self.algorithms, + audience=audience, **kwargs,) + else: + decoded_info = jwt.decode(token, key=self.keys.public_key, algorithms=self.algorithms, + **kwargs,) + + return decoded_info diff --git a/src/di/providers/jwt_manager.py b/src/di/providers/jwt_manager.py new file mode 100644 index 00000000..fb0c4951 --- /dev/null +++ b/src/di/providers/jwt_manager.py @@ -0,0 +1,10 @@ +from src.business_logic.jwt_manager import JWTManager +from src.business_logic.jwt_manager.interfaces import JWTManagerProtocol + + +def provide_jwt_manager_stub() -> None: + ... + + +def provide_jwt_manager() -> JWTManagerProtocol: + return JWTManager() diff --git a/src/di/providers/services_factory.py b/src/di/providers/services_factory.py new file mode 100644 index 00000000..a1793111 --- /dev/null +++ b/src/di/providers/services_factory.py @@ -0,0 +1,31 @@ +from src.business_logic.get_tokens import TokenServiceFactory +from src.data_access.postgresql.repositories import ( + BlacklistedTokenRepository, + ClientRepository, + DeviceRepository, + PersistentGrantRepository, + UserRepository, +) +from src.business_logic.jwt_manager.interfaces import JWTManagerProtocol + + +def provide_token_service_factory_stub() -> None: + ... + + +def provide_token_service_factory( + client_repo: ClientRepository, + persistent_grant_repo: PersistentGrantRepository, + user_repo: UserRepository, + device_repo: DeviceRepository, + jwt_service: JWTManagerProtocol, + blacklisted_repo: BlacklistedTokenRepository, +) -> TokenServiceFactory: + return TokenServiceFactory( + client_repo=client_repo, + persistent_grant_repo=persistent_grant_repo, + user_repo=user_repo, + device_repo=device_repo, + jwt_manager=jwt_service, + blacklisted_repo=blacklisted_repo + ) diff --git a/src/presentation/api/exception_handlers/__init__.py b/src/presentation/api/exception_handlers/__init__.py new file mode 100644 index 00000000..e2f53c69 --- /dev/null +++ b/src/presentation/api/exception_handlers/__init__.py @@ -0,0 +1,10 @@ +from .http400_invalid_grant import http400_invalid_grant_handler +from .http400_invalid_client import http400_invalid_client_handler +from .http400_unsupported_grant_type import http400_unsupported_grant_type_handler + + +__all__ = [ + 'http400_invalid_grant_handler', + 'http400_invalid_client_handler', + 'http400_unsupported_grant_type_handler' +] diff --git a/src/presentation/api/exception_handlers/http400_invalid_client.py b/src/presentation/api/exception_handlers/http400_invalid_client.py new file mode 100644 index 00000000..43f5741a --- /dev/null +++ b/src/presentation/api/exception_handlers/http400_invalid_client.py @@ -0,0 +1,17 @@ +from src.business_logic.common.errors import InvalidClientIdError +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.status import HTTP_400_BAD_REQUEST + + +async def http400_invalid_client_handler( + _: Request, + exc: InvalidClientIdError +) -> JSONResponse: + headers = {"Cache-Control": "no-store", "Pragma": "no-cache"} + content = {"error": "invalid_client"} + return JSONResponse( + content=content, + headers=headers, + status_code=HTTP_400_BAD_REQUEST + ) diff --git a/src/presentation/api/exception_handlers/http400_invalid_grant.py b/src/presentation/api/exception_handlers/http400_invalid_grant.py new file mode 100644 index 00000000..f9faa5ea --- /dev/null +++ b/src/presentation/api/exception_handlers/http400_invalid_grant.py @@ -0,0 +1,21 @@ +from typing import Union +from src.business_logic.get_tokens.errors import InvalidGrantError, InvalidRedirectUriError +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.status import HTTP_400_BAD_REQUEST + + +ExceptionsToHandle = Union[InvalidGrantError, InvalidRedirectUriError] + + +async def http400_invalid_grant_handler( + _: Request, + exc: ExceptionsToHandle +) -> JSONResponse: + headers = {"Cache-Control": "no-store", "Pragma": "no-cache"} + content = {"error": "invalid_grant"} + return JSONResponse( + content=content, + headers=headers, + status_code=HTTP_400_BAD_REQUEST + ) diff --git a/src/presentation/api/exception_handlers/http400_unsupported_grant_type.py b/src/presentation/api/exception_handlers/http400_unsupported_grant_type.py new file mode 100644 index 00000000..d5d327ab --- /dev/null +++ b/src/presentation/api/exception_handlers/http400_unsupported_grant_type.py @@ -0,0 +1,17 @@ +from src.business_logic.get_tokens.errors import UnsupportedGrantTypeError +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.status import HTTP_400_BAD_REQUEST + + +async def http400_unsupported_grant_type_handler( + _: Request, + exc: UnsupportedGrantTypeError +) -> JSONResponse: + headers = {"Cache-Control": "no-store", "Pragma": "no-cache"} + content = {"error": "unsupported_grant_type"} + return JSONResponse( + content=content, + headers=headers, + status_code=HTTP_400_BAD_REQUEST + ) From 294418d36c7e0372c45ba590bb721bb408bcea5f Mon Sep 17 00:00:00 2001 From: Viktar Taustyka Date: Thu, 20 Apr 2023 10:22:20 +0200 Subject: [PATCH 4/5] C4: done, !except one test (test_validate_logout_redirect_uri_invalid_uri). --- .../endsession/endsession_service.py | 33 ++---- src/business_logic/endsession/errors.py | 2 + .../endsession/validators/__init__.py | 2 - .../validators/validate_id_token_hint.py | 33 ------ .../validate_logout_redirect_uri.py | 12 +- src/presentation/api/routes/endsession.py | 6 + tests/conftest.py | 3 +- tests/test_api/test_endsession_endpoint.py | 11 +- .../test_services/test_end_session_service.py | 108 +++++++++++------- 9 files changed, 96 insertions(+), 114 deletions(-) delete mode 100644 src/business_logic/endsession/validators/validate_id_token_hint.py diff --git a/src/business_logic/endsession/endsession_service.py b/src/business_logic/endsession/endsession_service.py index 92c8eeb2..85213653 100644 --- a/src/business_logic/endsession/endsession_service.py +++ b/src/business_logic/endsession/endsession_service.py @@ -4,14 +4,10 @@ from src.business_logic.services.jwt_token import JWTService from .dto.request import RequestEndSessionModel -from .validators import ( - ValidateDecodedIdTokenHint, - ValidateLogoutRedirectUri, - ValidateIdTokenHint - ) +from .validators import ValidateLogoutRedirectUri from typing import Union, Optional, Any -# from src.business_logic.common.interfaces import ValidatorProtocol +from src.business_logic.common.interfaces import ValidatorProtocol class EndSessionService: @@ -27,15 +23,10 @@ def __init__( self.client_repo = client_repo self.persistent_grant_repo = persistent_grant_repo self.jwt_service = jwt_service - # self._request_model: Optional[RequestEndSessionModel]= None - # id_token_hint_validator: ValidatorProtocol = ValidateIdTokenHint - # decoded_id_token_hint_validator: ValidatorProtocol = ValidateDecodedIdTokenHint - # logout_redirect_uri_validator: ValidatorProtocol = ValidateLogoutRedirectUri + self.logout_redirect_uri_validator: ValidatorProtocol = ValidateLogoutRedirectUri(self.client_repo) async def end_session(self, request_model: RequestEndSessionModel) -> Optional[str]: - # await id_token_hint_validator(request_model) decoded_id_token_hint = await self._decode_id_token_hint(id_token_hint=request_model.id_token_hint) - # await decoded_id_token_hint_validator(decoded_id_token_hint: dict[str, Any]) await self._logout( client_id=decoded_id_token_hint['client_id'], @@ -43,15 +34,12 @@ async def end_session(self, request_model: RequestEndSessionModel) -> Optional[s ) if request_model.post_logout_redirect_uri: - # ? await logout_redirect_uri_validator(request_model, decoded_id_token_hint["client_id"]: str) - if await self._validate_logout_redirect_uri( - logout_redirect_uri=request_model.post_logout_redirect_uri, - client_id=decoded_id_token_hint["client_id"] - ): - logout_redirect_uri = request_model.post_logout_redirect_uri - if request_model.state: - logout_redirect_uri += f"&state={request_model.state}" - return logout_redirect_uri + await self.logout_redirect_uri_validator(request_model=request_model, + client_id=decoded_id_token_hint["client_id"]) + logout_redirect_uri = request_model.post_logout_redirect_uri + if request_model.state: + logout_redirect_uri += f"&state={request_model.state}" + return logout_redirect_uri return None async def _decode_id_token_hint(self, id_token_hint: str) -> dict[str, Any]: @@ -64,6 +52,3 @@ async def _logout(self, client_id: str, user_id: int) -> None: user_id=user_id ) - async def _validate_logout_redirect_uri(self, client_id: str, logout_redirect_uri: str) -> bool: - result = await self.client_repo.validate_post_logout_redirect_uri(client_id, logout_redirect_uri) - return result diff --git a/src/business_logic/endsession/errors.py b/src/business_logic/endsession/errors.py index e69de29b..6cb30a23 100644 --- a/src/business_logic/endsession/errors.py +++ b/src/business_logic/endsession/errors.py @@ -0,0 +1,2 @@ +class InvalidLogoutRedirectUriError(Exception): + ... diff --git a/src/business_logic/endsession/validators/__init__.py b/src/business_logic/endsession/validators/__init__.py index 544c045a..a34c9651 100644 --- a/src/business_logic/endsession/validators/__init__.py +++ b/src/business_logic/endsession/validators/__init__.py @@ -1,3 +1 @@ -from .validate_id_token_hint import (ValidateDecodedIdTokenHint, - ValidateIdTokenHint) from .validate_logout_redirect_uri import ValidateLogoutRedirectUri \ No newline at end of file diff --git a/src/business_logic/endsession/validators/validate_id_token_hint.py b/src/business_logic/endsession/validators/validate_id_token_hint.py deleted file mode 100644 index 232e1993..00000000 --- a/src/business_logic/endsession/validators/validate_id_token_hint.py +++ /dev/null @@ -1,33 +0,0 @@ -from business_logic.endsession.dto import RequestEndSessionModel -from business_logic.services import JWTService -from src.data_access.postgresql.repositories import PersistentGrantRepository -from typing import Any - -class ValidateIdTokenHint: - """ - Checks that id_token_hint exists. - """ - def __init__( - self, - jwt_service: JWTService - ): - self._jwt_service = jwt_service - - async def __call__(self, request_model: RequestEndSessionModel): - if not await self._jwt_service.verify_token(token=request_model.id_token_hint, aud="admin"): - raise ### - - -class ValidateDecodedIdTokenHint: - """ - Checks that id_token_hint exists. - """ - def __init__( - self, - persistent_grant_repo: PersistentGrantRepository - ): - self._persistant_grant_repo = persistent_grant_repo - - async def __call__(self, decoded_id_token_hint: dict[str, Any]): - pass - # if await request.id_token_hint == self._persistant_grant_repo diff --git a/src/business_logic/endsession/validators/validate_logout_redirect_uri.py b/src/business_logic/endsession/validators/validate_logout_redirect_uri.py index 8e3382d4..08abb9ba 100644 --- a/src/business_logic/endsession/validators/validate_logout_redirect_uri.py +++ b/src/business_logic/endsession/validators/validate_logout_redirect_uri.py @@ -1,5 +1,7 @@ from src.data_access.postgresql.repositories.client import ClientRepository +from src.business_logic.endsession.dto import RequestEndSessionModel from src.data_access.postgresql.repositories import PersistentGrantRepository +from src.business_logic.endsession.errors import InvalidLogoutRedirectUriError from typing import Any class ValidateLogoutRedirectUri: @@ -12,6 +14,10 @@ def __init__( ): self._client_repo = client_repo - async def __call__(self, request_model, client_id): - if request_model.post_logout_redirect_uri: - await self._client_repo.validate_post_logout_redirect_uri(client_id, logout_redirect_uri) + async def __call__(self, request_model: RequestEndSessionModel, + client_id: str) -> None: + if not await self._client_repo.validate_post_logout_redirect_uri( + client_id, + request_model.post_logout_redirect_uri, + ): + raise InvalidLogoutRedirectUriError("Invalid post logout redirect uri") diff --git a/src/presentation/api/routes/endsession.py b/src/presentation/api/routes/endsession.py index 8d0c9c9a..412921de 100644 --- a/src/presentation/api/routes/endsession.py +++ b/src/presentation/api/routes/endsession.py @@ -50,6 +50,12 @@ async def end_session( status_code=status.HTTP_404_NOT_FOUND, content={"message": "You are not logged in"}, ) + # except InvalidIdTokenHintError as exception: + # logger.exception(exception) + # return JSONResponse( + # status_code=status.HTTP_404_NOT_FOUND, + # content={"message": "Invalid id token hint provided"}, + # ) except jwt.exceptions.DecodeError as exception: logger.exception(exception) return JSONResponse( diff --git a/tests/conftest.py b/tests/conftest.py index 5360388d..6125f485 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,8 @@ import asyncio from src.main import get_application from src.business_logic.services.authorization import AuthorizationService -from src.business_logic.services.endsession import EndSessionService +# from src.business_logic.services.endsession import EndSessionService +from src.business_logic.endsession.endsession_service import EndSessionService from src.business_logic.services.userinfo import UserInfoServices from src.business_logic.services import DeviceService, WellKnownServices diff --git a/tests/test_api/test_endsession_endpoint.py b/tests/test_api/test_endsession_endpoint.py index 77876a5c..ebeb91de 100644 --- a/tests/test_api/test_endsession_endpoint.py +++ b/tests/test_api/test_endsession_endpoint.py @@ -29,17 +29,8 @@ async def test_successful_authorize_request( self, engine: AsyncEngine, client: AsyncClient, - app: FastAPI, - # end_session_service: EndSessionService, - # end_session_request_model: RequestEndSessionModel, + ) -> None: - # service = end_session_service - # service.request_model = end_session_request_model - # service = client.app.dependency_cache.get(provide_endsession_service_stub)() - # provide_endsession_service = app.dependency_overrides[provide_endsession_service_stub] - # service = provide_endsession_service_stub() - # service = provide_endsession_service() - # service = app.dependency_overrides[provide_endsession_service_stub]() service = prov.provide_endsession_service( client_repo=prov.provide_client_repo(engine), diff --git a/tests/test_unit/test_services/test_end_session_service.py b/tests/test_unit/test_services/test_end_session_service.py index 23055b20..cfe43338 100644 --- a/tests/test_unit/test_services/test_end_session_service.py +++ b/tests/test_unit/test_services/test_end_session_service.py @@ -1,7 +1,11 @@ import jwt import pytest +from unittest.mock import AsyncMock, MagicMock from sqlalchemy import insert, select, delete +from business_logic.endsession.errors import InvalidLogoutRedirectUriError +from business_logic.endsession.validators import ValidateLogoutRedirectUri +from data_access.postgresql.repositories import ClientRepository from src.data_access.postgresql.errors import ClientPostLogoutRedirectUriError from src.data_access.postgresql.tables.persistent_grant import PersistentGrant from src.data_access.postgresql.errors.persistent_grant import ( @@ -9,44 +13,15 @@ ) from tests.test_unit.fixtures import end_session_request_model, TOKEN_HINT_DATA -from src.business_logic.services.endsession import EndSessionService +# from src.business_logic.services.endsession import EndSessionService +from src.business_logic.endsession.endsession_service import EndSessionService + from src.presentation.api.models.endsession import RequestEndSessionModel from sqlalchemy.ext.asyncio.engine import AsyncEngine @pytest.mark.asyncio class TestEndSessionService: - async def test_validate_logout_redirect_uri( - self, - end_session_service: EndSessionService, - end_session_request_model: RequestEndSessionModel, - ) -> None: - service = end_session_service - service.request_model = end_session_request_model - if not service.request_model.post_logout_redirect_uri: - raise AssertionError - result = await service._validate_logout_redirect_uri( - client_id="test_client", - logout_redirect_uri=service.request_model.post_logout_redirect_uri, - ) - - assert result is True - - async def test_validate_logout_redirect_uri_error( - self, - end_session_service: EndSessionService, - end_session_request_model: RequestEndSessionModel, - ) -> None: - service = end_session_service - service.request_model = end_session_request_model - service.request_model.post_logout_redirect_uri = "not_exist_uri" - with pytest.raises(ClientPostLogoutRedirectUriError): - await service._validate_logout_redirect_uri( - client_id="client_not_exist", - logout_redirect_uri=service.request_model.post_logout_redirect_uri, - ) - a = 1 - async def test_logout( self, end_session_service: EndSessionService, @@ -114,8 +89,7 @@ async def test_end_session( connection: AsyncEngine, ) -> None: service = end_session_service - service.request_model = end_session_request_model - service.request_model.post_logout_redirect_uri = "https://www.cole.com/" + end_session_request_model.post_logout_redirect_uri = "https://www.cole.com/" await connection.execute( insert(PersistentGrant).values( key="test_key", @@ -127,7 +101,7 @@ async def test_end_session( ) ) await connection.commit() - redirect_uri = await service.end_session() + redirect_uri = await service.end_session(end_session_request_model) assert redirect_uri == "https://www.cole.com/&state=test_state" await connection.execute( @@ -142,9 +116,8 @@ async def test_end_session_without_state( connection: AsyncEngine, ) -> None: service = end_session_service - service.request_model = end_session_request_model - service.request_model.post_logout_redirect_uri = "https://www.cole.com/" - service.request_model.state = None + end_session_request_model.post_logout_redirect_uri = "https://www.cole.com/" + end_session_request_model.state = None await connection.execute( insert(PersistentGrant).values( key="test_key", @@ -156,7 +129,7 @@ async def test_end_session_without_state( ) ) await connection.commit() - redirect_uri = await service.end_session() + redirect_uri = await service.end_session(end_session_request_model) assert redirect_uri == "https://www.cole.com/" await connection.execute( @@ -171,7 +144,7 @@ async def test_end_session_wrong_uri( connection: AsyncEngine, ) -> None: service = end_session_service - service.request_model = end_session_request_model + # service.request_model = end_session_request_model await connection.execute( insert(PersistentGrant).values( key="test_key", @@ -185,9 +158,62 @@ async def test_end_session_wrong_uri( await connection.commit() with pytest.raises(ClientPostLogoutRedirectUriError): - await service.end_session() + await service.end_session(end_session_request_model) await connection.execute( delete(PersistentGrant).where(PersistentGrant.client_id == 3) ) await connection.commit() + +@pytest.mark.asyncio +class TestEndSessionServiceValidators: + async def test_validate_logout_redirect_uri_valid_uri(self): + request_model = RequestEndSessionModel( + id_token_hint="some_id_token_hint", + post_logout_redirect_uri="https://valid-logout-uri.com", + ) + + client_repo_mock = AsyncMock() + client_repo_mock.validate_post_logout_redirect_uri.return_value = True + + validator = ValidateLogoutRedirectUri(client_repo=client_repo_mock) + await validator(request_model=request_model, client_id="client_id") + + client_repo_mock.validate_post_logout_redirect_uri.assert_called_once_with( + "client_id", "https://valid-logout-uri.com" + ) + + async def test_validate_logout_redirect_uri_invalid_uri(self): + + client_repo_mock = AsyncMock() + client_repo_mock.validate_post_logout_redirect_uri.return_value = False + validator = ValidateLogoutRedirectUri(client_repo=client_repo_mock) + + request_model = RequestEndSessionModel( + id_token_hint="some_id_token_hint", + post_logout_redirect_uri="invalid_uri", + state="some_state" + ) + + with pytest.raises(InvalidLogoutRedirectUriError, match="Invalid post logout redirect uri"): + await validator(request_model=request_model, client_id="client_id") + + # client_repo_mock.validate_post_logout_redirect_uri.assert_called_once_with( + # "client_id", "https://invalid-logout-uri.com" + # ) + + # async def test_validate_logout_redirect_uri_error( + # self, + # end_session_service: EndSessionService, + # end_session_request_model: RequestEndSessionModel, + # ) -> None: + # service = end_session_service + # service.request_model = end_session_request_model + # service.request_model.post_logout_redirect_uri = "not_exist_uri" + # with pytest.raises(ClientPostLogoutRedirectUriError): + # await service._validate_logout_redirect_uri( + # client_id="client_not_exist", + # logout_redirect_uri=service.request_model.post_logout_redirect_uri, + # ) + + From 19d682999f59d76fb23d11763127f925b04143b8 Mon Sep 17 00:00:00 2001 From: Viktar Taustyka Date: Thu, 20 Apr 2023 11:46:54 +0200 Subject: [PATCH 5/5] C5: done --- src/presentation/api/routes/endsession.py | 6 ----- .../test_services/test_end_session_service.py | 25 +++---------------- 2 files changed, 3 insertions(+), 28 deletions(-) diff --git a/src/presentation/api/routes/endsession.py b/src/presentation/api/routes/endsession.py index 412921de..8d0c9c9a 100644 --- a/src/presentation/api/routes/endsession.py +++ b/src/presentation/api/routes/endsession.py @@ -50,12 +50,6 @@ async def end_session( status_code=status.HTTP_404_NOT_FOUND, content={"message": "You are not logged in"}, ) - # except InvalidIdTokenHintError as exception: - # logger.exception(exception) - # return JSONResponse( - # status_code=status.HTTP_404_NOT_FOUND, - # content={"message": "Invalid id token hint provided"}, - # ) except jwt.exceptions.DecodeError as exception: logger.exception(exception) return JSONResponse( diff --git a/tests/test_unit/test_services/test_end_session_service.py b/tests/test_unit/test_services/test_end_session_service.py index cfe43338..4a379380 100644 --- a/tests/test_unit/test_services/test_end_session_service.py +++ b/tests/test_unit/test_services/test_end_session_service.py @@ -3,17 +3,15 @@ from unittest.mock import AsyncMock, MagicMock from sqlalchemy import insert, select, delete -from business_logic.endsession.errors import InvalidLogoutRedirectUriError -from business_logic.endsession.validators import ValidateLogoutRedirectUri -from data_access.postgresql.repositories import ClientRepository +from tests.test_unit.fixtures import end_session_request_model, TOKEN_HINT_DATA +from src.business_logic.endsession.errors import InvalidLogoutRedirectUriError +from src.business_logic.endsession.validators import ValidateLogoutRedirectUri from src.data_access.postgresql.errors import ClientPostLogoutRedirectUriError from src.data_access.postgresql.tables.persistent_grant import PersistentGrant from src.data_access.postgresql.errors.persistent_grant import ( PersistentGrantNotFoundError, ) -from tests.test_unit.fixtures import end_session_request_model, TOKEN_HINT_DATA -# from src.business_logic.services.endsession import EndSessionService from src.business_logic.endsession.endsession_service import EndSessionService from src.presentation.api.models.endsession import RequestEndSessionModel @@ -198,22 +196,5 @@ async def test_validate_logout_redirect_uri_invalid_uri(self): with pytest.raises(InvalidLogoutRedirectUriError, match="Invalid post logout redirect uri"): await validator(request_model=request_model, client_id="client_id") - # client_repo_mock.validate_post_logout_redirect_uri.assert_called_once_with( - # "client_id", "https://invalid-logout-uri.com" - # ) - - # async def test_validate_logout_redirect_uri_error( - # self, - # end_session_service: EndSessionService, - # end_session_request_model: RequestEndSessionModel, - # ) -> None: - # service = end_session_service - # service.request_model = end_session_request_model - # service.request_model.post_logout_redirect_uri = "not_exist_uri" - # with pytest.raises(ClientPostLogoutRedirectUriError): - # await service._validate_logout_redirect_uri( - # client_id="client_not_exist", - # logout_redirect_uri=service.request_model.post_logout_redirect_uri, - # )