From dd1fe08b6f1c1dad8fb3bbfced2371a0a2281534 Mon Sep 17 00:00:00 2001 From: Code With Me Date: Thu, 26 Feb 2026 18:27:53 +0300 Subject: [PATCH 1/4] test: add load testing infrastructure & harden critical paths --- .pre-commit-config.yaml | 2 +- Makefile | 22 +++++ app/core/security.py | 15 ++- app/services/inventory/models.py | 14 ++- app/services/inventory/routes.py | 6 +- app/services/inventory/schemas.py | 2 +- app/services/user/service.py | 4 +- docker-compose.yaml | 1 + locust_base.py | 49 ++++++++++ locustfile.py | 30 ++++++ locustfile_oversell.py | 30 ++++++ ...443a9c1a0a_make_datetime_timezone_aware.py | 93 +++++++++++++++++++ pyproject.toml | 2 + scripts/seed_oversell_product.py | 54 +++++++++++ tests/conftest.py | 41 ++++++++ tests/test_concurrency_inventory.py | 53 +++++++++++ 16 files changed, 405 insertions(+), 13 deletions(-) create mode 100644 Makefile create mode 100644 locust_base.py create mode 100644 locustfile.py create mode 100644 locustfile_oversell.py create mode 100644 migrations/versions/2e443a9c1a0a_make_datetime_timezone_aware.py create mode 100644 scripts/seed_oversell_product.py create mode 100644 tests/conftest.py create mode 100644 tests/test_concurrency_inventory.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6b7d6d2..15a5829 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: hooks: - id: pytest name: pytest - entry: pytest + entry: uv run pytest language: system pass_filenames: false args: [tests/] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..9e85708 --- /dev/null +++ b/Makefile @@ -0,0 +1,22 @@ +HOST ?= http://localhost:8080 + +.PHONY: stress-test oversell-test + +stress-test: + uv run locust \ + -f locustfile.py \ + --headless \ + --users 500 \ + --spawn-rate 100 \ + --run-time 60s \ + --host $(HOST) + +oversell-test: + DB_HOST=localhost uv run python scripts/seed_oversell_product.py + uv run locust \ + -f locustfile_oversell.py \ + --headless \ + --users 100 \ + --spawn-rate 50 \ + --run-time 60s \ + --host $(HOST) diff --git a/app/core/security.py b/app/core/security.py index 1b666e6..3dcca39 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -1,3 +1,4 @@ +import asyncio from datetime import UTC, datetime, timedelta from jose import jwt @@ -8,14 +9,24 @@ pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto') -def verify_password(plain_password: str, hashed_password: str) -> bool: +def verify_password_sync(plain_password: str, hashed_password: str) -> bool: return bool(pwd_context.verify(plain_password, hashed_password)) -def get_password_hash(password: str) -> str: +async def verify_password(plain_password: str, hashed_password: str) -> bool: + return await asyncio.to_thread( + verify_password_sync, plain_password, hashed_password + ) + + +def get_password_hash_sync(password: str) -> str: return str(pwd_context.hash(password)) +async def get_password_hash(password: str) -> str: + return await asyncio.to_thread(get_password_hash_sync, password) + + def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str: to_encode = data.copy() if expires_delta: diff --git a/app/services/inventory/models.py b/app/services/inventory/models.py index 3fe6919..00a9e04 100644 --- a/app/services/inventory/models.py +++ b/app/services/inventory/models.py @@ -25,9 +25,11 @@ class Product(Base): default=Decimal('0.10'), ) qty_available: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now()) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) updated_at: Mapped[datetime] = mapped_column( - DateTime, server_default=func.now(), onupdate=func.now() + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) __table_args__ = ( @@ -48,5 +50,9 @@ class Reservation(Base): order_id: Mapped[UUID | None] = mapped_column( ForeignKey('orders.id'), nullable=True ) - expires_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now()) + expires_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) diff --git a/app/services/inventory/routes.py b/app/services/inventory/routes.py index 78a0755..705db4a 100644 --- a/app/services/inventory/routes.py +++ b/app/services/inventory/routes.py @@ -2,7 +2,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core.database import get_session -from app.services.inventory.models import Reservation from app.services.inventory.rate_limit import check_rate_limit from app.services.inventory.schemas import ReservationCreate, ReservationResponse from app.services.inventory.service import reserve_items @@ -21,15 +20,16 @@ async def reservation_data( x_idempotency_key: str = Header(...), session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), -) -> Reservation: +) -> ReservationResponse: await check_rate_limit( rate_limit_script=request.app.state.rate_limit_script, user_id=str(current_user.id), item_id=str(reservation_data.product_id), ) - return await reserve_items( + result = await reserve_items( session=session, user_id=current_user.id, idempotency_key=x_idempotency_key, reservation_data=reservation_data, ) + return ReservationResponse.model_validate(result) diff --git a/app/services/inventory/schemas.py b/app/services/inventory/schemas.py index a85efd4..717221e 100644 --- a/app/services/inventory/schemas.py +++ b/app/services/inventory/schemas.py @@ -14,6 +14,6 @@ class ReservationResponse(BaseModel): id: UUID product_id: UUID user_id: UUID - quantity: int + quantity: int = Field(validation_alias='qty_reserved') status: str expires_at: datetime diff --git a/app/services/user/service.py b/app/services/user/service.py index 081ad1a..c4f1a1a 100644 --- a/app/services/user/service.py +++ b/app/services/user/service.py @@ -23,7 +23,7 @@ async def create_user(session: AsyncSession, user_create: UserCreate) -> User: ) if result.scalar_one_or_none(): raise UserAlreadyExists - hashed_password = get_password_hash(user_create.password) + hashed_password = await get_password_hash(user_create.password) user = User(email=user_create.email, password_hash=hashed_password) session.add(user) await session.commit() @@ -36,7 +36,7 @@ async def authenticate_user( ) -> User | None: result = await session.execute(select(User).where(User.email == email)) user = result.scalar_one_or_none() - if not user or not verify_password(password, user.password_hash): + if not user or not await verify_password(password, user.password_hash): return None return user diff --git a/docker-compose.yaml b/docker-compose.yaml index 76d8318..a1291fa 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -11,6 +11,7 @@ services: postgres: image: postgres:15-alpine container_name: ${DB_HOST} + command: postgres -c max_connections=300 # switch to pgbouncer available healthcheck: test: ['CMD', 'pg_isready', '-d', '${POSTGRES_DB}', '-U', '${POSTGRES_USER}'] interval: 10s diff --git a/locust_base.py b/locust_base.py new file mode 100644 index 0000000..2188647 --- /dev/null +++ b/locust_base.py @@ -0,0 +1,49 @@ +from http import HTTPStatus +from uuid import uuid4 + +from locust import HttpUser + + +class BaseUser(HttpUser): + """ + Base class for all load testing scenarios. + Contains registration and authorization logic. + Descendants define wait_time and @task methods. + """ + + abstract = True + + def on_start(self) -> None: + self.access_token: str | None = None + email = f'locust_{uuid4()}@mail.com' + password = '12345' + + with self.client.post( + '/api/v1/users', + json={'email': email, 'password': password}, + catch_response=True, + ) as reg_res: + if reg_res.status_code not in ( + HTTPStatus.OK, + HTTPStatus.CREATED, + HTTPStatus.BAD_REQUEST, + ): + reg_res.failure(f'Registration failed: {reg_res.status_code}') + return + + with self.client.post( + '/api/v1/auth/token', + data={'username': email, 'password': password}, + catch_response=True, + ) as token_res: + if token_res.status_code == HTTPStatus.OK: + self.access_token = token_res.json().get('access_token') + else: + token_res.failure(f'Login failed: {token_res.status_code}') + + @property + def auth_headers(self) -> dict[str, str]: + return { + 'Authorization': f'Bearer {self.access_token}', + 'X-Idempotency-Key': str(uuid4()), + } diff --git a/locustfile.py b/locustfile.py new file mode 100644 index 0000000..30653f6 --- /dev/null +++ b/locustfile.py @@ -0,0 +1,30 @@ +from http import HTTPStatus + +from locust import between, task + +from locust_base import BaseUser + +TARGET_PRODUCT_ID = '5995fa75-07c7-4b55-82b7-6bfbb52948b8' + + +class HighLoadUser(BaseUser): + wait_time = between(0.5, 2.0) + + @task + def reserve_product(self) -> None: + if not self.access_token: + return + with self.client.post( + '/api/v1/inventory/reserve', + headers=self.auth_headers, + json={'product_id': TARGET_PRODUCT_ID, 'quantity': 1}, + catch_response=True, + ) as reserve_res: + if reserve_res.status_code not in ( + HTTPStatus.OK, + HTTPStatus.CREATED, + HTTPStatus.CONFLICT, + HTTPStatus.TOO_MANY_REQUESTS, + HTTPStatus.BAD_REQUEST, + ): + reserve_res.failure(f'Reserve failed: {reserve_res.status_code}') diff --git a/locustfile_oversell.py b/locustfile_oversell.py new file mode 100644 index 0000000..5ad38f1 --- /dev/null +++ b/locustfile_oversell.py @@ -0,0 +1,30 @@ +from http import HTTPStatus + +from locust import between, task + +from locust_base import BaseUser + +OVERSELL_PRODUCT_ID = '3fe44185-589a-4703-b640-40df8d7ea67f' + + +class OversellTestUser(BaseUser): + wait_time = between(0.1, 0.5) + + @task + def reserve_oversell_product(self) -> None: + if not self.access_token: + return + with self.client.post( + '/api/v1/inventory/reserve', + headers=self.auth_headers, + json={'product_id': OVERSELL_PRODUCT_ID, 'quantity': 1}, + catch_response=True, + ) as reserve_res: + if reserve_res.status_code not in ( + HTTPStatus.OK, + HTTPStatus.CREATED, + HTTPStatus.CONFLICT, + HTTPStatus.TOO_MANY_REQUESTS, + HTTPStatus.BAD_REQUEST, + ): + reserve_res.failure(f'Reserve failed: {reserve_res.status_code}') diff --git a/migrations/versions/2e443a9c1a0a_make_datetime_timezone_aware.py b/migrations/versions/2e443a9c1a0a_make_datetime_timezone_aware.py new file mode 100644 index 0000000..359d3bb --- /dev/null +++ b/migrations/versions/2e443a9c1a0a_make_datetime_timezone_aware.py @@ -0,0 +1,93 @@ +"""make datetime timezone aware + +Revision ID: 2e443a9c1a0a +Revises: 9ae77f428203 +Create Date: 2026-02-26 13:33:16.250487 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '2e443a9c1a0a' +down_revision: str | Sequence[str] | None = '9ae77f428203' +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + 'products', + 'created_at', + existing_type=postgresql.TIMESTAMP(), + type_=sa.DateTime(timezone=True), + existing_nullable=False, + existing_server_default=sa.text('now()'), + ) + op.alter_column( + 'products', + 'updated_at', + existing_type=postgresql.TIMESTAMP(), + type_=sa.DateTime(timezone=True), + existing_nullable=False, + existing_server_default=sa.text('now()'), + ) + op.alter_column( + 'reservations', + 'expires_at', + existing_type=postgresql.TIMESTAMP(), + type_=sa.DateTime(timezone=True), + existing_nullable=False, + ) + op.alter_column( + 'reservations', + 'created_at', + existing_type=postgresql.TIMESTAMP(), + type_=sa.DateTime(timezone=True), + existing_nullable=False, + existing_server_default=sa.text('now()'), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + 'reservations', + 'created_at', + existing_type=sa.DateTime(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_nullable=False, + existing_server_default=sa.text('now()'), + ) + op.alter_column( + 'reservations', + 'expires_at', + existing_type=sa.DateTime(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_nullable=False, + ) + op.alter_column( + 'products', + 'updated_at', + existing_type=sa.DateTime(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_nullable=False, + existing_server_default=sa.text('now()'), + ) + op.alter_column( + 'products', + 'created_at', + existing_type=sa.DateTime(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_nullable=False, + existing_server_default=sa.text('now()'), + ) + # ### end Alembic commands ### diff --git a/pyproject.toml b/pyproject.toml index 8b9f59f..0fbdeb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,8 @@ python_files = ['test_*.py'] python_classes = ['Test*'] python_functions = ['test_*'] pythonpath = ['.'] +asyncio_mode = 'auto' +asyncio_default_fixture_loop_scope = 'session' [tool.mypy] python_version = '3.11' diff --git a/scripts/seed_oversell_product.py b/scripts/seed_oversell_product.py new file mode 100644 index 0000000..62e0f77 --- /dev/null +++ b/scripts/seed_oversell_product.py @@ -0,0 +1,54 @@ +""" +Script for creating/resetting a test product before OversellTestUser test. +Usage: + uv run python scripts/seed_oversell_product.py +""" + +import asyncio +import sys +from decimal import Decimal +from pathlib import Path +from uuid import UUID + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + +from app.core.config import settings +from app.services.inventory.models import Product + +OVERSELL_PRODUCT_ID = UUID('3fe44185-589a-4703-b640-40df8d7ea67f') +INITIAL_QTY = 50 + + +async def seed() -> None: + engine = create_async_engine(url=str(settings.database_url), echo=False) + session_factory = async_sessionmaker(bind=engine, expire_on_commit=False) + + async with session_factory() as session: + result = await session.execute( + select(Product).where(Product.id == OVERSELL_PRODUCT_ID) + ) + product = result.scalar_one_or_none() + + if product: + product.qty_available = INITIAL_QTY + print(f'Product found, qty_available reset to {INITIAL_QTY}') + else: + product = Product( + id=OVERSELL_PRODUCT_ID, + name='Oversell Test Product', + qty_available=INITIAL_QTY, + price=Decimal('1.00'), + ) + session.add(product) + print(f'Created new product with qty_available={INITIAL_QTY}') + + await session.commit() + print(f'Done. id={OVERSELL_PRODUCT_ID}, qty={INITIAL_QTY}') + + await engine.dispose() + + +asyncio.run(seed()) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..0150e63 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,41 @@ +import os + +os.environ.setdefault('DB_HOST', 'localhost') + +from collections.abc import AsyncGenerator + +import pytest_asyncio +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.pool import NullPool + +from app.core.config import settings + + +@pytest_asyncio.fixture +async def db_engine() -> AsyncGenerator[AsyncEngine, None]: + url = str(settings.database_url) + engine = create_async_engine(url, echo=True, poolclass=NullPool) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def db_session_factory( + db_engine: AsyncEngine, +) -> async_sessionmaker[AsyncSession]: + return async_sessionmaker( + bind=db_engine, class_=AsyncSession, expire_on_commit=False + ) + + +@pytest_asyncio.fixture +async def db_session( + db_session_factory: async_sessionmaker[AsyncSession], +) -> AsyncGenerator[AsyncSession, None]: + async with db_session_factory() as session: + yield session diff --git a/tests/test_concurrency_inventory.py b/tests/test_concurrency_inventory.py new file mode 100644 index 0000000..fc73d5e --- /dev/null +++ b/tests/test_concurrency_inventory.py @@ -0,0 +1,53 @@ +import asyncio +from uuid import uuid4 + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from app.core.exceptions import InsufficientInventoryError +from app.services.inventory.models import Product +from app.services.inventory.schemas import ReservationCreate +from app.services.inventory.service import reserve_items +from app.services.orders.models import Order # noqa: F401 +from app.services.user.models import User + + +@pytest.mark.asyncio +async def test_concurrent_reservations_service_level( + db_session_factory: async_sessionmaker[AsyncSession], +) -> None: + async with db_session_factory() as setup_session: + user = User(id=uuid4(), email=f'test_{uuid4()}@mail.com', password_hash='foo') + product = Product( + id=uuid4(), name='Test Sneakers', price=100.0, qty_available=10 + ) + setup_session.add(user) + setup_session.add(product) + await setup_session.commit() + user_id = user.id + product_id = product.id + concurrency_level = 50 + + async def worker() -> bool | Exception: + async with db_session_factory() as session: + request = ReservationCreate(product_id=product_id, quantity=1) + idempotency_key = str(uuid4()) + try: + await reserve_items(session, user_id, idempotency_key, request) + return True + except InsufficientInventoryError: + return False + except Exception as e: + return e + + start_workers = await asyncio.gather(*(worker() for _ in range(concurrency_level))) + success_count = sum(1 for r in start_workers if r is True) + fail_count = sum(1 for r in start_workers if r is False) + error_count = sum(1 for r in start_workers if isinstance(r, Exception)) + assert success_count == 10 + assert fail_count == 40 + assert error_count == 0 + async with db_session_factory() as session: + final_product = await session.get(Product, product_id) + assert final_product is not None + assert final_product.qty_available == 0 From 0985b3aa8f67745d83a1e68b5e22577a0ce9038f Mon Sep 17 00:00:00 2001 From: Code With Me Date: Thu, 26 Feb 2026 18:39:21 +0300 Subject: [PATCH 2/4] test: add load testing infrastructure & harden critical paths --- .env | 41 +++++++++++++++++++++++++++++++++++++++ .github/workflows/ci.yaml | 21 +++++++++++++++++++- .gitignore | 1 - 3 files changed, 61 insertions(+), 2 deletions(-) create mode 100644 .env diff --git a/.env b/.env new file mode 100644 index 0000000..d29a73e --- /dev/null +++ b/.env @@ -0,0 +1,41 @@ +#PSG_sql +DB_PORT=5432 +POSTGRES_DB=fairdrop_db +POSTGRES_USER=fairdrop_user +POSTGRES_PASSWORD=fairdrop_password +DB_HOST=db_fairdrop +#REDIS +REDIS_PREFIX=redis:// +REDIS_HOST=redis_fairdrop +REDIS_PORT=6379 +REDIS_URL=${REDIS_PREFIX}${REDIS_HOST}:${REDIS_PORT} +#s3 +S3_HOST=s3-fairdrop +S3_PORT=9000 +MINIO_ROOT_USER=s3_fairdrop_user +MINIO_ROOT_PASSWORD=s3_fairdrop_password +MINIO_BUCKET_NAME=s3_fairdrop-media +MINIO_URL=http://${S3_HOST}:${S3_PORT} +PRESIGNED_URL_EXPIRE_SECONDS=3600 +MIN_FILE_SIZE_BYTES=1 +MAX_FILE_SIZE_BYTES=5242880 +#gateway +GATEWAY_HOST=gateway_fairdrop +GATEWAY_PORT=8080 +#app +APP_PORT=8000 +#debug +DEBUG_MODE=True +#jwt +SECRET_KEY=change_me_super_secret +ACCESS_TOKEN_EXPIRE_MINUTES=30 +REFRESH_TOKEN_EXPIRE_DAYS=7 +#lua limiter +RATE_LIMIT_USER_RPS=10 +RATE_LIMIT_GLOBAL_RPS=1000 +RATE_LIMIT_TTL_SECONDS=1 +IDEMPOTENT_KEY_LIFETIME_SEC=86400 +RESERVE_TIMEOUT_MINUTES=15 +#db_engine_layer +POOL_SIZE=50 +MAX_OVERFLOW=100 \ No newline at end of file diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 6a85409..bb99a5d 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -9,13 +9,32 @@ on: jobs: build: runs-on: ubuntu-latest + + services: + postgres: + image: postgres:15-alpine + env: + POSTGRES_USER: fairdrop_user + POSTGRES_PASSWORD: fairdrop_pass + POSTGRES_DB: fairdrop_db + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + env: + DB_HOST: localhost # override docker container name from .env + steps: - name: Checkout uses: actions/checkout@v4 - name: Set up UV (without pip) uses: astral-sh/setup-uv@v5 - with: + with: enable-cache: true - name: Install dependencies diff --git a/.gitignore b/.gitignore index 181a5ee..858c163 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,3 @@ wheels/ .vscode/ # Local env -.env From 351d0c1944604d17171fad7ce329cad875f98138 Mon Sep 17 00:00:00 2001 From: Code With Me Date: Thu, 26 Feb 2026 18:45:29 +0300 Subject: [PATCH 3/4] test: add load testing infrastructure & harden critical paths --- .github/workflows/ci.yaml | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index bb99a5d..9bfd3ba 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -9,13 +9,15 @@ on: jobs: build: runs-on: ubuntu-latest - +# ===== +# Hardcode will be replaced with secrets GitHub Repo +# ===== services: postgres: image: postgres:15-alpine env: POSTGRES_USER: fairdrop_user - POSTGRES_PASSWORD: fairdrop_pass + POSTGRES_PASSWORD: fairdrop_password POSTGRES_DB: fairdrop_db ports: - 5432:5432 @@ -27,7 +29,9 @@ jobs: env: DB_HOST: localhost # override docker container name from .env - +# ===== +# Hardcode will be replaced with secrets GitHub Repo +# ===== steps: - name: Checkout uses: actions/checkout@v4 From 366f60cf5efeb2fca3c971400efe2fdd601850a0 Mon Sep 17 00:00:00 2001 From: Code With Me Date: Thu, 26 Feb 2026 19:00:10 +0300 Subject: [PATCH 4/4] test: add load testing infrastructure & harden critical paths --- tests/conftest.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 0150e63..09000f5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,4 @@ import os - -os.environ.setdefault('DB_HOST', 'localhost') - from collections.abc import AsyncGenerator import pytest_asyncio @@ -13,13 +10,31 @@ ) from sqlalchemy.pool import NullPool +import app.services.inventory.models # noqa: F401 +import app.services.orders.models # noqa: F401 +import app.services.user.models # noqa: F401 from app.core.config import settings +from app.core.database import Base + + +def _test_db_url() -> str: + """ + Build test DB URL with DB_HOST override. + settings.database_url may contain docker container name (db_fairdrop). + Tests run locally or in CI where postgres is on localhost. + """ + host = os.environ.get('DB_HOST', 'localhost') + return ( + f'postgresql+asyncpg://{settings.db_user}:{settings.db_password}' + f'@{host}:{settings.db_port}/{settings.db_name}' + ) @pytest_asyncio.fixture async def db_engine() -> AsyncGenerator[AsyncEngine, None]: - url = str(settings.database_url) - engine = create_async_engine(url, echo=True, poolclass=NullPool) + engine = create_async_engine(_test_db_url(), echo=True, poolclass=NullPool) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) yield engine await engine.dispose()