diff --git a/src/business_logic/common/__init__.py b/src/business_logic/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/business_logic/common/errors.py b/src/business_logic/common/errors.py new file mode 100644 index 00000000..f8fefd1f --- /dev/null +++ b/src/business_logic/common/errors.py @@ -0,0 +1,2 @@ +class InvalidClientIdError(Exception): + ... diff --git a/src/business_logic/common/interfaces.py b/src/business_logic/common/interfaces.py new file mode 100644 index 00000000..647e2997 --- /dev/null +++ b/src/business_logic/common/interfaces.py @@ -0,0 +1,6 @@ +from typing import Protocol, Any + + +class ValidatorProtocol(Protocol): + async def __call__(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError diff --git a/src/business_logic/common/validators.py b/src/business_logic/common/validators.py new file mode 100644 index 00000000..c8a3715e --- /dev/null +++ b/src/business_logic/common/validators.py @@ -0,0 +1,17 @@ +from src.business_logic.common.errors import InvalidClientIdError + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from src.data_access.postgresql.repositories import ClientRepository + + +class ValidateClient: + def __init__( + self, + client_repo: 'ClientRepository' + ) -> None: + self._client_repo = client_repo + + async def __call__(self, client_id: str) -> None: + if not await self._client_repo.exists(client_id=client_id): + raise InvalidClientIdError('Client with specified client_id doesn\'t exists') diff --git a/src/business_logic/endsession/__init__.py b/src/business_logic/endsession/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/business_logic/endsession/dto/__init__.py b/src/business_logic/endsession/dto/__init__.py new file mode 100644 index 00000000..1b4efa0b --- /dev/null +++ b/src/business_logic/endsession/dto/__init__.py @@ -0,0 +1,3 @@ +from .request import RequestEndSessionModel + +__all__ = ["RequestEndSessionModel"] \ No newline at end of file diff --git a/src/business_logic/endsession/dto/request.py b/src/business_logic/endsession/dto/request.py new file mode 100644 index 00000000..d45da234 --- /dev/null +++ b/src/business_logic/endsession/dto/request.py @@ -0,0 +1,15 @@ +from typing import Optional + +from pydantic import BaseModel + + +class RequestEndSessionModel(BaseModel): + id_token_hint: str + post_logout_redirect_uri: Optional[str] + state: Optional[str] + + class Config: + orm_mode = True + + def __repr__(self) -> str: # pragma: no cover + return f"Model {self.__class__.__name__}" \ No newline at end of file diff --git a/src/business_logic/endsession/endsession_service.py b/src/business_logic/endsession/endsession_service.py new file mode 100644 index 00000000..85213653 --- /dev/null +++ b/src/business_logic/endsession/endsession_service.py @@ -0,0 +1,54 @@ +from src.data_access.postgresql.repositories.client import ClientRepository +from src.data_access.postgresql.repositories.persistent_grant import PersistentGrantRepository +from src.business_logic.dependencies.database import get_repository_no_depends +from src.business_logic.services.jwt_token import JWTService + +from .dto.request import RequestEndSessionModel +from .validators import ValidateLogoutRedirectUri + +from typing import Union, Optional, Any +from src.business_logic.common.interfaces import ValidatorProtocol + + +class EndSessionService: + """ + Service for endsession endpoint ...... + """ + def __init__( + self, + client_repo: ClientRepository, + persistent_grant_repo: PersistentGrantRepository, + jwt_service: JWTService, + ) -> None: + self.client_repo = client_repo + self.persistent_grant_repo = persistent_grant_repo + self.jwt_service = jwt_service + self.logout_redirect_uri_validator: ValidatorProtocol = ValidateLogoutRedirectUri(self.client_repo) + + async def end_session(self, request_model: RequestEndSessionModel) -> Optional[str]: + decoded_id_token_hint = await self._decode_id_token_hint(id_token_hint=request_model.id_token_hint) + + await self._logout( + client_id=decoded_id_token_hint['client_id'], + user_id=decoded_id_token_hint['sub'] + ) + + if request_model.post_logout_redirect_uri: + await self.logout_redirect_uri_validator(request_model=request_model, + client_id=decoded_id_token_hint["client_id"]) + logout_redirect_uri = request_model.post_logout_redirect_uri + if request_model.state: + logout_redirect_uri += f"&state={request_model.state}" + return logout_redirect_uri + return None + + async def _decode_id_token_hint(self, id_token_hint: str) -> dict[str, Any]: + decoded_data = await self.jwt_service.decode_token(token=id_token_hint) + return decoded_data + + async def _logout(self, client_id: str, user_id: int) -> None: + await self.persistent_grant_repo.delete_persistent_grant_by_client_and_user_id( + client_id=client_id, + user_id=user_id + ) + diff --git a/src/business_logic/endsession/errors.py b/src/business_logic/endsession/errors.py new file mode 100644 index 00000000..6cb30a23 --- /dev/null +++ b/src/business_logic/endsession/errors.py @@ -0,0 +1,2 @@ +class InvalidLogoutRedirectUriError(Exception): + ... diff --git a/src/business_logic/endsession/interfaces.py b/src/business_logic/endsession/interfaces.py new file mode 100644 index 00000000..b1658a58 --- /dev/null +++ b/src/business_logic/endsession/interfaces.py @@ -0,0 +1,11 @@ +from typing import Protocol, Optional, Any +from typing import TYPE_CHECKING +# if TYPE_CHECKING: +from src.business_logic.endsession.dto import RequestEndSessionModel + +class EndSessionServiceProtocol(Protocol): + + async def end_session(self) -> Optional[str]: + raise NotImplementedError + + diff --git a/src/business_logic/endsession/service_impls/__init__.py b/src/business_logic/endsession/service_impls/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/business_logic/endsession/service_impls/endsession_validation_service.py b/src/business_logic/endsession/service_impls/endsession_validation_service.py new file mode 100644 index 00000000..e69de29b diff --git a/src/business_logic/endsession/validators/__init__.py b/src/business_logic/endsession/validators/__init__.py new file mode 100644 index 00000000..a34c9651 --- /dev/null +++ b/src/business_logic/endsession/validators/__init__.py @@ -0,0 +1 @@ +from .validate_logout_redirect_uri import ValidateLogoutRedirectUri \ No newline at end of file diff --git a/src/business_logic/endsession/validators/validate_logout_redirect_uri.py b/src/business_logic/endsession/validators/validate_logout_redirect_uri.py new file mode 100644 index 00000000..08abb9ba --- /dev/null +++ b/src/business_logic/endsession/validators/validate_logout_redirect_uri.py @@ -0,0 +1,23 @@ +from src.data_access.postgresql.repositories.client import ClientRepository +from src.business_logic.endsession.dto import RequestEndSessionModel +from src.data_access.postgresql.repositories import PersistentGrantRepository +from src.business_logic.endsession.errors import InvalidLogoutRedirectUriError +from typing import Any + +class ValidateLogoutRedirectUri: + """ + Checks that id_token_hint exists. + """ + def __init__( + self, + client_repo: ClientRepository + ): + self._client_repo = client_repo + + async def __call__(self, request_model: RequestEndSessionModel, + client_id: str) -> None: + if not await self._client_repo.validate_post_logout_redirect_uri( + client_id, + request_model.post_logout_redirect_uri, + ): + raise InvalidLogoutRedirectUriError("Invalid post logout redirect uri") diff --git a/src/business_logic/get_tokens/__init__.py b/src/business_logic/get_tokens/__init__.py new file mode 100644 index 00000000..fc2ffe7c --- /dev/null +++ b/src/business_logic/get_tokens/__init__.py @@ -0,0 +1,5 @@ +from .factory import TokenServiceFactory +from .interfaces import TokenServiceProtocol + + +__all__ = ['TokenServiceFactory', 'TokenServiceProtocol'] diff --git a/src/business_logic/get_tokens/dto/__init__.py b/src/business_logic/get_tokens/dto/__init__.py new file mode 100644 index 00000000..eeb3537c --- /dev/null +++ b/src/business_logic/get_tokens/dto/__init__.py @@ -0,0 +1,5 @@ +from .request import RequestTokenModel +from .response import ResponseTokenModel + + +__all__ = ['RequestTokenModel', 'ResponseTokenModel'] diff --git a/src/business_logic/get_tokens/dto/request.py b/src/business_logic/get_tokens/dto/request.py new file mode 100644 index 00000000..f9f1075a --- /dev/null +++ b/src/business_logic/get_tokens/dto/request.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass +from typing import Optional +from fastapi import Form +from pydantic import BaseModel + + +class RequestTokenModel(BaseModel): + client_id: Optional[str] + client_secret: Optional[str] + grant_type: str + scope: Optional[str] + redirect_uri: Optional[str] + code: Optional[str] + code_verifier: Optional[str] + username: Optional[str] + password: Optional[str] + acr_values: Optional[str] + refresh_token: Optional[str] + device_code: Optional[str] + + @classmethod + def as_form( + cls, + client_id: Optional[str] = Form(...), + client_secret: Optional[str] = Form(None), + grant_type: str = Form(...), + scope: str = Form(None), + redirect_uri: Optional[str] = Form(None), + code: Optional[str] = Form(None), + code_verifier: Optional[str] = Form(None), + username: Optional[str] = Form(None), + password: Optional[str] = Form(None), + acr_values: Optional[str] = Form(None), + refresh_token: Optional[str] = Form(None), + device_code: Optional[str] = Form(None) + ) -> 'RequestTokenModel': + return cls(client_id=client_id, client_secret=client_secret, grant_type=grant_type, scope=scope, + redirect_uri=redirect_uri, code=code, code_verifier=code_verifier, username=username, + password=password, acr_values=acr_values, refresh_token=refresh_token, device_code=device_code) diff --git a/src/business_logic/get_tokens/dto/response.py b/src/business_logic/get_tokens/dto/response.py new file mode 100644 index 00000000..76f5556c --- /dev/null +++ b/src/business_logic/get_tokens/dto/response.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel +from typing import Optional + + +class ResponseTokenModel(BaseModel): + + access_token: Optional[str] + token_type: Optional[str] + refresh_token: Optional[str] + expires_in: Optional[int] + id_token: Optional[str] + refresh_expires_in : Optional[int] + not_before_policy : Optional[int] = None + scope : Optional[str] = None + + class Config: + orm_mode = True diff --git a/src/business_logic/get_tokens/errors.py b/src/business_logic/get_tokens/errors.py new file mode 100644 index 00000000..87200919 --- /dev/null +++ b/src/business_logic/get_tokens/errors.py @@ -0,0 +1,10 @@ +class InvalidGrantError(Exception): + ... + + +class InvalidRedirectUriError(Exception): + ... + + +class UnsupportedGrantTypeError(Exception): + ... diff --git a/src/business_logic/get_tokens/factory.py b/src/business_logic/get_tokens/factory.py new file mode 100644 index 00000000..b1cd2d5a --- /dev/null +++ b/src/business_logic/get_tokens/factory.py @@ -0,0 +1,67 @@ +from __future__ import annotations +from src.business_logic.get_tokens.service_impls import ( + AuthorizationCodeTokenService, + RefreshTokenGrantService, +) +from src.business_logic.get_tokens.validators import ( + ValidatePersistentGrant, + ValidateRedirectUri, + ValidateGrantByClient, + ValidateGrantExpired, +) +from src.business_logic.get_tokens.errors import UnsupportedGrantTypeError +from src.business_logic.common.validators import ValidateClient + + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .interfaces import TokenServiceProtocol + from src.data_access.postgresql.repositories import ( + BlacklistedTokenRepository, + ClientRepository, + DeviceRepository, + PersistentGrantRepository, + UserRepository, + ) + from src.business_logic.jwt_manager.interfaces import JWTManagerProtocol + + +class TokenServiceFactory: + def __init__( + self, + client_repo: ClientRepository, + persistent_grant_repo: PersistentGrantRepository, + user_repo: UserRepository, + device_repo: DeviceRepository, + jwt_manager: JWTManagerProtocol, + blacklisted_repo: BlacklistedTokenRepository, + ) -> None: + self._client_repo = client_repo + self._persistent_grant_repo = persistent_grant_repo + self._user_repo = user_repo + self._device_repo = device_repo + self._jwt_manager = jwt_manager + self._blacklisted_repo = blacklisted_repo + + def get_service_impl(self, grant_type: str) -> TokenServiceProtocol: + if grant_type == 'authorization_code': + return AuthorizationCodeTokenService( + grant_validator=ValidatePersistentGrant(persistent_grant_repo=self._persistent_grant_repo), + redirect_uri_validator=ValidateRedirectUri(client_repo=self._client_repo), + client_validator=ValidateClient(client_repo=self._client_repo), + code_validator=ValidateGrantByClient(persistent_grant_repo=self._persistent_grant_repo), + grant_exp_validator=ValidateGrantExpired(), + jwt_manager=self._jwt_manager, + persistent_grant_repo=self._persistent_grant_repo + ) + elif grant_type == 'refresh_token': + return RefreshTokenGrantService( + grant_validator=ValidatePersistentGrant(persistent_grant_repo=self._persistent_grant_repo), + redirect_uri_validator=ValidateRedirectUri(client_repo=self._client_repo), + client_validator=ValidateClient(client_repo=self._client_repo), + code_validator=ValidateGrantByClient(persistent_grant_repo=self._persistent_grant_repo), + jwt_manager=self._jwt_manager, + persistent_grant_repo=self._persistent_grant_repo + ) + else: + raise UnsupportedGrantTypeError diff --git a/src/business_logic/get_tokens/interfaces.py b/src/business_logic/get_tokens/interfaces.py new file mode 100644 index 00000000..efb0495c --- /dev/null +++ b/src/business_logic/get_tokens/interfaces.py @@ -0,0 +1,12 @@ +from typing import Protocol +from typing import TYPE_CHECKING, Any +if TYPE_CHECKING: + from src.business_logic.get_tokens.dto import ( + RequestTokenModel, + ResponseTokenModel, + ) + + +class TokenServiceProtocol(Protocol): + async def get_tokens(self, request_data: 'RequestTokenModel') -> 'ResponseTokenModel': + raise NotImplementedError diff --git a/src/business_logic/get_tokens/service_impls/__init__.py b/src/business_logic/get_tokens/service_impls/__init__.py new file mode 100644 index 00000000..56fad48d --- /dev/null +++ b/src/business_logic/get_tokens/service_impls/__init__.py @@ -0,0 +1,5 @@ +from .auth_code import AuthorizationCodeTokenService +from .refresh_token import RefreshTokenGrantService + + +__all__ = ['AuthorizationCodeTokenService', 'RefreshTokenGrantService'] diff --git a/src/business_logic/get_tokens/service_impls/auth_code.py b/src/business_logic/get_tokens/service_impls/auth_code.py new file mode 100644 index 00000000..af91e664 --- /dev/null +++ b/src/business_logic/get_tokens/service_impls/auth_code.py @@ -0,0 +1,102 @@ +from __future__ import annotations +import time +import uuid +from src.business_logic.get_tokens.dto import RequestTokenModel, ResponseTokenModel +from src.business_logic.jwt_manager.dto import AccessTokenPayload, RefreshTokenPayload, IdTokenPayload +from src.dyna_config import DOMAIN_NAME +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from src.business_logic.common.interfaces import ValidatorProtocol + from src.business_logic.jwt_manager.interfaces import JWTManagerProtocol + from src.data_access.postgresql.repositories import PersistentGrantRepository + + +class AuthorizationCodeTokenService: + """ + Service for switching authorization code that was generated by + authorization endpoint in auth code flow and sent to the client + as a response. + """ + def __init__( + self, + grant_validator: ValidatorProtocol, + redirect_uri_validator: ValidatorProtocol, + client_validator: ValidatorProtocol, + code_validator: ValidatorProtocol, + grant_exp_validator: ValidatorProtocol, + jwt_manager: JWTManagerProtocol, + persistent_grant_repo: PersistentGrantRepository + ): + self._grant_validator = grant_validator + self._redirect_uri_validator = redirect_uri_validator + self._client_validator = client_validator + self._code_validator = code_validator + self._grant_expiration_validator = grant_exp_validator + self._jwt_manager = jwt_manager + self._persistent_grant_repo = persistent_grant_repo + + async def get_tokens(self, request_data: RequestTokenModel) -> ResponseTokenModel: + await self._client_validator(request_data.client_id) + await self._code_validator(request_data.code, request_data.client_id, request_data.grant_type) + await self._grant_validator(request_data.code, request_data.grant_type) + await self._redirect_uri_validator(request_data.redirect_uri, request_data.client_id) + + grant = await self._persistent_grant_repo.get_grant(grant_type=request_data.grant_type, grant_data=request_data.code) + print(grant) + user_id = grant.user_id + current_unix_time = int(time.time()) + + access_token = await self._get_access_token(request_data=request_data, user_id=user_id, unix_time=current_unix_time) + refresh_token = await self._get_refresh_token(request_data=request_data) + id_token = await self._get_id_token(request_data=request_data, user_id=user_id, unix_time=current_unix_time) + + + await self._persistent_grant_repo.delete_grant(grant=grant) + await self._persistent_grant_repo.create_grant( + client_id=grant.client_id, + grant_data=refresh_token, + user_id=user_id, + grant_type_id=2, + expiration_time=84700 + ) + + return ResponseTokenModel( + access_token=access_token, + refresh_token=refresh_token, + id_token=id_token, + token_type='Bearer', + expires_in=600, + refresh_expires_in=1800 + ) + + async def _get_access_token(self, request_data: RequestTokenModel, user_id: str, unix_time: int) -> str: + payload = AccessTokenPayload( + sub=user_id, + iss=DOMAIN_NAME, + client_id=request_data.client_id, + iat=unix_time, + exp=unix_time + 600, + aud=request_data.client_id, + jti=str(uuid.uuid4()), + acr=0, + ) + return self._jwt_manager.encode(payload=payload, algorithm='RS256') + + async def _get_refresh_token(self, request_data: RequestTokenModel) -> str: + payload = RefreshTokenPayload( + jti=str(uuid.uuid4()) + ) + return self._jwt_manager.encode(payload=payload, algorithm='RS256') + + async def _get_id_token(self, request_data: RequestTokenModel, user_id: str, unix_time: int) -> str: + payload = IdTokenPayload( + sub=user_id, + iss=DOMAIN_NAME, + client_id=request_data.client_id, + iat=unix_time, + exp=unix_time + 600, + aud=request_data.client_id, + jti=str(uuid.uuid4()), + acr=0, + ) + return self._jwt_manager.encode(payload=payload, algorithm='RS256') diff --git a/src/business_logic/get_tokens/service_impls/client_credentials.py b/src/business_logic/get_tokens/service_impls/client_credentials.py new file mode 100644 index 00000000..e69de29b diff --git a/src/business_logic/get_tokens/service_impls/device_code.py b/src/business_logic/get_tokens/service_impls/device_code.py new file mode 100644 index 00000000..e69de29b diff --git a/src/business_logic/get_tokens/service_impls/refresh_token.py b/src/business_logic/get_tokens/service_impls/refresh_token.py new file mode 100644 index 00000000..2dac4bf3 --- /dev/null +++ b/src/business_logic/get_tokens/service_impls/refresh_token.py @@ -0,0 +1,37 @@ +from __future__ import annotations +import time +import uuid +from src.business_logic.get_tokens.dto import RequestTokenModel, ResponseTokenModel +from src.business_logic.jwt_manager.dto import AccessTokenPayload, RefreshTokenPayload, IdTokenPayload +from src.dyna_config import DOMAIN_NAME +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from src.business_logic.common.interfaces import ValidatorProtocol + from src.business_logic.jwt_manager.interfaces import JWTManagerProtocol + from src.data_access.postgresql.repositories import PersistentGrantRepository + + +class RefreshTokenGrantService: + def __init__( + self, + grant_validator: ValidatorProtocol, + redirect_uri_validator: ValidatorProtocol, + client_validator: ValidatorProtocol, + refresh_token_validator: ValidatorProtocol, + grant_exp_validator: ValidatorProtocol, + jwt_manager: JWTManagerProtocol, + persistent_grant_repo: PersistentGrantRepository + ): + self._grant_validator = grant_validator + self._redirect_uri_validator = redirect_uri_validator + self._client_validator = client_validator + self._code_validator = refresh_token_validator + self._grant_expiration_validator = grant_exp_validator + self._jwt_manager = jwt_manager + self._persistent_grant_repo = persistent_grant_repo + + async def get_tokens(self, request_data: RequestTokenModel) -> ResponseTokenModel: + await self._client_validator(request_data.client_id) + await self._code_validator(request_data.refresh_token, request_data.client_id, request_data.grant_type) + await self._grant_validator(request_data.refresh_token, request_data.grant_type) + await self._redirect_uri_validator(request_data.redirect_uri, request_data.client_id) diff --git a/src/business_logic/get_tokens/validators/__init__.py b/src/business_logic/get_tokens/validators/__init__.py new file mode 100644 index 00000000..e144df7c --- /dev/null +++ b/src/business_logic/get_tokens/validators/__init__.py @@ -0,0 +1,7 @@ +from .validate_persistent_grant import ValidatePersistentGrant +from .validate_redirect_uri import ValidateRedirectUri +from .validate_grant_and_client import ValidateGrantByClient +from .validate_grant_expiration import ValidateGrantExpired + + +__all__ = ['ValidatePersistentGrant', 'ValidateRedirectUri', 'ValidateGrantByClient', 'ValidateGrantExpired'] diff --git a/src/business_logic/get_tokens/validators/validate_grant_and_client.py b/src/business_logic/get_tokens/validators/validate_grant_and_client.py new file mode 100644 index 00000000..a875c200 --- /dev/null +++ b/src/business_logic/get_tokens/validators/validate_grant_and_client.py @@ -0,0 +1,20 @@ +from src.business_logic.get_tokens.errors import InvalidGrantError + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from src.data_access.postgresql.repositories import PersistentGrantRepository + + +class ValidateGrantByClient: + """ + Checks that authorization code was issued to the authenticated Client. + """ + def __init__( + self, + persistent_grant_repo: 'PersistentGrantRepository' + ): + self._persistent_grant_repo = persistent_grant_repo + + async def __call__(self, authorization_code: str, client_id: str, grant_type: str) -> None: + if not await self._persistent_grant_repo.exists_grant_for_client(authorization_code, client_id, grant_type): + raise InvalidGrantError('Invalid data provided.') diff --git a/src/business_logic/get_tokens/validators/validate_grant_expiration.py b/src/business_logic/get_tokens/validators/validate_grant_expiration.py new file mode 100644 index 00000000..28014de2 --- /dev/null +++ b/src/business_logic/get_tokens/validators/validate_grant_expiration.py @@ -0,0 +1,6 @@ +from src.business_logic.get_tokens.errors import InvalidGrantError + + +class ValidateGrantExpired: + def __call__(self, grant_created_datetime: str, grant_expiration: int) -> None: + pass diff --git a/src/business_logic/get_tokens/validators/validate_persistent_grant.py b/src/business_logic/get_tokens/validators/validate_persistent_grant.py new file mode 100644 index 00000000..aaec560b --- /dev/null +++ b/src/business_logic/get_tokens/validators/validate_persistent_grant.py @@ -0,0 +1,17 @@ +from src.business_logic.get_tokens.errors import InvalidGrantError + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from src.data_access.postgresql.repositories import PersistentGrantRepository + + +class ValidatePersistentGrant: + def __init__( + self, + persistent_grant_repo: 'PersistentGrantRepository' + ): + self._persistent_grant_repo = persistent_grant_repo + + async def __call__(self, code_to_validate: str, grant_type: str) -> None: + if not await self._persistent_grant_repo.exists(code_to_validate, grant_type): + raise InvalidGrantError('Invalid grant value or grant type.') diff --git a/src/business_logic/get_tokens/validators/validate_redirect_uri.py b/src/business_logic/get_tokens/validators/validate_redirect_uri.py new file mode 100644 index 00000000..4bf0307a --- /dev/null +++ b/src/business_logic/get_tokens/validators/validate_redirect_uri.py @@ -0,0 +1,18 @@ +from src.business_logic.get_tokens.errors import InvalidRedirectUriError + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from src.data_access.postgresql.repositories import ClientRepository + + +class ValidateRedirectUri: + def __init__( + self, + client_repo: 'ClientRepository' + ): + self._client_repo = client_repo + + async def __call__(self, redirect_uri: str, client_id: str) -> None: + uris_list = await self._client_repo.list_all_redirect_uris_by_client(client_id=client_id) + if redirect_uri not in uris_list: + raise InvalidRedirectUriError diff --git a/src/business_logic/jwt_manager/__init__.py b/src/business_logic/jwt_manager/__init__.py new file mode 100644 index 00000000..e72b7b67 --- /dev/null +++ b/src/business_logic/jwt_manager/__init__.py @@ -0,0 +1,4 @@ +from .service_impls.jwt_service import JWTManager + + +__all__ = ['JWTManager'] diff --git a/src/business_logic/jwt_manager/dto/__init__.py b/src/business_logic/jwt_manager/dto/__init__.py new file mode 100644 index 00000000..c937fe95 --- /dev/null +++ b/src/business_logic/jwt_manager/dto/__init__.py @@ -0,0 +1,8 @@ +from .input import ( + AccessTokenPayload, + RefreshTokenPayload, + IdTokenPayload, +) + + +__all__ = ['AccessTokenPayload', 'RefreshTokenPayload', 'IdTokenPayload'] diff --git a/src/business_logic/jwt_manager/dto/input.py b/src/business_logic/jwt_manager/dto/input.py new file mode 100644 index 00000000..5f048939 --- /dev/null +++ b/src/business_logic/jwt_manager/dto/input.py @@ -0,0 +1,33 @@ +from pydantic import BaseModel +from typing import Optional, Any + + +class BaseJWTPayload(BaseModel): + sub: str # user id + iss: str # auth service uri + iat: int # time of creation + exp: int # time when token will expire + aud: str # name for whom token was generated + client_id: str # id of the client who issued a token + jti: str # uniques identifier for token, UUID4 + acr: Optional[int] # default 0 + + +class AccessTokenPayload(BaseJWTPayload): + typ: str = 'Bearer' + + +class RefreshTokenPayload(BaseModel): + jti: str + + +class IdTokenPayload(BaseJWTPayload): + typ: str = 'ID' + email: Optional[str] = None + email_verified: Optional[str] = None + given_name: Optional[str] = None + last_name: Optional[str] = None + preferred_username: Optional[str] = None + picture: Optional[str] = None + zoneinfo: Optional[str] = None + locale: Optional[str] = None diff --git a/src/business_logic/jwt_manager/dto/output.py b/src/business_logic/jwt_manager/dto/output.py new file mode 100644 index 00000000..e69de29b diff --git a/src/business_logic/jwt_manager/interfaces.py b/src/business_logic/jwt_manager/interfaces.py new file mode 100644 index 00000000..f1848b47 --- /dev/null +++ b/src/business_logic/jwt_manager/interfaces.py @@ -0,0 +1,17 @@ +from src.business_logic.jwt_manager.dto import ( + AccessTokenPayload, + RefreshTokenPayload, + IdTokenPayload, +) +from typing import Protocol, Any, Union + + +Payload = Union[AccessTokenPayload, RefreshTokenPayload, IdTokenPayload] + + +class JWTManagerProtocol(Protocol): + def encode(self, payload: Payload, algorithm: str) -> str: + raise NotImplementedError + + def decode(self, token: str, audience: str,**kwargs: Any) -> dict[str, Any]: + raise NotImplementedError diff --git a/src/business_logic/jwt_manager/service_impls/__init__.py b/src/business_logic/jwt_manager/service_impls/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/business_logic/jwt_manager/service_impls/jwt_service.py b/src/business_logic/jwt_manager/service_impls/jwt_service.py new file mode 100644 index 00000000..1fb6ca39 --- /dev/null +++ b/src/business_logic/jwt_manager/service_impls/jwt_service.py @@ -0,0 +1,44 @@ +from __future__ import annotations +import logging +import jwt +from src.config.rsa_keys import RSAKeypair +from src.di import Container +from src.business_logic.jwt_manager.dto import ( + AccessTokenPayload, + RefreshTokenPayload, + IdTokenPayload, +) +from typing import Any, Optional, Union + + +logger = logging.getLogger(__name__) + + +Payload = Union[AccessTokenPayload, RefreshTokenPayload, IdTokenPayload] + + +class JWTManager: + def __init__(self, keys: RSAKeypair = Container().config().keys) -> None: + self.keys = keys + + def encode(self, payload: Payload, algorithm: str, secret: Optional[str] = None) -> str: + if secret: + key = secret + else: + key = self.keys.private_key + + token = jwt.encode( + payload=payload.dict(exclude_none=True), key=key, algorithm=algorithm + ) + return token + + def decode(self, token: str, audience: Optional[str] = None, **kwargs: Any) -> dict[str, Any]: + token = token.replace("Bearer ", "") + if audience: + decoded_info = jwt.decode(token, key=self.keys.public_key, algorithms=self.algorithms, + audience=audience, **kwargs,) + else: + decoded_info = jwt.decode(token, key=self.keys.public_key, algorithms=self.algorithms, + **kwargs,) + + return decoded_info diff --git a/src/di/providers/jwt_manager.py b/src/di/providers/jwt_manager.py new file mode 100644 index 00000000..fb0c4951 --- /dev/null +++ b/src/di/providers/jwt_manager.py @@ -0,0 +1,10 @@ +from src.business_logic.jwt_manager import JWTManager +from src.business_logic.jwt_manager.interfaces import JWTManagerProtocol + + +def provide_jwt_manager_stub() -> None: + ... + + +def provide_jwt_manager() -> JWTManagerProtocol: + return JWTManager() diff --git a/src/di/providers/services.py b/src/di/providers/services.py index 51d14675..f29e80fa 100644 --- a/src/di/providers/services.py +++ b/src/di/providers/services.py @@ -13,7 +13,7 @@ ThirdPartyGitLabService, ThirdPartyMicrosoftService, DeviceService, - EndSessionService, + # EndSessionService, IntrospectionServies, JWTService, LoginFormService, @@ -26,6 +26,7 @@ UserInfoServices, WellKnownServices, ) +from src.business_logic.endsession.endsession_service import EndSessionService from src.data_access.postgresql.repositories import ( ClientRepository, DeviceRepository, diff --git a/src/di/providers/services_factory.py b/src/di/providers/services_factory.py new file mode 100644 index 00000000..a1793111 --- /dev/null +++ b/src/di/providers/services_factory.py @@ -0,0 +1,31 @@ +from src.business_logic.get_tokens import TokenServiceFactory +from src.data_access.postgresql.repositories import ( + BlacklistedTokenRepository, + ClientRepository, + DeviceRepository, + PersistentGrantRepository, + UserRepository, +) +from src.business_logic.jwt_manager.interfaces import JWTManagerProtocol + + +def provide_token_service_factory_stub() -> None: + ... + + +def provide_token_service_factory( + client_repo: ClientRepository, + persistent_grant_repo: PersistentGrantRepository, + user_repo: UserRepository, + device_repo: DeviceRepository, + jwt_service: JWTManagerProtocol, + blacklisted_repo: BlacklistedTokenRepository, +) -> TokenServiceFactory: + return TokenServiceFactory( + client_repo=client_repo, + persistent_grant_repo=persistent_grant_repo, + user_repo=user_repo, + device_repo=device_repo, + jwt_manager=jwt_service, + blacklisted_repo=blacklisted_repo + ) diff --git a/src/presentation/api/exception_handlers/__init__.py b/src/presentation/api/exception_handlers/__init__.py new file mode 100644 index 00000000..e2f53c69 --- /dev/null +++ b/src/presentation/api/exception_handlers/__init__.py @@ -0,0 +1,10 @@ +from .http400_invalid_grant import http400_invalid_grant_handler +from .http400_invalid_client import http400_invalid_client_handler +from .http400_unsupported_grant_type import http400_unsupported_grant_type_handler + + +__all__ = [ + 'http400_invalid_grant_handler', + 'http400_invalid_client_handler', + 'http400_unsupported_grant_type_handler' +] diff --git a/src/presentation/api/exception_handlers/http400_invalid_client.py b/src/presentation/api/exception_handlers/http400_invalid_client.py new file mode 100644 index 00000000..43f5741a --- /dev/null +++ b/src/presentation/api/exception_handlers/http400_invalid_client.py @@ -0,0 +1,17 @@ +from src.business_logic.common.errors import InvalidClientIdError +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.status import HTTP_400_BAD_REQUEST + + +async def http400_invalid_client_handler( + _: Request, + exc: InvalidClientIdError +) -> JSONResponse: + headers = {"Cache-Control": "no-store", "Pragma": "no-cache"} + content = {"error": "invalid_client"} + return JSONResponse( + content=content, + headers=headers, + status_code=HTTP_400_BAD_REQUEST + ) diff --git a/src/presentation/api/exception_handlers/http400_invalid_grant.py b/src/presentation/api/exception_handlers/http400_invalid_grant.py new file mode 100644 index 00000000..f9faa5ea --- /dev/null +++ b/src/presentation/api/exception_handlers/http400_invalid_grant.py @@ -0,0 +1,21 @@ +from typing import Union +from src.business_logic.get_tokens.errors import InvalidGrantError, InvalidRedirectUriError +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.status import HTTP_400_BAD_REQUEST + + +ExceptionsToHandle = Union[InvalidGrantError, InvalidRedirectUriError] + + +async def http400_invalid_grant_handler( + _: Request, + exc: ExceptionsToHandle +) -> JSONResponse: + headers = {"Cache-Control": "no-store", "Pragma": "no-cache"} + content = {"error": "invalid_grant"} + return JSONResponse( + content=content, + headers=headers, + status_code=HTTP_400_BAD_REQUEST + ) diff --git a/src/presentation/api/exception_handlers/http400_unsupported_grant_type.py b/src/presentation/api/exception_handlers/http400_unsupported_grant_type.py new file mode 100644 index 00000000..d5d327ab --- /dev/null +++ b/src/presentation/api/exception_handlers/http400_unsupported_grant_type.py @@ -0,0 +1,17 @@ +from src.business_logic.get_tokens.errors import UnsupportedGrantTypeError +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.status import HTTP_400_BAD_REQUEST + + +async def http400_unsupported_grant_type_handler( + _: Request, + exc: UnsupportedGrantTypeError +) -> JSONResponse: + headers = {"Cache-Control": "no-store", "Pragma": "no-cache"} + content = {"error": "unsupported_grant_type"} + return JSONResponse( + content=content, + headers=headers, + status_code=HTTP_400_BAD_REQUEST + ) diff --git a/src/presentation/api/routes/endsession.py b/src/presentation/api/routes/endsession.py index 71a5fd74..8d0c9c9a 100644 --- a/src/presentation/api/routes/endsession.py +++ b/src/presentation/api/routes/endsession.py @@ -5,7 +5,7 @@ from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse, RedirectResponse -from src.business_logic.services.endsession import EndSessionService +from src.business_logic.endsession.interfaces import EndSessionServiceProtocol from src.data_access.postgresql.errors.client import ( ClientPostLogoutRedirectUriError, ) @@ -24,14 +24,12 @@ @endsession_router.get("/", status_code=status.HTTP_204_NO_CONTENT) async def end_session( request_model: RequestEndSessionModel = Depends(), - service_class: EndSessionService = Depends( + service_class: EndSessionServiceProtocol = Depends( provide_endsession_service_stub ), ) -> Union[int, RedirectResponse, JSONResponse]: try: - service_class = service_class - service_class.request_model = request_model - logout_redirect_uri = await service_class.end_session() + logout_redirect_uri = await service_class.end_session(request_model) if logout_redirect_uri is None: return status.HTTP_204_NO_CONTENT diff --git a/tests/conftest.py b/tests/conftest.py index 5360388d..6125f485 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,8 @@ import asyncio from src.main import get_application from src.business_logic.services.authorization import AuthorizationService -from src.business_logic.services.endsession import EndSessionService +# from src.business_logic.services.endsession import EndSessionService +from src.business_logic.endsession.endsession_service import EndSessionService from src.business_logic.services.userinfo import UserInfoServices from src.business_logic.services import DeviceService, WellKnownServices diff --git a/tests/test_api/test_endsession_endpoint.py b/tests/test_api/test_endsession_endpoint.py index d6fd94c5..ebeb91de 100644 --- a/tests/test_api/test_endsession_endpoint.py +++ b/tests/test_api/test_endsession_endpoint.py @@ -1,12 +1,13 @@ import jwt import pytest -from fastapi import status +from fastapi import status, FastAPI from httpx import AsyncClient from sqlalchemy import delete, insert from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio.engine import AsyncEngine from sqlalchemy.orm import sessionmaker +import src.di.providers as prov from src.data_access.postgresql.tables.persistent_grant import PersistentGrant from tests.test_unit.fixtures import ( end_session_request_model, @@ -14,7 +15,7 @@ TokenHint, ) from src.presentation.api.models.endsession import RequestEndSessionModel -from src.business_logic.services.endsession import EndSessionService +from src.business_logic.endsession.endsession_service import EndSessionService from tests.test_unit.fixtures import ( end_session_request_model, TOKEN_HINT_DATA, @@ -28,11 +29,15 @@ async def test_successful_authorize_request( self, engine: AsyncEngine, client: AsyncClient, - end_session_service: EndSessionService, - end_session_request_model: RequestEndSessionModel, + ) -> None: - service = end_session_service - service.request_model = end_session_request_model + + service = prov.provide_endsession_service( + client_repo=prov.provide_client_repo(engine), + persistent_grant_repo=prov.provide_persistent_grant_repo(engine), + jwt_service=prov.provide_jwt_service(), + ) + secret = await service.client_repo.get_client_secrete_by_client_id( client_id=TOKEN_HINT_DATA["client_id"] ) @@ -71,14 +76,10 @@ async def test_successful_authorize_request( await session.commit() async def test_successful_authorize_request_without_uri( - self, - engine: AsyncEngine, - client: AsyncClient, - end_session_service: EndSessionService, - end_session_request_model: RequestEndSessionModel, + self, + engine: AsyncEngine, + client: AsyncClient, ) -> None: - service = end_session_service - service.request_model = end_session_request_model session_factory = sessionmaker( engine, expire_on_commit=False, class_=AsyncSession ) @@ -105,14 +106,10 @@ async def test_successful_authorize_request_without_uri( assert response.status_code == status.HTTP_204_NO_CONTENT async def test_successful_authorize_request_wrong_uri( - self, - engine: AsyncEngine, - client: AsyncClient, - end_session_service: EndSessionService, - end_session_request_model: RequestEndSessionModel, + self, + engine: AsyncEngine, + client: AsyncClient, ) -> None: - service = end_session_service - service.request_model = end_session_request_model session_factory = sessionmaker( engine, expire_on_commit=False, class_=AsyncSession ) @@ -145,11 +142,7 @@ async def test_successful_authorize_request_wrong_uri( async def test_end_session_bad_token( self, client: AsyncClient, - end_session_service: EndSessionService, - end_session_request_model: RequestEndSessionModel, ) -> None: - service = end_session_service - service.request_model = end_session_request_model expected_content = '{"message":"Bad id_token_hint"}' params = { "id_token_hint": "id_token_hint", @@ -160,16 +153,11 @@ async def test_end_session_bad_token( assert response.content.decode("UTF-8") == expected_content async def test_end_session_not_full_token( - self, - client: AsyncClient, - end_session_service: EndSessionService, - end_session_request_model: RequestEndSessionModel, + self, + client: AsyncClient ) -> None: hint = TokenHint() short_token_hint = await hint.get_short_token_hint() - service = end_session_service - service.request_model = end_session_request_model - service.request_model.id_token_hint = short_token_hint expected_content = ( '{"message":"The id_token_hint is missing something"}' @@ -185,13 +173,9 @@ async def test_end_session_not_full_token( async def test_end_session_no_persistent_grant( self, client: AsyncClient, - end_session_service: EndSessionService, - end_session_request_model: RequestEndSessionModel, ) -> None: hint = TokenHint() token_hint = await hint.get_token_hint() - service = end_session_service - service.request_model = end_session_request_model expected_content = '{"message":"You are not logged in"}' params = { "id_token_hint": token_hint, diff --git a/tests/test_unit/test_services/test_end_session_service.py b/tests/test_unit/test_services/test_end_session_service.py index 23055b20..4a379380 100644 --- a/tests/test_unit/test_services/test_end_session_service.py +++ b/tests/test_unit/test_services/test_end_session_service.py @@ -1,52 +1,25 @@ import jwt import pytest +from unittest.mock import AsyncMock, MagicMock from sqlalchemy import insert, select, delete +from tests.test_unit.fixtures import end_session_request_model, TOKEN_HINT_DATA +from src.business_logic.endsession.errors import InvalidLogoutRedirectUriError +from src.business_logic.endsession.validators import ValidateLogoutRedirectUri from src.data_access.postgresql.errors import ClientPostLogoutRedirectUriError from src.data_access.postgresql.tables.persistent_grant import PersistentGrant from src.data_access.postgresql.errors.persistent_grant import ( PersistentGrantNotFoundError, ) -from tests.test_unit.fixtures import end_session_request_model, TOKEN_HINT_DATA -from src.business_logic.services.endsession import EndSessionService +from src.business_logic.endsession.endsession_service import EndSessionService + from src.presentation.api.models.endsession import RequestEndSessionModel from sqlalchemy.ext.asyncio.engine import AsyncEngine @pytest.mark.asyncio class TestEndSessionService: - async def test_validate_logout_redirect_uri( - self, - end_session_service: EndSessionService, - end_session_request_model: RequestEndSessionModel, - ) -> None: - service = end_session_service - service.request_model = end_session_request_model - if not service.request_model.post_logout_redirect_uri: - raise AssertionError - result = await service._validate_logout_redirect_uri( - client_id="test_client", - logout_redirect_uri=service.request_model.post_logout_redirect_uri, - ) - - assert result is True - - async def test_validate_logout_redirect_uri_error( - self, - end_session_service: EndSessionService, - end_session_request_model: RequestEndSessionModel, - ) -> None: - service = end_session_service - service.request_model = end_session_request_model - service.request_model.post_logout_redirect_uri = "not_exist_uri" - with pytest.raises(ClientPostLogoutRedirectUriError): - await service._validate_logout_redirect_uri( - client_id="client_not_exist", - logout_redirect_uri=service.request_model.post_logout_redirect_uri, - ) - a = 1 - async def test_logout( self, end_session_service: EndSessionService, @@ -114,8 +87,7 @@ async def test_end_session( connection: AsyncEngine, ) -> None: service = end_session_service - service.request_model = end_session_request_model - service.request_model.post_logout_redirect_uri = "https://www.cole.com/" + end_session_request_model.post_logout_redirect_uri = "https://www.cole.com/" await connection.execute( insert(PersistentGrant).values( key="test_key", @@ -127,7 +99,7 @@ async def test_end_session( ) ) await connection.commit() - redirect_uri = await service.end_session() + redirect_uri = await service.end_session(end_session_request_model) assert redirect_uri == "https://www.cole.com/&state=test_state" await connection.execute( @@ -142,9 +114,8 @@ async def test_end_session_without_state( connection: AsyncEngine, ) -> None: service = end_session_service - service.request_model = end_session_request_model - service.request_model.post_logout_redirect_uri = "https://www.cole.com/" - service.request_model.state = None + end_session_request_model.post_logout_redirect_uri = "https://www.cole.com/" + end_session_request_model.state = None await connection.execute( insert(PersistentGrant).values( key="test_key", @@ -156,7 +127,7 @@ async def test_end_session_without_state( ) ) await connection.commit() - redirect_uri = await service.end_session() + redirect_uri = await service.end_session(end_session_request_model) assert redirect_uri == "https://www.cole.com/" await connection.execute( @@ -171,7 +142,7 @@ async def test_end_session_wrong_uri( connection: AsyncEngine, ) -> None: service = end_session_service - service.request_model = end_session_request_model + # service.request_model = end_session_request_model await connection.execute( insert(PersistentGrant).values( key="test_key", @@ -185,9 +156,45 @@ async def test_end_session_wrong_uri( await connection.commit() with pytest.raises(ClientPostLogoutRedirectUriError): - await service.end_session() + await service.end_session(end_session_request_model) await connection.execute( delete(PersistentGrant).where(PersistentGrant.client_id == 3) ) await connection.commit() + +@pytest.mark.asyncio +class TestEndSessionServiceValidators: + async def test_validate_logout_redirect_uri_valid_uri(self): + request_model = RequestEndSessionModel( + id_token_hint="some_id_token_hint", + post_logout_redirect_uri="https://valid-logout-uri.com", + ) + + client_repo_mock = AsyncMock() + client_repo_mock.validate_post_logout_redirect_uri.return_value = True + + validator = ValidateLogoutRedirectUri(client_repo=client_repo_mock) + await validator(request_model=request_model, client_id="client_id") + + client_repo_mock.validate_post_logout_redirect_uri.assert_called_once_with( + "client_id", "https://valid-logout-uri.com" + ) + + async def test_validate_logout_redirect_uri_invalid_uri(self): + + client_repo_mock = AsyncMock() + client_repo_mock.validate_post_logout_redirect_uri.return_value = False + validator = ValidateLogoutRedirectUri(client_repo=client_repo_mock) + + request_model = RequestEndSessionModel( + id_token_hint="some_id_token_hint", + post_logout_redirect_uri="invalid_uri", + state="some_state" + ) + + with pytest.raises(InvalidLogoutRedirectUriError, match="Invalid post logout redirect uri"): + await validator(request_model=request_model, client_id="client_id") + + +