diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index f701291..ca7fdd5 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -34,7 +34,7 @@ jobs: run: uv run mypy - name: Run Pytest with Coverage - run: uv run pytest --cov + run: uv run pytest --cov -m "not e2e" - name: Creating coverage folder run: mkdir -p coverage diff --git a/pyproject.toml b/pyproject.toml index 274456c..6fe8f1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] -name = "authapigateway" -version = "0.0.1" -description = "Cookie-based JWT Auth API" +name = "Auth-API" +version = "0.0.3" +description = "RBAC JWT Cookies-based API" readme = "README.md" requires-python = ">=3.12" dependencies = [ @@ -59,7 +59,7 @@ lint.ignore = [ [tool.ruff.lint.per-file-ignores] "src/tests/*" = ["D103", "S101"] "src/app/user/exceptions.py" = ["D101"] -"src/alembic/versions/*" = ["D103", "D400", "D415", "INP001"] +"src/alembic/*" = ["D103", "D400", "D415", "INP001", "I001"] [tool.mypy] python_version = "3.12" diff --git a/src/alembic/env.py b/src/alembic/env.py index eb5a3ea..cdf07e0 100644 --- a/src/alembic/env.py +++ b/src/alembic/env.py @@ -1,11 +1,11 @@ -import asyncio # noqa: INP001 +import asyncio from logging.config import fileConfig +from alembic import context from sqlalchemy.engine.base import Connection from sqlalchemy.ext.asyncio import async_engine_from_config from sqlalchemy.pool import NullPool -from alembic import context from app.main.settings import settings from app.user.models import User diff --git a/src/app/admin/views.py b/src/app/admin/views.py index abf3bb3..9efcb81 100644 --- a/src/app/admin/views.py +++ b/src/app/admin/views.py @@ -4,9 +4,9 @@ from sqladmin import ModelView from wtforms.fields import EmailField, Field, PasswordField, SelectField -from app.user.auth import generate_password_hash from app.user.constants import UserRole from app.user.models import User +from app.user.security.passwords import generate_password_hash if TYPE_CHECKING: from sqladmin._types import MODEL_ATTR diff --git a/src/app/main/app.py b/src/app/main/app.py index 27c4691..a30738a 100644 --- a/src/app/main/app.py +++ b/src/app/main/app.py @@ -2,21 +2,16 @@ from contextlib import asynccontextmanager from fastapi import FastAPI -from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware -from sqlalchemy.exc import IntegrityError +from redis.asyncio import Redis from app.admin.app import init_admin_app -from app.user.errors import ( - db_integrity_error_handler, - unexpected_error_handler, - validation_error_handler, -) +from app.user.errors import error_handlers as user_error_handlers from app.user.routes import router as user_router from .db import engine +from .errors import error_handlers as main_error_handlers from .logger import configure_logging -from .redis import redis from .routes import router as default_router from .settings import settings @@ -26,39 +21,44 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: """Init and release objects on app startup and shutdown. On startup: - - init Redis client with connection pool; - - init Admin app with the specified database engine; + - add logger object to the application state; + - add Redis client object to the application state; + - init Admin app with the provided database engine; On shutdown: - - dispose Redis client with connection pool; + - dispose Redis client including its connection pool; - dispose database engine. - NOTE: during testing state's objects are redeclared to adapt the + NOTE: during testing state objects are redeclared to adapt the application to the testing environment. """ - redis_client = getattr(app.state, "redis", redis) - database_engine = getattr(app.state, "engine", engine) + db = getattr(app.state, "engine", engine) + logger = getattr( + app.state, + "logger", + configure_logging(settings.log_name, options=settings.logging_kwargs), + ) + redis = getattr(app.state, "redis", Redis(**settings.redis_kwargs)) - init_admin_app(app, database_engine) + # Init admin app based on the provided database engine + init_admin_app(app, db) - # set application state - app.state.logger = configure_logging() - app.state.redis = redis_client + # Init state objects of the app + app.state.logger = logger + app.state.redis = redis yield - await redis_client.close() - await redis_client.connection_pool.disconnect() - await database_engine.dispose() + await redis.aclose() + await db.dispose() app = FastAPI( **settings.fastapi_kwargs, lifespan=lifespan, exception_handlers={ - IntegrityError: db_integrity_error_handler, - RequestValidationError: validation_error_handler, - Exception: unexpected_error_handler, + **user_error_handlers, + **main_error_handlers, }, ) app.add_middleware(CORSMiddleware, **settings.cors_kwargs) diff --git a/src/app/main/dependencies.py b/src/app/main/dependencies.py index 7dbaba4..79c858c 100644 --- a/src/app/main/dependencies.py +++ b/src/app/main/dependencies.py @@ -20,13 +20,14 @@ async def get_db() -> AsyncIterator["AsyncSession"]: async def get_redis(request: Request) -> "Redis": - """Return initialized Redis client to be used as a dependency. + """Return initialized in the lifespan client for Redis. :param request: request object providing access to the app state - :return: Redis client + :return: client for Redis """ return cast("Redis", request.app.state.redis) +# https://fastapi.tiangolo.com/tutorial/dependencies/#share-annotated-dependencies DbSession: TypeAlias = Annotated["AsyncSession", Depends(get_db)] RedisT: TypeAlias = Annotated["Redis", Depends(get_redis)] diff --git a/src/app/main/errors.py b/src/app/main/errors.py new file mode 100644 index 0000000..07348b6 --- /dev/null +++ b/src/app/main/errors.py @@ -0,0 +1,42 @@ +from collections.abc import Callable, Coroutine +from typing import TYPE_CHECKING, Any, cast + +from fastapi import Request, Response, status +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse + +if TYPE_CHECKING: + from logging import Logger + + +async def validation_error_handler( + _: Request, + e: RequestValidationError, +) -> JSONResponse: + """Pydantic validation error handler.""" + error = e.errors()[0] + field = error["loc"][1] + error_message = f"{error['msg']}. Field: {field}" + client_message = {"detail": error_message} + return JSONResponse(client_message, status.HTTP_422_UNPROCESSABLE_ENTITY) + + +async def unexpected_error_handler( + request: Request, + e: Exception, +) -> JSONResponse: + """Error handler for all uncaught exceptions.""" + logger = cast("Logger", request.app.state.logger) + logger.critical("Internal Server Error: %s", e) + client_message = {"error": "Service is temporarily unavailable"} + return JSONResponse(client_message, status.HTTP_500_INTERNAL_SERVER_ERROR) + + +# error handlers mapping to be registered in the main FastAPI app +error_handlers: dict[ + int | type[Exception], + Callable[[Request, Any], Coroutine[Any, Any, Response]], +] = { + RequestValidationError: validation_error_handler, + Exception: unexpected_error_handler, +} diff --git a/src/app/main/exceptions.py b/src/app/main/exceptions.py index 9b3d213..86facec 100644 --- a/src/app/main/exceptions.py +++ b/src/app/main/exceptions.py @@ -1,7 +1,5 @@ from fastapi import HTTPException, status -from app.main.settings import settings - class BaseHTTPException(HTTPException): """Base HTTP exception class.""" @@ -19,5 +17,5 @@ def __init__( super().__init__( status_code or self.status_code, detail or self.detail, - headers or {"WWW-Authenticate": settings.token_scheme}, + headers or {"WWW-Authenticate": "Bearer"}, ) diff --git a/src/app/main/logger.py b/src/app/main/logger.py index 9c51d94..016a38a 100644 --- a/src/app/main/logger.py +++ b/src/app/main/logger.py @@ -1,13 +1,14 @@ import logging from typing import TYPE_CHECKING -from .settings import settings - if TYPE_CHECKING: from logging import Logger + from .settings import LoggingKwargs + -def configure_logging() -> "Logger": +def configure_logging(name: str, options: "LoggingKwargs | None" = None) -> "Logger": """Configure app logging and return logger object.""" - logging.basicConfig(**settings.logging_kwargs) - return logging.getLogger(settings.log_name) + if options is not None: + logging.basicConfig(**options) + return logging.getLogger(name) diff --git a/src/app/main/redis.py b/src/app/main/redis.py deleted file mode 100644 index 3283d12..0000000 --- a/src/app/main/redis.py +++ /dev/null @@ -1,9 +0,0 @@ -from redis.asyncio import ConnectionPool, Redis - -from .settings import settings - -redis_connection_pool = ConnectionPool.from_url( - settings.redis_url, - decode_responses=True, -) -redis = Redis(connection_pool=redis_connection_pool) diff --git a/src/app/main/settings.py b/src/app/main/settings.py index 127c1da..f1f4b61 100644 --- a/src/app/main/settings.py +++ b/src/app/main/settings.py @@ -1,9 +1,14 @@ +import json import logging -from functools import lru_cache +from functools import cached_property, lru_cache from pathlib import Path -from typing import TypedDict +from typing import TypedDict, cast from pydantic_settings import BaseSettings +from redis.asyncio import ConnectionPool + +# type to describe content of the RBAC policy file +type RBACPolicyT = dict[str, dict[str, list[str]]] class FastAPIKwargs(TypedDict): @@ -42,30 +47,37 @@ class CORSMiddlewareKwargs(TypedDict): max_age: int +class RedisKwargs(TypedDict): + """Kwargs for Redis client.""" + + connection_pool: ConnectionPool + + class Settings(BaseSettings): """Main project settings.""" # FastAPI settings - title: str = "JWT Auth API Gateway" - description: str = "RBAC JWT Cookies-based API Gateway" - version: str = "0.0.1" - debug: bool = True + title: str = "Auth API" + description: str = "RBAC JWT Cookies-based API" + version: str = "0.0.3" + debug: bool = False docs_url: str = "/" # Admin app setting admin_base_url: str = "/admin" admin_title: str = "Auth API Admin" - admin_debug: bool = True + admin_debug: bool = debug # Logging settings log_name: str = "app" log_level: int = logging.INFO log_format: str = "%(levelname)s - %(name)s - %(asctime)s - %(message)s" log_datefmt: str = "%Y-%m-%d %H:%M:%S" - # name of the logger used in testing + # name of the logger used when running tests test_log_name: str = "test" host_server_domain: str = "localhost" + test_client_base_url: str = f"http://{host_server_domain}" # CORS settings cors_max_age: int = 600 # seconds @@ -83,9 +95,7 @@ class Settings(BaseSettings): # JWT settings jwt_algorithm: str = "HS256" - jwt_secret_key: str = "supersecret12345" # noqa: S105 - - # JWT Tokens settings + jwt_secret_key: str = "" access_token_expiration_time: int = 5 # minutes access_token_cookie_expiration_time: int = ( access_token_expiration_time * 60 @@ -94,16 +104,23 @@ class Settings(BaseSettings): refresh_token_cookie_expiration_time: int = ( refresh_token_expiration_time * 60 ) # seconds - token_scheme: str = "Bearer" # noqa: S105 - # size of the cache to store payload of the corresponding tokens + # size of the cache to store payload of JWTs token_payload_max_cache_hits: int = 10_000 + # specifies value of JWTs in Redis to avoid their reuse + jwt_blacklist_name: str = "blacklist" # Passwords settings max_password_length: int = 50 min_password_length: int = 5 # RBAC settings - user_policy_file: Path = Path(__file__).parents[2] / "policy.json" + rbac_policy_fp: Path = Path(__file__).parents[2] / "policy.json" + + @cached_property + def rbac_policy(self) -> RBACPolicyT: + """Read and return content of the policy.json file.""" + with self.rbac_policy_fp.open() as f: + return cast("RBACPolicyT", json.load(f)) @property def fastapi_kwargs(self) -> FastAPIKwargs: @@ -145,6 +162,12 @@ def cors_kwargs(self) -> CORSMiddlewareKwargs: max_age=self.cors_max_age, ) + @property + def redis_kwargs(self) -> RedisKwargs: + """Kwargs for Redis client.""" + connection_pool = ConnectionPool.from_url(self.redis_url, decode_responses=True) + return RedisKwargs(connection_pool=connection_pool) + @lru_cache def get_settings() -> Settings: diff --git a/src/app/user/auth.py b/src/app/user/auth.py deleted file mode 100644 index 6329982..0000000 --- a/src/app/user/auth.py +++ /dev/null @@ -1,96 +0,0 @@ -import json -import re -from functools import lru_cache -from typing import TYPE_CHECKING, cast - -from jose import JWTError, jwt -from passlib.context import CryptContext - -from app.main.settings import settings - -from .exceptions import InvalidCredentialsError -from .schemas import TokenPayload - -if TYPE_CHECKING: - from .constants import TokenType, UserRole - -type PolicyT = dict[str, dict[str, list[str]]] - -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - - -def generate_password_hash(plain_password: str) -> str: - """Generate hash of the plain password. - - :param plain_password: plain password - :return: hash of the password - """ - return pwd_context.hash(plain_password) - - -def verify_password(plain_password: str, hashed_password: str) -> bool: - """Verify plain and hashed password. - - :param plain_password: plain password - :param hashed_password: hashed password - :return: boolean result of the verification - """ - return pwd_context.verify(plain_password, hashed_password) - - -def generate_token(token_type: "TokenType", username: str, role: str) -> str: - """Generate a JWT token of the specific type and claims. - - :param token_type: type of the token (access or refresh) - :param username: username to generate the token for - :param role: role of the user (user, admin, etc.) - :return: JWT token - """ - payload = TokenPayload(sub=username, typ=token_type, role=role) - token = jwt.encode( - claims=payload.model_dump(), - key=settings.jwt_secret_key, - algorithm=settings.jwt_algorithm, - ) - return f"{settings.token_scheme} {token}" - - -@lru_cache(maxsize=settings.token_payload_max_cache_hits) -def get_token_payload(token: str) -> TokenPayload: - """Decode the JWT token and return its payload. - - :param token: JWT token - :raises InvalidCredentialsError: if the token can't be decoded - :return: payload of the token - """ - try: - payload = jwt.decode( - token, - key=settings.jwt_secret_key, - algorithms=[settings.jwt_algorithm], - ) - except JWTError as e: - raise InvalidCredentialsError from e - return TokenPayload.model_construct(**payload) - - -@lru_cache -def verify_access(role: "UserRole", url: str) -> bool: - """Verify if the user can access the URL according to one's role. - - :param role: role of the user - :param url: target URL to check access to - :return: boolean result of the verification - """ - policy = read_user_policy() - return any(re.match(pattern, url) for pattern in policy[role]["locations"]) - - -@lru_cache -def read_user_policy() -> PolicyT: - """Read user policy file and return its content as JSON. - - :return: user policy as JSON - """ - with settings.user_policy_file.open() as f: - return cast("PolicyT", json.load(f)) diff --git a/src/app/user/constants.py b/src/app/user/constants.py index a256985..aec5373 100644 --- a/src/app/user/constants.py +++ b/src/app/user/constants.py @@ -1,23 +1,16 @@ -from enum import StrEnum +from enum import StrEnum, auto -class TokenType(StrEnum): - """Types of JWT tokens.""" +class TT(StrEnum): + """TT stands for "Token Type" and specifies types of JWT.""" - access = "access" - refresh = "refresh" - - -class CookieType(StrEnum): - """Types of cookies.""" - - access_token = "access_token" # noqa: S105 - refresh_token = "refresh_token" # noqa: S105 + access_token = auto() + refresh_token = auto() class UserRole(StrEnum): """User roles of the app.""" - admin = "admin" - moderator = "moderator" - user = "user" + admin = auto() + moderator = auto() + user = auto() diff --git a/src/app/user/dependencies.py b/src/app/user/dependencies.py index 30235a6..d3d9e7b 100644 --- a/src/app/user/dependencies.py +++ b/src/app/user/dependencies.py @@ -1,4 +1,4 @@ -from typing import Annotated +from typing import Annotated, TypeAlias from fastapi import Depends, Request from fastapi.security import OAuth2PasswordRequestForm @@ -6,79 +6,107 @@ from app.main.dependencies import DbSession, RedisT from app.main.settings import settings -from .auth import get_token_payload, verify_password -from .exceptions import ExpiredTokenError, IncorrectPasswordError -from .middlewares import OAuth2PasswordBearerWithCookie -from .models import User +from .constants import TT +from .exceptions import ( + AuthenticationError, + ExpiredTokenError, + InactiveUserError, + IncorrectPasswordError, +) from .schemas import TokenPayload, UserInDB +from .security.passwords import verify_password +from .security.tokens import get_token_payload +from .services.users import UsersDbService -oauth2_scheme = OAuth2PasswordBearerWithCookie(tokenUrl="user/login") + +async def get_user_service(session: DbSession) -> UsersDbService: + """Return instance of the UserService class. + + :param session: database session object + :return: instance of UsersDbService class + """ + return UsersDbService(session) + + +UserServiceT: TypeAlias = Annotated[UsersDbService, Depends(get_user_service)] async def authenticate_user( form: Annotated[OAuth2PasswordRequestForm, Depends()], - db: DbSession, + service: UserServiceT, ) -> UserInDB: """Return info about user from the database verifying one's password. :param form: input info about user from the form - :param db: DB session - :raises IncorrectPasswordError: if the password isn't verified + :param service: instance of UsersDbService object + :raises InactiveUserError: unless the user is active + :raises IncorrectPasswordError: unless the password is verified :return: info about user from the database """ - user = await User.get(db, form.username) + user = await service.get(form.username) if not verify_password(form.password, user.password): raise IncorrectPasswordError + if not user.is_active: + raise InactiveUserError return user -async def decode_token(token: Annotated[str, Depends(oauth2_scheme)]) -> TokenPayload: - """Decode JWT token and return its payload. +UserT: TypeAlias = Annotated[UserInDB, Depends(authenticate_user)] - :param token: user's JWT token - :return: payload of the token + +async def decode_jwt(request: Request) -> TokenPayload: + """Decode JWT and return its payload. + + :param request: FastAPI's Request object + :raises AuthenticationError: if auth cookies are missing + :return: payload of the JWT """ - return get_token_payload(token) + if access_token := request.cookies.get(TT.access_token.name): + return get_token_payload(access_token) + if refresh_token := request.cookies.get(TT.refresh_token.name): + return get_token_payload(refresh_token) + raise AuthenticationError -async def verify_token_against_blacklist(request: Request, redis: RedisT) -> None: - """Check if JWT tokens aren't blacklisted. +async def is_jwt_blacklisted(request: Request, redis: RedisT) -> None: + """Check if JWTs are blacklisted. :param request: FastAPI's Request object - :raises ExpiredTokenError: if any token is blacklisted - :param redis: Redis client + :raises ExpiredTokenError: if any JWT is blacklisted + :param redis: client for Redis """ if not request.cookies: return async with redis.pipeline() as pipe: - if access_token := request.cookies.get("access_token"): + if access_token := request.cookies.get(TT.access_token.name): pipe.exists(access_token) - if refresh_token := request.cookies.get("refresh_token"): + if refresh_token := request.cookies.get(TT.refresh_token.name): pipe.exists(refresh_token) if any(await pipe.execute()): raise ExpiredTokenError -async def add_tokens_to_blacklist(request: Request, redis: RedisT) -> None: - """Add access and/or refresh tokens into black list. - - Tokens are blacklisted to avoid their reuse. +async def blacklist_jwt(request: Request, redis: RedisT) -> None: + """Add access and/or refresh JWT into the black list to avoid their reuse. - NOTE (1): tokens are blacklisted until their corresponding - cookies aren't expired. + NOTE (1): tokens are blacklisted until their corresponding cookies are expired. :param request: FastAPI's Request object - :param redis: Redis client + :param redis: client for Redis """ if not request.cookies: return async with redis.pipeline() as pipe: - if access_token := request.cookies.get("access_token"): + if access_token := request.cookies.get(TT.access_token.name): pipe.setex( - access_token, settings.access_token_cookie_expiration_time, "blacklist" + access_token, + settings.access_token_cookie_expiration_time, + settings.jwt_blacklist_name, ) - if refresh_token := request.cookies.get("refresh_token"): + if refresh_token := request.cookies.get(TT.refresh_token.name): pipe.setex( - refresh_token, settings.refresh_token_cookie_expiration_time, "blacklist" + refresh_token, + settings.refresh_token_cookie_expiration_time, + settings.jwt_blacklist_name, ) await pipe.execute() diff --git a/src/app/user/errors.py b/src/app/user/errors.py index 56291d5..0ca9186 100644 --- a/src/app/user/errors.py +++ b/src/app/user/errors.py @@ -1,7 +1,7 @@ -from typing import TYPE_CHECKING, cast +from collections.abc import Callable, Coroutine +from typing import TYPE_CHECKING, Any, cast -from fastapi import Request, status -from fastapi.exceptions import RequestValidationError +from fastapi import Request, Response, status from fastapi.responses import JSONResponse from sqlalchemy.exc import IntegrityError @@ -20,24 +20,10 @@ async def db_integrity_error_handler( return JSONResponse(client_message, status.HTTP_409_CONFLICT) -async def validation_error_handler( - _: Request, - e: RequestValidationError, -) -> JSONResponse: - """Pydantic validation error handler.""" - error = e.errors()[0] - field = error["loc"][1] - error_message = f"{error['msg']}. Field: {field}" - client_message = {"detail": error_message} - return JSONResponse(client_message, status.HTTP_422_UNPROCESSABLE_ENTITY) - - -async def unexpected_error_handler( - request: Request, - e: Exception, -) -> JSONResponse: - """Error handler for all uncaught exceptions.""" - logger = cast("Logger", request.app.state.logger) - logger.critical("Internal Server Error: %s", e) - client_message = {"error": "Service is temporarily unavailable"} - return JSONResponse(client_message, status.HTTP_500_INTERNAL_SERVER_ERROR) +# error handlers mapping to be registered in the main FastAPI app +error_handlers: dict[ + int | type[Exception], + Callable[[Request, Any], Coroutine[Any, Any, Response]], +] = { + IntegrityError: db_integrity_error_handler, +} diff --git a/src/app/user/middlewares.py b/src/app/user/middlewares.py deleted file mode 100644 index e5e29fe..0000000 --- a/src/app/user/middlewares.py +++ /dev/null @@ -1,54 +0,0 @@ -from fastapi.openapi.models import OAuthFlowPassword, OAuthFlows -from fastapi.security import OAuth2 -from fastapi.security.utils import get_authorization_scheme_param -from starlette.requests import Request - -from app.main.settings import settings - -from .exceptions import AuthenticationError - - -class OAuth2PasswordBearerWithCookie(OAuth2): - """OAuth2 flow for authentication based on cookies.""" - - def __init__( - self, - tokenUrl: str, # noqa: N803 - scheme_name: str | None = None, - scopes: dict[str, str] | None = None, - auto_error: bool = True, - ) -> None: - """Init OAuth flow.""" - flows = OAuthFlows( - password=OAuthFlowPassword(tokenUrl=tokenUrl, scopes=scopes or {}) - ) - super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error) - - async def __call__(self, request: Request) -> str | None: - """Validate JWT token taken from the access or refresh token. - - If access token is missing (e.g. cookie is expired), then - refresh token will be validated. - - if both tokens are missing, 'AuthenticationError' will be raised - (user must re-login). - - :param request: request object providing access to cookies - :raises AuthenticationError: if tokens are invalid or corrupted - :return: access or refresh token - """ - # Check if access token is valid - access_token_cookie = request.cookies.get("access_token") - access_token_scheme, access_token = get_authorization_scheme_param( - access_token_cookie - ) - if access_token or access_token_scheme == settings.token_scheme: - return access_token - # Check if refresh token is valid - refresh_token_cookie = request.cookies.get("refresh_token") - refresh_token_scheme, refresh_token = get_authorization_scheme_param( - refresh_token_cookie - ) - if refresh_token or refresh_token_scheme == settings.token_scheme: - return refresh_token - raise AuthenticationError diff --git a/src/app/user/models.py b/src/app/user/models.py index bae32e8..55319c4 100644 --- a/src/app/user/models.py +++ b/src/app/user/models.py @@ -1,20 +1,11 @@ from datetime import datetime -from typing import TYPE_CHECKING -from sqlalchemy import Boolean, func, select +from sqlalchemy import Boolean, func from sqlalchemy.orm import Mapped, mapped_column from app.main.db import Base -from .auth import generate_password_hash from .constants import UserRole -from .exceptions import UserNotFoundError -from .schemas import UserInDB - -if TYPE_CHECKING: - from sqlalchemy.ext.asyncio import AsyncSession - - from .schemas import NewUser class User(Base): @@ -35,50 +26,3 @@ class User(Base): server_default=func.now(), onupdate=func.now(), ) - - @classmethod - async def get(cls, db: "AsyncSession", username: str) -> UserInDB: - """Return info about user. - - :param db: DB session - :param username: username of the user - :raises UserNotFoundError: if the user doesn't exist in the database - :return: info about user - """ - cols = ( - cls.first_name, - cls.last_name, - cls.username, - cls.email, - cls.password, - cls.is_active, - cls.role, - ) - query = select(*cols).where(cls.username == username) - if user := (await db.execute(query)).one_or_none(): - return UserInDB.model_construct(**user._asdict()) # pyright: ignore[reportPrivateUsage] - raise UserNotFoundError - - @classmethod - async def create( - cls, - db: "AsyncSession", - user: "NewUser", - **kwargs: str, - ) -> "NewUser": - """Save info about new user into the database. - - :param db: DB session - :param user: info about new user - :raises UserAlreadyExistsError: if such user already exists - :return: info about user from the database - """ - db.add( - cls( - **user.model_dump(exclude={"password"}), - password=generate_password_hash(user.password), - **kwargs, - ) - ) - await db.commit() - return user diff --git a/src/app/user/routes.py b/src/app/user/routes.py index db99f2e..e1a831a 100644 --- a/src/app/user/routes.py +++ b/src/app/user/routes.py @@ -2,19 +2,18 @@ from fastapi import APIRouter, Depends, Header, Response, status -from app.main.dependencies import DbSession - -from .auth import generate_token, verify_access -from .constants import CookieType, TokenType +from .constants import TT from .dependencies import ( - add_tokens_to_blacklist, - authenticate_user, - decode_token, - verify_token_against_blacklist, + UserServiceT, + UserT, + blacklist_jwt, + decode_jwt, + is_jwt_blacklisted, ) -from .exceptions import InactiveUserError, PermissionDenied -from .models import User as DBUser -from .schemas import NewUser, Token, TokenPayload, User, UserCookie, UserInDB +from .exceptions import PermissionDenied +from .schemas import NewUser, Token, TokenPayload, User, UserCookie +from .security.rbac import verify_access +from .security.tokens import generate_token router = APIRouter(tags=["user"]) @@ -23,53 +22,52 @@ "/logout", status_code=status.HTTP_205_RESET_CONTENT, response_class=Response, - dependencies=[Depends(add_tokens_to_blacklist)], + dependencies=[Depends(blacklist_jwt)], ) async def logout(response: Response) -> None: - """Log out the authenticated user. + """Log out authenticated user. Once the user is logged out, the authorization cookies are deleted and the tokens are blacklisted. """ - for cookie in CookieType: - response.delete_cookie(cookie.name) + for token_type in TT: + response.delete_cookie(token_type.name) -@router.post("/login") -async def login( - response: Response, - user: Annotated[UserInDB, Depends(authenticate_user)], -) -> Token: - """Authenticate the user generating authorization cookies.""" - if not user.is_active: - raise InactiveUserError - access_token = generate_token(TokenType.access, user.username, user.role) - refresh_token = generate_token(TokenType.refresh, user.username, user.role) - access_token_cookie = UserCookie(key=CookieType.access_token, value=access_token) - refresh_token_cookie = UserCookie(key=CookieType.refresh_token, value=refresh_token) - response.set_cookie(**access_token_cookie.model_dump()) - response.set_cookie(**refresh_token_cookie.model_dump()) - return Token.model_construct(access_token=access_token, refresh_token=refresh_token) +@router.post("/login", response_model=Token) +async def login(response: Response, user: UserT) -> dict[str, str]: + """Authenticate user generating authorization cookies.""" + tokens: dict[str, str] = {} + for token_type in TT: + token = generate_token(token_type, user.username, user.role) + cookie = UserCookie(key=token_type, value=token) + response.set_cookie(**cookie.model_dump()) + tokens[token_type.name] = token + return tokens -@router.post("/signup", response_model=User) -async def signup(user: NewUser, db: DbSession) -> NewUser: +@router.post( + "/signup", + response_model=User, + status_code=status.HTTP_201_CREATED, +) +async def signup(user: NewUser, service: UserServiceT) -> NewUser: """Sign up a user.""" - return await DBUser.create(db, user) + return await service.create(user) @router.get( "/auth", - dependencies=[Depends(verify_token_against_blacklist)], + dependencies=[Depends(is_jwt_blacklisted)], status_code=status.HTTP_204_NO_CONTENT, include_in_schema=False, ) async def auth( response: Response, - payload: Annotated[TokenPayload, Depends(decode_token)], + payload: Annotated[TokenPayload, Depends(decode_jwt)], x_original_uri: Annotated[str | None, Header()] = None, ) -> None: - """Verify authorization tokens (access or refresh). + """Verify authorization cookies. Note: 1) it's used by Nginx Subrequest module to allow/disallow @@ -82,7 +80,7 @@ async def auth( """ if x_original_uri and not verify_access(payload.role, x_original_uri): raise PermissionDenied - if payload.typ == TokenType.refresh: - access_token = generate_token(TokenType.access, payload.sub, payload.role) - access_token_cookie = UserCookie(key=CookieType.access_token, value=access_token) - response.set_cookie(**access_token_cookie.model_dump()) + if payload.typ == TT.refresh_token: + token = generate_token(TT.access_token, payload.sub, payload.role) + cookie = UserCookie(key=TT.access_token, value=token) + response.set_cookie(**cookie.model_dump()) diff --git a/src/app/user/schemas.py b/src/app/user/schemas.py index 38942de..1a106ed 100644 --- a/src/app/user/schemas.py +++ b/src/app/user/schemas.py @@ -15,7 +15,7 @@ from app.main.settings import settings -from .constants import CookieType, TokenType +from .constants import TT from .exceptions import PasswordsDontMatchError @@ -26,50 +26,50 @@ class BaseCustomModel(BaseModel): class Token(BaseCustomModel): - """Schema to provide info about access and refresh JWT tokens.""" + """Schema to provide info about access and refresh JWTs.""" - access_token: str = Field(description="JWT access token") - refresh_token: str = Field(description="JWT refresh token") + access_token: str = Field(description="access JWT") + refresh_token: str = Field(description="refresh JWT") class TokenPayload(BaseModel): - """Schema to provide info about payload of a JWT token.""" + """Schema to provide info about payload of a JWT.""" - sub: str = Field(description="Subject of the token (user's username)") - typ: TokenType = Field(description="Type of the token (access or refresh)") + sub: str = Field(description="Subject of the JWT (user's username)") + typ: TT = Field(description="Type of the JWT (access or refresh)") role: str = Field(description="Role of the user (admin, user, etc.)") jti: str = Field( default_factory=lambda: str(uuid4()), - description="JWT ID (unique identifier of the token)", + description="JWT ID (unique identifier of the JWT)", ) exp: datetime = Field( default_factory=lambda: datetime.now(UTC), validate_default=True, - description="Expiration date and time of the token", + description="Expiration date and time of the JWT", ) @field_validator("exp") @classmethod def set_exp(cls, now: datetime, info: ValidationInfo) -> datetime: - """Set expiration date of the token based on its type. + """Set expiration date and time of the JWT based on its type. :param now: value of the attribute :param info: all schema values - :return: expiration date of the token based on its type + :return: expiration date and time of the JWT based on its type """ exp = ( settings.access_token_expiration_time - if info.data["typ"] == TokenType.access + if info.data["typ"] == TT.access_token else settings.refresh_token_expiration_time ) return now + timedelta(minutes=exp) class UserCookie(BaseModel): - """Schema to set a cookie to store a JWT token.""" + """Schema to set a cookie to store a JWT.""" - key: CookieType | str = Field(description="Key (name) of the cookie") - value: str = Field(description="Value of the cookie (JWT token)") + key: TT | str = Field(description="Key (name) of the cookie") + value: str = Field(description="Value of the cookie (JWT)") domain: str = Field( default=settings.host_server_domain, description="Domain of the server the cookie is associated with", @@ -89,7 +89,7 @@ class UserCookie(BaseModel): @field_validator("key") @classmethod - def key_to_str(cls, key: CookieType) -> str: + def key_to_str(cls, key: TT) -> str: """Convert key from Enum to str. :param key: enum key @@ -100,15 +100,15 @@ def key_to_str(cls, key: CookieType) -> str: @field_validator("expires", "max_age", mode="before") @classmethod def set_cookie_expiration_time(cls, _: None, info: ValidationInfo) -> int: - """Set expiration time of the cookie (in seconds) based on its type. + """Set expiration date and time of the cookie (in seconds) based on its type. :param _: value of the attribute :param info: all schema values - :return: expiration time of the cookie based on its type + :return: expiration date and time of the cookie based on its type """ return ( settings.access_token_cookie_expiration_time - if info.data["key"] == CookieType.access_token + if info.data["key"] == TT.access_token else settings.refresh_token_cookie_expiration_time ) @@ -145,6 +145,12 @@ class NewUser(BaseUser): password: str = PasswordField repeat_password: str = PasswordField + @property + def credentials(self) -> dict[str, str]: + """Read-only property returning user credentials for login.""" + return {"username": self.username, "password": self.password} + + # del user.repeat_password @model_validator(mode="after") def check_passwords_match(self) -> Self: """Check if the original and repeated passwords match.""" diff --git a/src/app/user/security/__init__.py b/src/app/user/security/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/app/user/security/passwords.py b/src/app/user/security/passwords.py new file mode 100644 index 0000000..2df0948 --- /dev/null +++ b/src/app/user/security/passwords.py @@ -0,0 +1,34 @@ +from functools import lru_cache + +from passlib.context import CryptContext + + +@lru_cache +def get_pwd_context() -> CryptContext: + """Return cached helper for hashing & verifying passwords. + + :return: helper for hashing & verifying passwords + """ + return CryptContext(schemes=["bcrypt"], deprecated="auto") + + +pwd_context = get_pwd_context() + + +def generate_password_hash(plain_password: str) -> str: + """Generate hash of the plain password. + + :param plain_password: plain password + :return: hash of the password + """ + return pwd_context.hash(plain_password) + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """Verify plain and hashed password. + + :param plain_password: plain password + :param hashed_password: hashed password + :return: boolean result of the verification + """ + return pwd_context.verify(plain_password, hashed_password) diff --git a/src/app/user/security/rbac.py b/src/app/user/security/rbac.py new file mode 100644 index 0000000..4862f33 --- /dev/null +++ b/src/app/user/security/rbac.py @@ -0,0 +1,21 @@ +import re +from functools import lru_cache +from typing import TYPE_CHECKING + +from app.main.settings import settings + +if TYPE_CHECKING: + from app.user.constants import UserRole + + +@lru_cache +def verify_access(role: "UserRole", url: str) -> bool: + """Verify if the user can access the URL according to one's role. + + :param role: role of the user + :param url: target URL to check access to + :return: boolean result of the verification + """ + return any( + re.match(pattern, url) for pattern in settings.rbac_policy[role]["locations"] + ) diff --git a/src/app/user/security/tokens.py b/src/app/user/security/tokens.py new file mode 100644 index 0000000..780cfcd --- /dev/null +++ b/src/app/user/security/tokens.py @@ -0,0 +1,46 @@ +from functools import lru_cache +from typing import TYPE_CHECKING + +from jose import JWTError, jwt + +from app.main.settings import settings +from app.user.exceptions import InvalidCredentialsError +from app.user.schemas import TokenPayload + +if TYPE_CHECKING: + from app.user.constants import TT + + +def generate_token(token_type: "TT", username: str, role: str) -> str: + """Generate JWT of the specified types and claims. + + :param token_type: type of the token (access or refresh) + :param username: username to generate the token for + :param role: role of the user (user, admin, etc.) + :return: JWT token + """ + payload = TokenPayload(sub=username, typ=token_type, role=role) + return jwt.encode( + claims=payload.model_dump(), + key=settings.jwt_secret_key, + algorithm=settings.jwt_algorithm, + ) + + +@lru_cache(maxsize=settings.token_payload_max_cache_hits) +def get_token_payload(token: str) -> TokenPayload: + """Decode the JWT token and return its payload. + + :param token: JWT token + :raises InvalidCredentialsError: if the token can't be decoded + :return: payload of the token + """ + try: + payload = jwt.decode( + token, + key=settings.jwt_secret_key, + algorithms=[settings.jwt_algorithm], + ) + except JWTError as e: + raise InvalidCredentialsError from e + return TokenPayload.model_construct(**payload) diff --git a/src/app/user/services/__init__.py b/src/app/user/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/app/user/services/users.py b/src/app/user/services/users.py new file mode 100644 index 0000000..3b45dbf --- /dev/null +++ b/src/app/user/services/users.py @@ -0,0 +1,39 @@ +from typing import TYPE_CHECKING + +from sqlmodel import select + +from app.user.exceptions import UserNotFoundError +from app.user.models import User +from app.user.schemas import UserInDB +from app.user.security.passwords import generate_password_hash + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + + from app.user.schemas import NewUser + + +class UsersDbService: + """Service for managing users in the database.""" + + def __init__(self, session: "AsyncSession") -> None: + """Define session object to perform database operations.""" + self.session = session + + async def get(self, username: str) -> UserInDB: + """Retrieve a user by username.""" + stmt = select(User).where(User.username == username) + if user := (await self.session.execute(stmt)).scalar(): + return UserInDB(**user.__dict__) + raise UserNotFoundError + + async def create(self, user: "NewUser", **kwargs: str) -> "NewUser": + """Save info about user into the database.""" + db_user = User( + **user.model_dump(exclude={"password"}), + password=generate_password_hash(user.password), + **kwargs, + ) + self.session.add(db_user) + await self.session.commit() + return user diff --git a/src/scripts/create_user.py b/src/scripts/create_user.py index 4ffaf01..d1df9ae 100644 --- a/src/scripts/create_user.py +++ b/src/scripts/create_user.py @@ -4,8 +4,8 @@ from app.main.db import async_session from app.user.constants import UserRole -from app.user.models import User from app.user.schemas import NewUser +from app.user.services.users import UsersDbService app = typer.Typer() @@ -36,8 +36,9 @@ async def save_user(user: NewUser, role: UserRole) -> None: :param user: info about user to be saved :return: None """ - async with async_session() as db: - await User.create(db, user, role=role.name) + async with async_session() as session: + service = UsersDbService(session) + await service.create(user, role=role.name) if __name__ == "__main__": diff --git a/src/tests/app/admin/test_views.py b/src/tests/app/admin/test_views.py index eadaab7..505c629 100644 --- a/src/tests/app/admin/test_views.py +++ b/src/tests/app/admin/test_views.py @@ -4,8 +4,6 @@ from fastapi import status from httpx import AsyncClient -from tests.conftest import E2E_MODE_DISABLED - if TYPE_CHECKING: from httpx import AsyncClient @@ -23,13 +21,13 @@ async def test_admin_users_index(client: "AsyncClient") -> None: assert '