diff --git a/.env.example b/.env.example index 5b52a4d..1051262 100644 --- a/.env.example +++ b/.env.example @@ -1,51 +1,64 @@ -APP_NAME=Specification Generator -ENV=development -DEBUG=False -VERSION=1.0.0 +APP_NAME= +ENV= +DEBUG= +VERSION= -HOST=0.0.0.0 -PORT=8000 +HOST= +PORT= +FASTAPI_HOST_PORT= +FASTAPI_CONTAINER_PORT= +REDIS_HOST_PORT= +REDIS_CONTAINER_PORT= DATABASE_URL=postgresql://:@:5432/?sslmode=disable -SECRET_KEY= -ACCESS_TOKEN_EXPIRE_MINUTES=60 -ACCESS_TOKEN_MINUTE_IN_SECONDS=60 +SECRET_KEY= +ACCESS_TOKEN_EXPIRE_MINUTES= +ACCESS_TOKEN_MINUTE_IN_SECONDS= -AUTH_PREFIX=/auth -AUTH_ME_PATH=/me -AUTH_TAG=auth -AUTH_BOOTSTRAP_ENABLED=False -AUTH_BOOTSTRAP_SUPERUSER=True +AUTH_PREFIX= +AUTH_ME_PATH= +AUTH_TAG= +AUTH_BOOTSTRAP_ENABLED= +AUTH_BOOTSTRAP_SUPERUSER= AUTH_USERNAME= AUTH_EMAIL= AUTH_PASSWORD= -AUTH_PASSWORD_HASH= +AUTH_PASSWORD_HASH= -ALLOWED_ORIGINS=["*"] -SECURITY_TRUSTED_HOSTS=["*"] -SECURITY_ENABLE_HTTPS_REDIRECT=False -SECURITY_CSP=default-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self' -SECURITY_REFERRER_POLICY=strict-origin-when-cross-origin -SECURITY_CORS_ALLOW_METHODS=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"] -SECURITY_CORS_ALLOW_HEADERS=["Authorization", "Content-Type", "Accept", "Origin", "X-Requested-With", "X-Request-ID"] -SECURITY_REQUEST_ID_HEADER=X-Request-ID -SECURITY_LOG_SUSPICIOUS=True -SECURITY_RATE_LIMIT_ENABLED=True -SECURITY_RATE_LIMIT_REQUESTS=120 -SECURITY_RATE_LIMIT_WINDOW_SECONDS=60 -SECURITY_RATE_LIMIT_PATHS=["/api/v1/auth"] +ALLOWED_ORIGINS= +SECURITY_TRUSTED_HOSTS= +SECURITY_ENABLE_HTTPS_REDIRECT= +SECURITY_CSP= +SECURITY_REFERRER_POLICY= +SECURITY_CORS_ALLOW_METHODS= +SECURITY_CORS_ALLOW_HEADERS= +SECURITY_REQUEST_ID_HEADER= +SECURITY_LOG_SUSPICIOUS= +SECURITY_RATE_LIMIT_ENABLED= +SECURITY_RATE_LIMIT_REQUESTS= +SECURITY_RATE_LIMIT_WINDOW_SECONDS= +SECURITY_RATE_LIMIT_PATHS= -OLLAMA_BASE_URL=http://localhost:11434 -OLLAMA_MODEL=mistral -OLLAMA_TIMEOUT=120 -OLLAMA_SYSTEM_PROMPT=You are a helpful assistant that generates software specifications. +OLLAMA_BASE_URL= +OLLAMA_MODEL= +OLLAMA_TIMEOUT= +OLLAMA_CONNECT_TIMEOUT= +OLLAMA_MAX_RETRIES= +OLLAMA_RETRY_BACKOFF_SECONDS= +OLLAMA_SYSTEM_PROMPT= +LLM_PROMPT_MAX_LENGTH= +FEATURE_SPEC_HISTORY_DEFAULT_LIMIT= +CELERY_BROKER_URL= +CELERY_RESULT_BACKEND= +CELERY_TASK_MAX_RETRIES= +CELERY_TASK_RETRY_BASE_SECONDS= TEST_DB_HOST= -TEST_DB_PORT=5432 +TEST_DB_PORT= TEST_DB_NAME= TEST_DB_USER= TEST_DB_PASSWORD= -TEST_DB_SSLMODE=disable +TEST_DB_SSLMODE= TEST_DEFAULT_USERNAME= TEST_DEFAULT_EMAIL= TEST_DEFAULT_HASHED_PASSWORD= diff --git a/Dockerfile b/Dockerfile index 957ca20..f44b397 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,7 +17,7 @@ RUN sed -i 's/\r$//' /app/entrypoint.sh \ USER appuser -EXPOSE 8000 +EXPOSE 8001 ENTRYPOINT ["/app/entrypoint.sh"] -CMD ["python", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file +CMD ["python", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8001"] \ No newline at end of file diff --git a/README.md b/README.md index 0650822..c6fe7fe 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,12 @@
-# Specification Generator API +# Feature Specification Generator API -Production-ready FastAPI backend for authentication, LLM-powered specification generation, health checks, and secure API middleware baseline. +Production-ready FastAPI backend with async LLM generation via Celery + Redis, JWT auth, and Docker Compose orchestration. + +

+ FastAPI +

python @@ -11,6 +15,9 @@ Production-ready FastAPI backend for authentication, LLM-powered specification g alembic tests docs + celery + redis + ollama

@@ -19,7 +26,9 @@ Production-ready FastAPI backend for authentication, LLM-powered specification g - JWT authentication based on fastapi-users - User registration and user management endpoints -- LLM generation endpoint with multiple response formats (text, sections, json) +- Async feature specification generation with Celery tasks +- Redis as Celery broker/result backend +- Ollama integration for LLM responses - Readiness and health probes for runtime checks - Alembic database migrations - Security middleware baseline: @@ -38,10 +47,20 @@ Production-ready FastAPI backend for authentication, LLM-powered specification g - SQLAlchemy 2.0 - Alembic - fastapi-users -- PostgreSQL (via DATABASE_URL) -- Ollama (LLM provider) +- Celery +- Redis +- PostgreSQL +- Ollama - Pytest +## Architecture (Async Flow) + +1. Client calls `POST /api/v1/feature-spec/generate`. +2. API stores a run row and enqueues Celery task to Redis. +3. API returns immediately with `task_id` and `processing` status. +4. Celery worker calls Ollama and persists success/error in DB. +5. Client polls `GET /api/v1/feature-spec/tasks/{task_id}` for result. + ## Project Structure - app/main.py: FastAPI app bootstrap and router registration @@ -49,10 +68,11 @@ Production-ready FastAPI backend for authentication, LLM-powered specification g - app/api/: health, readiness, OpenAPI customization - app/middlewares/: security middleware composition and implementations - app/modules/auth/: auth domain (models, schemas, dependencies, router) -- app/modules/llm/: LLM API, schemas, providers -- app/scripts/: utility scripts (admin bootstrap) +- app/modules/feature_spec/: API layer + application services for feature spec +- app/infrastructure/: Celery app, task workers, Ollama client +- app/scripts/: utility scripts (admin and prompt/model bootstrap) - alembic/: migration config and versions -- docker-compose.yml: containerized app run +- docker-compose.yml: app + celery-worker + redis + ollama ## Setup Guide @@ -60,8 +80,7 @@ Production-ready FastAPI backend for authentication, LLM-powered specification g - Python 3.10+ - PostgreSQL database -- Optional: Docker + Docker Compose -- Optional: local Ollama instance for LLM generation +- Docker + Docker Compose (recommended) ### 1) Configure environment @@ -71,6 +90,8 @@ Required minimum: - DATABASE_URL - SECRET_KEY +- CELERY_BROKER_URL +- CELERY_RESULT_BACKEND Recommended auth bootstrap values: @@ -82,6 +103,20 @@ LLM values: - OLLAMA_BASE_URL - OLLAMA_MODEL +- OLLAMA_TIMEOUT + +Compose ports (host -> container): + +- FASTAPI_HOST_PORT=8005 +- FASTAPI_CONTAINER_PORT=8001 +- REDIS_HOST_PORT=6380 +- REDIS_CONTAINER_PORT=6379 + +For Docker Compose in this project use: + +- OLLAMA_BASE_URL=http://ollama:11434 +- CELERY_BROKER_URL=redis://redis:6379/0 +- CELERY_RESULT_BACKEND=redis://redis:6379/1 ### 2A) Run locally @@ -94,32 +129,58 @@ source venv/bin/activate pip install --upgrade pip pip install -r requirements.txt +pip install -r requirements-dev.txt python -m alembic upgrade head python -m app.scripts.bootstrap_admin -python -m uvicorn app.main:app --host 0.0.0.0 --port 8000 +python -m app.scripts.bootstrap_prompt_template +# terminal 1: api +python -m uvicorn app.main:app --host 0.0.0.0 --port 8005 + +# terminal 2: worker +celery -A app.infrastructure.celery_app:celery_app worker --loglevel=INFO ``` ### 2B) Run with Docker ```bash -docker compose up --build +docker compose up -d --build ``` Notes: - Container entrypoint automatically runs: - - alembic upgrade head - - admin bootstrap script + - migration + DB head check (`python -m app.scripts.migrate_and_check`) + - admin bootstrap script (`python -m app.scripts.bootstrap_admin`) + - prompt template bootstrap script (`python -m app.scripts.bootstrap_prompt_template`) + - Ollama model bootstrap (`python -m app.scripts.ensure_ollama_model`) - uvicorn app startup +- Celery worker runs in a dedicated container (`celery-worker`). +- Redis runs only in Docker Compose and is used internally by service name `redis`. +- On first deploy, startup may take longer while the configured `OLLAMA_MODEL` is downloaded. +- FastAPI container reaches Ollama via internal Docker network URL: http://ollama:11434 + +Verify services: + +```bash +docker compose ps +docker compose logs -f app +docker compose logs -f celery-worker +``` ## API Docs -When server is running: +When server is running locally (custom port): + +- Swagger UI: http://localhost:8005/docs +- ReDoc: http://localhost:8005/redoc (pinned ReDoc 2.x script) +- OpenAPI JSON: http://localhost:8005/openapi.json -- Swagger UI: http://localhost:8000/docs -- ReDoc: http://localhost:8000/redoc -- OpenAPI JSON: http://localhost:8000/openapi.json +For Docker Compose deployment (default host mapping): + +- Swagger UI: http://localhost:8005/docs +- ReDoc: http://localhost:8005/redoc +- OpenAPI JSON: http://localhost:8005/openapi.json ## API Endpoints @@ -133,6 +194,7 @@ Base API prefix: /api/v1 ### Auth (fastapi-users) - POST /api/v1/auth/jwt/login +- POST /api/v1/auth/jwt/refresh - POST /api/v1/auth/jwt/logout - POST /api/v1/auth/register - GET /api/v1/auth/users/me @@ -141,16 +203,65 @@ Base API prefix: /api/v1 - PATCH /api/v1/auth/users/{id} - DELETE /api/v1/auth/users/{id} -### LLM +Login note: + +- Endpoint `/api/v1/auth/jwt/login` uses `application/x-www-form-urlencoded`. +- In form field `username`, pass only `AUTH_USERNAME`. +- Login response returns `access_token` in body and sets HttpOnly refresh cookie. +- Use `/api/v1/auth/jwt/refresh` to get a new access token and refresh cookie. +- `/api/v1/auth/jwt/logout` clears refresh cookie on client side. +- Swagger `Authorize` value must contain only raw JWT token (without `Bearer ` prefix). + +### Feature Spec -- POST /api/v1/llm/generate +- POST /api/v1/feature-spec/generate +- GET /api/v1/feature-spec/tasks/{task_id} +- GET /api/v1/feature-spec/history?limit=10 -Request body example: +Generate request: ```json { - "prompt": "Generate API specification for user profile module", - "response_format": "sections" + "feature_idea": "payment for premium posts" +} +``` + +Generate response: + +```json +{ + "task_id": "3b44daff-1e83-4328-925f-62c22a9163d2", + "status": "processing" +} +``` + +Task status response examples: + +```json +{ + "task_id": "3b44daff-1e83-4328-925f-62c22a9163d2", + "status": "PENDING" +} +``` + +```json +{ + "task_id": "3b44daff-1e83-4328-925f-62c22a9163d2", + "status": "SUCCESS", + "result": { + "run_id": 10, + "status": "success", + "feature_idea": "payment for premium posts", + "feature_summary": { + "user_stories": [], + "acceptance_criteria": [], + "db_models_and_api_endpoints": { + "db_models": [], + "api_endpoints": [] + }, + "risk_assessment": [] + } + } } ``` @@ -159,7 +270,7 @@ Request body example: Run linter: ```bash -python -m flake8 +python -m flake8 . ``` Run auth unit tests: @@ -194,7 +305,20 @@ If app cannot connect to DB: - Verify DATABASE_URL - Verify DB network access and sslmode if needed -If LLM requests fail: +If Celery tasks stay in `PENDING`: + +- Check worker is healthy: `docker compose ps` +- Check worker logs: `docker compose logs -f celery-worker` +- Verify Redis URLs in `.env` point to `redis` service inside Docker network + +If LLM requests fail or timeout: - Verify OLLAMA_BASE_URL - Ensure Ollama is running and model is available +- For Docker deployment, ensure OLLAMA_BASE_URL is http://ollama:11434 +- Increase `OLLAMA_TIMEOUT` for long generations + +If compose prints `variable is not set` for random token-like names: + +- Your `.env` likely has `$` inside secret values +- Escape `$` as `$$` in `.env` values used by Docker Compose diff --git a/alembic/env.py b/alembic/env.py index affedde..a97b863 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -1,15 +1,27 @@ import logging from logging.config import fileConfig +from importlib import import_module +import pkgutil -from sqlalchemy import engine_from_config, pool, text +from sqlalchemy import engine_from_config, pool from alembic import context -from importlib import import_module from app.core.database import Base, _normalize_database_url from app.core.settings import settings -import_module("app.modules.auth.models") + +def _import_all_models() -> None: + modules_pkg = import_module("app.modules") + for module in pkgutil.iter_modules(modules_pkg.__path__, "app.modules."): + model_module_name = f"{module.name}.models" + try: + import_module(model_module_name) + except ModuleNotFoundError: + continue + + +_import_all_models() config = context.config logger = logging.getLogger("alembic.runtime.migration") @@ -22,35 +34,6 @@ target_metadata = Base.metadata -def _ensure_version_num_capacity(connection) -> None: - """Prevent failures when custom revision IDs are longer than 32 chars.""" - result = connection.execute( - text( - """ - SELECT character_maximum_length - FROM information_schema.columns - WHERE table_schema = current_schema() - AND table_name = 'alembic_version' - AND column_name = 'version_num' - """ - ) - ).scalar_one_or_none() - - if result is None: - return - - if result < 255: - logger.warning( - "Expanding alembic_version.version_num from %s to 255 chars for revision-id safety", - result, - ) - connection.execute( - text( - "ALTER TABLE alembic_version ALTER COLUMN version_num TYPE VARCHAR(255)" - ) - ) - - def run_migrations_offline() -> None: url = config.get_main_option("sqlalchemy.url") context.configure( @@ -73,7 +56,6 @@ def run_migrations_online() -> None: ) with connectable.connect() as connection: - _ensure_version_num_capacity(connection) context.configure( connection=connection, target_metadata=target_metadata, diff --git a/alembic/versions/2e5bd3a1eb88_initial.py b/alembic/versions/2e5bd3a1eb88_initial.py deleted file mode 100644 index 9b120c8..0000000 --- a/alembic/versions/2e5bd3a1eb88_initial.py +++ /dev/null @@ -1,45 +0,0 @@ -"""initial - -Revision ID: 2e5bd3a1eb88 -Revises: -Create Date: 2026-04-11 10:30:05.432174 -""" - -import sqlalchemy as sa - -from alembic import op - -revision = "2e5bd3a1eb88" -down_revision = None -branch_labels = None -depends_on = None - - -def upgrade() -> None: - op.create_table( - "users", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("username", sa.String(length=100), nullable=False), - sa.Column("hashed_password", sa.String(length=1024), nullable=False), - sa.Column("is_active", sa.Boolean(), nullable=False), - sa.Column("is_superuser", sa.Boolean(), nullable=False), - sa.Column("is_verified", sa.Boolean(), nullable=False), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.text("now()"), - nullable=False, - ), - sa.Column("email", sa.String(length=320), nullable=False), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True) - op.create_index(op.f("ix_users_id"), "users", ["id"], unique=False) - op.create_index(op.f("ix_users_username"), "users", ["username"], unique=True) - - -def downgrade() -> None: - op.drop_index(op.f("ix_users_username"), table_name="users") - op.drop_index(op.f("ix_users_id"), table_name="users") - op.drop_index(op.f("ix_users_email"), table_name="users") - op.drop_table("users") diff --git a/app/admin/__init__.py b/app/admin/__init__.py new file mode 100644 index 0000000..c24bd4b --- /dev/null +++ b/app/admin/__init__.py @@ -0,0 +1,3 @@ +from app.admin.setup import setup_admin + +__all__ = ["setup_admin"] diff --git a/app/admin/auth.py b/app/admin/auth.py new file mode 100644 index 0000000..d6cd6e6 --- /dev/null +++ b/app/admin/auth.py @@ -0,0 +1,64 @@ +from fastapi.security import OAuth2PasswordRequestForm +from sqlalchemy import select +from sqladmin.authentication import AuthenticationBackend +from starlette.requests import Request + +from app.core.database import AsyncSessionLocal +from app.core.settings import settings +from app.modules.auth.infrastructure.fastapi_users_adapter import ( + UserManager, + UsernameAwareUserDatabase, +) +from app.modules.auth.models import User + + +class AdminAuthBackend(AuthenticationBackend): + def __init__(self) -> None: + super().__init__(secret_key=settings.SECRET_KEY) + + async def login(self, request: Request) -> bool: + form = await request.form() + username = str(form.get("username", "")).strip() + password = str(form.get("password", "")).strip() + if not username or not password: + return False + + credentials = OAuth2PasswordRequestForm( + username=username, + password=password, + scope="", + client_id=None, + client_secret=None, + ) + + async with AsyncSessionLocal() as session: + user_db = UsernameAwareUserDatabase(session, User) + user_manager = UserManager(user_db) + user = await user_manager.authenticate(credentials) + + if user is None or not user.is_active or not user.is_superuser: + return False + + request.session.update({"admin_user_id": user.id}) + return True + + async def logout(self, request: Request) -> bool: + request.session.clear() + return True + + async def authenticate(self, request: Request) -> bool: + user_id = request.session.get("admin_user_id") + if user_id is None: + return False + + async with AsyncSessionLocal() as session: + statement = select(User).where(User.id == int(user_id)) + user = await session.scalar(statement) + + if user is None or not user.is_active or not user.is_superuser: + request.session.clear() + return False + return True + + +admin_auth_backend = AdminAuthBackend() diff --git a/app/admin/setup.py b/app/admin/setup.py new file mode 100644 index 0000000..dfd0b61 --- /dev/null +++ b/app/admin/setup.py @@ -0,0 +1,24 @@ +from fastapi import FastAPI +from sqladmin import Admin + +from app.admin.auth import admin_auth_backend +from app.admin.views import ( + FeatureSpecRunAdmin, + PromptTemplateAdmin, + RefreshTokenAdmin, + UserAdmin, +) +from app.core.database import engine + + +def setup_admin(app: FastAPI) -> None: + admin = Admin( + app=app, + engine=engine, + authentication_backend=admin_auth_backend, + title="Specification Generator Admin", + ) + admin.add_view(UserAdmin) + admin.add_view(RefreshTokenAdmin) + admin.add_view(PromptTemplateAdmin) + admin.add_view(FeatureSpecRunAdmin) diff --git a/app/admin/views.py b/app/admin/views.py new file mode 100644 index 0000000..3848b64 --- /dev/null +++ b/app/admin/views.py @@ -0,0 +1,93 @@ +from sqladmin import ModelView + +from app.modules.auth.models import RefreshToken, User +from app.modules.feature_spec.models import FeatureSpecRun, PromptTemplate + + +class UserAdmin(ModelView, model=User): + name = "User" + name_plural = "Users" + icon = "fa-solid fa-user" + + column_list = [ + User.id, + User.username, + User.email, + User.is_active, + User.is_superuser, + User.is_verified, + User.created_at, + ] + column_searchable_list = [User.username, User.email] + column_filters = [User.is_active, User.is_superuser, User.is_verified, User.created_at] + column_sortable_list = [User.id, User.username, User.email, User.created_at] + + form_excluded_columns = [User.hashed_password, User.created_at] + can_create = False + can_delete = False + + +class RefreshTokenAdmin(ModelView, model=RefreshToken): + name = "Refresh Token" + name_plural = "Refresh Tokens" + icon = "fa-solid fa-key" + + column_list = [ + RefreshToken.id, + RefreshToken.user_id, + RefreshToken.expires_at, + RefreshToken.created_at, + RefreshToken.revoked_at, + ] + column_searchable_list = [RefreshToken.user_id] + column_filters = [RefreshToken.expires_at, RefreshToken.created_at, RefreshToken.revoked_at] + column_sortable_list = [RefreshToken.id, RefreshToken.user_id, RefreshToken.expires_at] + + form_excluded_columns = [RefreshToken.token_hash, RefreshToken.created_at] + can_create = False + can_edit = False + + +class PromptTemplateAdmin(ModelView, model=PromptTemplate): + name = "Prompt Template" + name_plural = "Prompt Templates" + icon = "fa-solid fa-file-lines" + + column_list = [ + PromptTemplate.id, + PromptTemplate.is_active, + PromptTemplate.updated_at, + PromptTemplate.feature_to_feature_summary, + ] + column_searchable_list = [PromptTemplate.feature_to_feature_summary] + column_filters = [PromptTemplate.is_active, PromptTemplate.updated_at] + column_sortable_list = [PromptTemplate.id, PromptTemplate.updated_at] + + can_create = False + can_delete = False + + +class FeatureSpecRunAdmin(ModelView, model=FeatureSpecRun): + name = "Feature Spec Run" + name_plural = "Feature Spec Runs" + icon = "fa-solid fa-wand-magic-sparkles" + + column_list = [ + FeatureSpecRun.id, + FeatureSpecRun.user_id, + FeatureSpecRun.status, + FeatureSpecRun.feature_idea, + FeatureSpecRun.created_at, + FeatureSpecRun.updated_at, + ] + column_searchable_list = [FeatureSpecRun.feature_idea, FeatureSpecRun.status] + column_filters = [FeatureSpecRun.status, FeatureSpecRun.created_at, FeatureSpecRun.updated_at] + column_sortable_list = [ + FeatureSpecRun.id, + FeatureSpecRun.user_id, + FeatureSpecRun.status, + FeatureSpecRun.created_at, + ] + + can_create = False + can_delete = False diff --git a/app/api/openapi.py b/app/api/openapi.py index d9350a8..1de66dc 100644 --- a/app/api/openapi.py +++ b/app/api/openapi.py @@ -1,4 +1,5 @@ from fastapi import FastAPI +from fastapi.openapi.docs import get_redoc_html from fastapi.openapi.utils import get_openapi @@ -43,3 +44,15 @@ def custom_openapi() -> dict: return app.openapi_schema app.openapi = custom_openapi + + +def configure_redoc_route(app: FastAPI) -> None: + @app.get("/redoc", include_in_schema=False) + async def redoc_html(): + return get_redoc_html( + openapi_url=app.openapi_url, + title=f"{app.title} - ReDoc", + redoc_js_url=( + "https://cdn.jsdelivr.net/npm/redoc@2.1.5/bundles/redoc.standalone.js" + ), + ) diff --git a/app/core/bootstrap.py b/app/core/bootstrap.py deleted file mode 100644 index 021592f..0000000 --- a/app/core/bootstrap.py +++ /dev/null @@ -1,78 +0,0 @@ -import logging - -from fastapi_users.password import PasswordHelper -from sqlalchemy.orm import Session - -from app.core.settings import settings -from app.modules.auth.models import User - -logger = logging.getLogger(__name__) - - -def _resolve_bootstrap_hashed_password() -> str: - helper = PasswordHelper() - if settings.AUTH_PASSWORD_HASH: - return settings.AUTH_PASSWORD_HASH - if not settings.AUTH_PASSWORD: - raise RuntimeError( - "AUTH_PASSWORD must be set when AUTH_PASSWORD_HASH is not provided" - ) - return helper.hash(settings.AUTH_PASSWORD) - - -def _is_bootstrap_enabled() -> bool: - if settings.ENV.lower() == "production": - return settings.AUTH_BOOTSTRAP_ENABLED - return settings.AUTH_BOOTSTRAP_ENABLED or bool(settings.AUTH_EMAIL) - - -def bootstrap_auth(db: Session) -> None: - if not _is_bootstrap_enabled(): - logger.info("Auth bootstrap is disabled") - return - - if not settings.AUTH_EMAIL or not settings.AUTH_USERNAME: - raise RuntimeError( - "AUTH_EMAIL and AUTH_USERNAME must be set when bootstrap is enabled" - ) - - existing_user = ( - db.query(User).filter(User.email == settings.AUTH_EMAIL).one_or_none() - ) - if existing_user is not None: - try: - if not existing_user.hashed_password: - existing_user.hashed_password = _resolve_bootstrap_hashed_password() - db.add(existing_user) - db.commit() - except Exception: - logger.exception( - "Failed to ensure bootstrap state for existing user: email=%s", - settings.AUTH_EMAIL, - ) - db.rollback() - raise - return - - try: - user = User( - username=settings.AUTH_USERNAME, - email=settings.AUTH_EMAIL, - hashed_password=_resolve_bootstrap_hashed_password(), - is_active=True, - is_superuser=settings.AUTH_BOOTSTRAP_SUPERUSER, - is_verified=True, - ) - db.add(user) - db.commit() - except Exception: - logger.exception( - "Failed to create bootstrap user: email=%s", - settings.AUTH_EMAIL, - ) - db.rollback() - raise - - -def ensure_default_user(db: Session) -> None: - bootstrap_auth(db) diff --git a/app/core/settings.py b/app/core/settings.py index 613ae2e..ab3fe72 100644 --- a/app/core/settings.py +++ b/app/core/settings.py @@ -27,13 +27,21 @@ class Settings(BaseSettings): AUTH_EMAIL: str | None = None AUTH_PASSWORD: str | None = None AUTH_PASSWORD_HASH: str | None = None + AUTH_USERNAME_MIN_LENGTH: int = 3 + AUTH_USERNAME_MAX_LENGTH: int = 100 + AUTH_HASHED_PASSWORD_MAX_LENGTH: int = 1024 + AUTH_REFRESH_TOKEN_HASH_LENGTH: int = 128 + AUTH_DAY_IN_SECONDS: int = 24 * 60 * 60 + AUTH_REFRESH_TOKEN_EXPIRE_DAYS: int = 14 + AUTH_REFRESH_TOKEN_EXPIRE_SECONDS: int = 14 * 24 * 60 * 60 + AUTH_REFRESH_COOKIE_NAME: str = "refresh_token" + AUTH_REFRESH_COOKIE_SECURE: bool = True + AUTH_REFRESH_COOKIE_SAMESITE: str = "lax" ALLOWED_ORIGINS: list[str] = ["*"] SECURITY_TRUSTED_HOSTS: list[str] = ["*"] SECURITY_ENABLE_HTTPS_REDIRECT: bool = False - SECURITY_CSP: str = ( - "default-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" - ) + SECURITY_CSP: str = "" SECURITY_REFERRER_POLICY: str = "strict-origin-when-cross-origin" SECURITY_CORS_ALLOW_METHODS: list[str] = [ "GET", @@ -60,13 +68,23 @@ class Settings(BaseSettings): OLLAMA_BASE_URL: str = "http://localhost:11434" OLLAMA_MODEL: str = "mistral" - OLLAMA_TIMEOUT: int = 120 + OLLAMA_TIMEOUT: int = 180 + OLLAMA_CONNECT_TIMEOUT: int = 10 + OLLAMA_MAX_RETRIES: int = 2 + OLLAMA_RETRY_BACKOFF_SECONDS: float = 1.0 OLLAMA_SYSTEM_PROMPT: str = ( "You are a helpful assistant that generates software specifications." ) LLM_PREFIX: str = "/llm" LLM_GENERATE_PATH: str = "/generate" LLM_TAG: str = "llm" + LLM_PROMPT_MAX_LENGTH: int = 8000 + FEATURE_SPEC_HISTORY_DEFAULT_LIMIT: int = 10 + + CELERY_BROKER_URL: str = "redis://redis:6379/0" + CELERY_RESULT_BACKEND: str = "redis://redis:6379/1" + CELERY_TASK_MAX_RETRIES: int = 3 + CELERY_TASK_RETRY_BASE_SECONDS: int = 2 model_config = SettingsConfigDict( env_file=".env", diff --git a/app/infrastructure/__init__.py b/app/infrastructure/__init__.py new file mode 100644 index 0000000..2182630 --- /dev/null +++ b/app/infrastructure/__init__.py @@ -0,0 +1 @@ +"""Infrastructure layer modules.""" diff --git a/app/infrastructure/celery_app.py b/app/infrastructure/celery_app.py new file mode 100644 index 0000000..bc5b403 --- /dev/null +++ b/app/infrastructure/celery_app.py @@ -0,0 +1,21 @@ +from celery import Celery + +from app.core.settings import settings + + +celery_app = Celery( + "specification_generator", + broker=settings.CELERY_BROKER_URL, + backend=settings.CELERY_RESULT_BACKEND, + include=["app.infrastructure.tasks.feature_spec_tasks"], +) + +celery_app.conf.update( + broker_connection_retry_on_startup=True, + task_track_started=True, + task_serializer="json", + result_serializer="json", + accept_content=["json"], + timezone="UTC", + enable_utc=True, +) diff --git a/app/infrastructure/ollama_client.py b/app/infrastructure/ollama_client.py new file mode 100644 index 0000000..35b0046 --- /dev/null +++ b/app/infrastructure/ollama_client.py @@ -0,0 +1,39 @@ +import httpx + +from app.core.settings import settings + + +class OllamaSyncClient: + def __init__(self) -> None: + self._base_url = settings.OLLAMA_BASE_URL.rstrip("/") + self._model = settings.OLLAMA_MODEL + self._timeout = httpx.Timeout( + float(settings.OLLAMA_TIMEOUT), + connect=float(settings.OLLAMA_CONNECT_TIMEOUT), + ) + self._system_prompt = settings.OLLAMA_SYSTEM_PROMPT + + def _build_payload(self, user_prompt: str) -> dict: + messages = [] + if self._system_prompt.strip(): + messages.append({"role": "system", "content": self._system_prompt}) + messages.append({"role": "user", "content": user_prompt}) + return { + "model": self._model, + "stream": False, + "format": "json", + "messages": messages, + } + + def generate(self, user_prompt: str) -> str: + url = f"{self._base_url}/api/chat" + payload = self._build_payload(user_prompt) + with httpx.Client(timeout=self._timeout) as client: + response = client.post(url, json=payload, timeout=self._timeout) + response.raise_for_status() + data = response.json() + content = data.get("message", {}).get("content", "") + return content if isinstance(content, str) else str(content) + + +ollama_sync_client = OllamaSyncClient() diff --git a/app/infrastructure/tasks/__init__.py b/app/infrastructure/tasks/__init__.py new file mode 100644 index 0000000..27a2a9f --- /dev/null +++ b/app/infrastructure/tasks/__init__.py @@ -0,0 +1 @@ +"""Celery task modules.""" diff --git a/app/infrastructure/tasks/feature_spec_tasks.py b/app/infrastructure/tasks/feature_spec_tasks.py new file mode 100644 index 0000000..bd39ff1 --- /dev/null +++ b/app/infrastructure/tasks/feature_spec_tasks.py @@ -0,0 +1,84 @@ +import httpx +from celery.exceptions import MaxRetriesExceededError + +from app.core.database import SessionLocal +from app.core.settings import settings +from app.infrastructure.celery_app import celery_app +from app.infrastructure.ollama_client import ollama_sync_client +from app.modules.feature_spec.models import FeatureSpecRun +from app.modules.feature_spec.prompts.feature_summary import ( + build_feature_summary_prompt_from_db, + parse_feature_summary_response, +) +from app.modules.feature_spec.schemas import FeatureSummaryResult + + +def _ensure_auth_models_loaded() -> None: + from app.modules.auth import models as auth_models + + _ = auth_models.User + + +@celery_app.task(bind=True, name="feature_spec.generate") +def generate_feature_spec_task(self, run_id: int, feature_idea: str, user_id: int) -> dict: + _ensure_auth_models_loaded() + db = SessionLocal() + try: + run = db.get(FeatureSpecRun, run_id) + if run is None: + return {"run_id": run_id, "status": "error", "message": "Run not found"} + + if run.status == "success" and run.response_json is not None: + return { + "run_id": run.id, + "status": run.status, + "feature_idea": run.feature_idea, + "feature_summary": run.response_json, + } + + try: + prompt = build_feature_summary_prompt_from_db(feature_idea, db) + feature_summary_raw = ollama_sync_client.generate(prompt) + feature_summary_json = parse_feature_summary_response(feature_summary_raw) + feature_summary_typed = FeatureSummaryResult.model_validate(feature_summary_json) + + run.status = "success" + run.response_json = feature_summary_typed.model_dump(mode="json") + run.error_message = None + db.add(run) + db.commit() + + return { + "run_id": run.id, + "status": run.status, + "feature_idea": run.feature_idea, + "feature_summary": run.response_json, + } + except ( + httpx.TimeoutException, + httpx.RequestError, + httpx.HTTPStatusError, + ) as exc: + retry_number = int(self.request.retries) + 1 + if retry_number <= settings.CELERY_TASK_MAX_RETRIES: + delay_seconds = settings.CELERY_TASK_RETRY_BASE_SECONDS**retry_number + try: + raise self.retry(exc=exc, countdown=delay_seconds) + except MaxRetriesExceededError: + pass + + db.rollback() + run.status = "error" + run.error_message = "LLM provider request failed" + db.add(run) + db.commit() + raise RuntimeError("LLM task failed after retries") from exc + except Exception as exc: + db.rollback() + run.status = "error" + run.error_message = "Failed to process feature specification" + db.add(run) + db.commit() + raise RuntimeError("Feature specification task failed") from exc + finally: + db.close() diff --git a/app/main.py b/app/main.py index ac6f427..0543db8 100644 --- a/app/main.py +++ b/app/main.py @@ -1,19 +1,28 @@ from fastapi import FastAPI +from app.admin import setup_admin from app.api.health import router as health_router -from app.api.openapi import configure_openapi_bearer_auth +from app.api.openapi import configure_openapi_bearer_auth, configure_redoc_route from app.core.settings import settings from app.core.startup import lifespan from app.middlewares import configure_security_middlewares from app.modules.auth.router import router as auth_router -from app.modules.llm.router import router as llm_router +from app.modules.feature_spec.router import router as feature_spec_router -app = FastAPI(title=settings.APP_NAME, version=settings.VERSION, lifespan=lifespan) +app = FastAPI( + title=settings.APP_NAME, + version=settings.VERSION, + lifespan=lifespan, + redoc_url=None, +) + +setup_admin(app) configure_security_middlewares(app) app.include_router(auth_router, prefix=settings.API_V1_PREFIX) -app.include_router(llm_router, prefix=settings.API_V1_PREFIX) +app.include_router(feature_spec_router, prefix=settings.API_V1_PREFIX) app.include_router(health_router) configure_openapi_bearer_auth(app) +configure_redoc_route(app) diff --git a/app/middlewares/security/headers.py b/app/middlewares/security/headers.py index b80ee2f..8b18b04 100644 --- a/app/middlewares/security/headers.py +++ b/app/middlewares/security/headers.py @@ -1,4 +1,3 @@ -import uuid from starlette.datastructures import MutableHeaders from starlette.types import ASGIApp, Message, Receive, Scope, Send @@ -22,13 +21,6 @@ async def send_with_security_headers(message: Message) -> None: headers.setdefault("X-Frame-Options", "DENY") headers.setdefault("Referrer-Policy", settings.SECURITY_REFERRER_POLICY) - csp = settings.SECURITY_CSP - if "{nonce}" in csp: - nonce = uuid.uuid4().hex - scope["csp_nonce"] = nonce - csp = csp.replace("{nonce}", nonce) - headers.setdefault("Content-Security-Policy", csp) - if settings.ENV.lower() == "production": headers.setdefault( "Strict-Transport-Security", diff --git a/app/middlewares/security/setup.py b/app/middlewares/security/setup.py index 8ff2ace..2916aa8 100644 --- a/app/middlewares/security/setup.py +++ b/app/middlewares/security/setup.py @@ -11,6 +11,7 @@ def configure_security_middlewares(app: FastAPI) -> None: + if settings.SECURITY_ENABLE_HTTPS_REDIRECT: app.add_middleware(HTTPSRedirectMiddleware) diff --git a/app/modules/auth/infrastructure/fastapi_users_adapter.py b/app/modules/auth/infrastructure/fastapi_users_adapter.py index d862c79..17102f6 100644 --- a/app/modules/auth/infrastructure/fastapi_users_adapter.py +++ b/app/modules/auth/infrastructure/fastapi_users_adapter.py @@ -1,10 +1,13 @@ from collections.abc import AsyncGenerator +from typing import Any from fastapi import Depends +from fastapi.security import OAuth2PasswordRequestForm from fastapi_users import BaseUserManager, FastAPIUsers, IntegerIDMixin from fastapi_users.authentication import (AuthenticationBackend, BearerTransport, JWTStrategy) from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.database import get_async_db @@ -12,15 +15,45 @@ from app.modules.auth.models import User +class UsernameAwareUserDatabase(SQLAlchemyUserDatabase): + async def get_by_username(self, username: str) -> User | None: + statement = select(User).where(User.username == username) + return await self._get_user(statement) + + class UserManager(IntegerIDMixin, BaseUserManager[User, int]): reset_password_token_secret = settings.SECRET_KEY verification_token_secret = settings.SECRET_KEY + async def authenticate( + self, credentials: OAuth2PasswordRequestForm + ) -> User | None: + if (user := await self.user_db.get_by_username(credentials.username)) is None: + return None + + verified, updated_password_hash = self.password_helper.verify_and_update( + credentials.password, user.hashed_password + ) + if not verified: + return None + + if updated_password_hash is not None: + await self.user_db.update(user, {"hashed_password": updated_password_hash}) + + return user + + async def create(self, user_create: Any, safe: bool = False, request: Any = None): + if safe: + user_create = user_create.model_copy( + update={"is_superuser": False, "is_verified": False} + ) + return await super().create(user_create, safe=safe, request=request) + async def get_user_db( session: AsyncSession = Depends(get_async_db), ) -> AsyncGenerator[SQLAlchemyUserDatabase, None]: - yield SQLAlchemyUserDatabase(session, User) + yield UsernameAwareUserDatabase(session, User) async def get_user_manager( @@ -42,6 +75,13 @@ def get_jwt_strategy() -> JWTStrategy: ) +def get_refresh_jwt_strategy() -> JWTStrategy: + return JWTStrategy( + secret=settings.SECRET_KEY, + lifetime_seconds=settings.AUTH_REFRESH_TOKEN_EXPIRE_SECONDS, + ) + + auth_backend = AuthenticationBackend( name="jwt", transport=bearer_transport, diff --git a/app/modules/auth/jwt_router.py b/app/modules/auth/jwt_router.py new file mode 100644 index 0000000..f02c711 --- /dev/null +++ b/app/modules/auth/jwt_router.py @@ -0,0 +1,103 @@ +from fastapi import APIRouter, Depends, HTTPException, Request, Response, status +from fastapi.responses import JSONResponse +from fastapi.security import OAuth2PasswordRequestForm + +from app.core.settings import settings +from app.modules.auth.infrastructure.fastapi_users_adapter import ( + get_jwt_strategy, + get_refresh_jwt_strategy, + get_user_manager, +) +from app.modules.auth.service import AuthSessionService + +router = APIRouter(prefix="/jwt") + + +def _build_refresh_cookie_path() -> str: + return f"{settings.API_V1_PREFIX}{settings.AUTH_PREFIX}/jwt" + + +def _set_refresh_cookie(response: Response, refresh_token: str) -> None: + response.set_cookie( + key=settings.AUTH_REFRESH_COOKIE_NAME, + value=refresh_token, + httponly=True, + secure=settings.AUTH_REFRESH_COOKIE_SECURE, + samesite=settings.AUTH_REFRESH_COOKIE_SAMESITE, + max_age=settings.AUTH_REFRESH_TOKEN_EXPIRE_SECONDS, + path=_build_refresh_cookie_path(), + ) + + +def _clear_refresh_cookie(response: Response) -> None: + response.delete_cookie( + key=settings.AUTH_REFRESH_COOKIE_NAME, + path=_build_refresh_cookie_path(), + ) + + +async def get_auth_session_service( + user_manager=Depends(get_user_manager), +) -> AuthSessionService: + return AuthSessionService( + user_authentication=user_manager, + access_token_strategy=get_jwt_strategy(), + refresh_token_strategy=get_refresh_jwt_strategy(), + ) + + +@router.post("/login", name="auth:jwt.login") +async def login( + credentials: OAuth2PasswordRequestForm = Depends(), + auth_session_service: AuthSessionService = Depends(get_auth_session_service), +) -> JSONResponse: + tokens = await auth_session_service.login(credentials) + if tokens is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="LOGIN_BAD_CREDENTIALS", + ) + + response = JSONResponse( + {"access_token": tokens.access_token, "token_type": "bearer"} + ) + _set_refresh_cookie(response, tokens.refresh_token) + return response + + +@router.post("/refresh", name="auth:jwt.refresh") +async def refresh( + request: Request, + auth_session_service: AuthSessionService = Depends(get_auth_session_service), +) -> JSONResponse: + refresh_token = request.cookies.get(settings.AUTH_REFRESH_COOKIE_NAME) + if not refresh_token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Unauthorized", + ) + + tokens = await auth_session_service.refresh(refresh_token) + if tokens is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Unauthorized", + ) + + response = JSONResponse( + {"access_token": tokens.access_token, "token_type": "bearer"} + ) + _set_refresh_cookie(response, tokens.refresh_token) + return response + + +@router.post("/logout", name="auth:jwt.logout", status_code=status.HTTP_204_NO_CONTENT) +async def logout( + request: Request, + auth_session_service: AuthSessionService = Depends(get_auth_session_service), +) -> Response: + await auth_session_service.logout(request.cookies.get(settings.AUTH_REFRESH_COOKIE_NAME)) + + response = Response(status_code=status.HTTP_204_NO_CONTENT) + _clear_refresh_cookie(response) + return response diff --git a/app/modules/auth/models.py b/app/modules/auth/models.py index 03c0222..52c3197 100644 --- a/app/modules/auth/models.py +++ b/app/modules/auth/models.py @@ -1,10 +1,11 @@ from datetime import datetime from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTable -from sqlalchemy import Boolean, DateTime, Integer, String, func +from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, func from sqlalchemy.orm import Mapped, mapped_column from app.core.database import Base +from app.core.settings import settings class User(SQLAlchemyBaseUserTable[int], Base): @@ -12,14 +13,48 @@ class User(SQLAlchemyBaseUserTable[int], Base): id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) username: Mapped[str] = mapped_column( - String(100), unique=True, index=True, nullable=False + String(settings.AUTH_USERNAME_MAX_LENGTH), + unique=True, + index=True, + nullable=False, + ) + hashed_password: Mapped[str] = mapped_column( + String(settings.AUTH_HASHED_PASSWORD_MAX_LENGTH), + nullable=False, ) - hashed_password: Mapped[str] = mapped_column(String(1024), nullable=False) is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) is_superuser: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) - is_verified: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + is_verified: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + nullable=False, + ) + + +class RefreshToken(Base): + __tablename__ = "refresh_tokens" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) + user_id: Mapped[int] = mapped_column( + Integer, + ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + token_hash: Mapped[str] = mapped_column( + String(settings.AUTH_REFRESH_TOKEN_HASH_LENGTH), + index=True, + unique=True, + nullable=False, + ) + expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False, ) + revoked_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=True), + nullable=True, + ) diff --git a/app/modules/auth/router.py b/app/modules/auth/router.py index 29fe2fa..e52765a 100644 --- a/app/modules/auth/router.py +++ b/app/modules/auth/router.py @@ -2,12 +2,14 @@ from app.core.settings import settings from app.modules.auth.infrastructure.fastapi_users_adapter import ( - auth_backend, fastapi_users) + fastapi_users, +) +from app.modules.auth.jwt_router import router as jwt_router from app.modules.auth.schemas import UserCreate, UserRead, UserUpdate router = APIRouter(prefix=settings.AUTH_PREFIX, tags=[settings.AUTH_TAG]) -router.include_router(fastapi_users.get_auth_router(auth_backend), prefix="/jwt") +router.include_router(jwt_router) router.include_router( fastapi_users.get_register_router(UserRead, UserCreate), prefix="" ) diff --git a/app/modules/auth/schemas.py b/app/modules/auth/schemas.py index 2273fb5..96a4f36 100644 --- a/app/modules/auth/schemas.py +++ b/app/modules/auth/schemas.py @@ -1,5 +1,8 @@ + from fastapi_users import schemas -from pydantic import Field +from pydantic import ConfigDict, Field, model_validator + +from app.core.settings import settings class UserRead(schemas.BaseUser[int]): @@ -7,8 +10,35 @@ class UserRead(schemas.BaseUser[int]): class UserCreate(schemas.BaseUserCreate): - username: str = Field(min_length=3, max_length=100) + username: str = Field( + min_length=settings.AUTH_USERNAME_MIN_LENGTH, + max_length=settings.AUTH_USERNAME_MAX_LENGTH, + ) + is_superuser: bool = Field(default=False, exclude=True) + + model_config = ConfigDict(extra="forbid") + + @model_validator(mode="after") + def validate_superuser_registration(self): + if self.is_superuser: + raise ValueError("Setting is_superuser via public registration is forbidden") + return self + + @classmethod + def model_json_schema(cls, *args, **kwargs): + schema = super().model_json_schema(*args, **kwargs) + properties = schema.get("properties", {}) + properties.pop("is_superuser", None) + + required = schema.get("required") + if isinstance(required, list) and "is_superuser" in required: + required.remove("is_superuser") + return schema class UserUpdate(schemas.BaseUserUpdate): - username: str | None = Field(default=None, min_length=3, max_length=100) + username: str | None = Field( + default=None, + min_length=settings.AUTH_USERNAME_MIN_LENGTH, + max_length=settings.AUTH_USERNAME_MAX_LENGTH, + ) diff --git a/app/modules/auth/service.py b/app/modules/auth/service.py index 184038a..0590802 100644 --- a/app/modules/auth/service.py +++ b/app/modules/auth/service.py @@ -1,8 +1,35 @@ +from dataclasses import dataclass +from typing import Protocol + +from fastapi.security import OAuth2PasswordRequestForm + from app.modules.auth.errors import PermissionDeniedError from app.modules.auth.models import User from app.modules.auth.repository.user_repository import UserRepositoryPort +class UserAuthenticationPort(Protocol): + async def authenticate(self, credentials: OAuth2PasswordRequestForm) -> User | None: ... + + async def get(self, id): ... + + +class AccessTokenStrategyPort(Protocol): + async def write_token(self, user: User) -> str: ... + + +class RefreshTokenStrategyPort(Protocol): + async def write_token(self, user: User) -> str: ... + + async def read_token(self, token: str, user_manager: UserAuthenticationPort) -> User | None: ... + + +@dataclass(frozen=True) +class AuthTokens: + access_token: str + refresh_token: str + + class AuthorizationService: def ensure_superuser(self, user: User) -> User: if not user.is_superuser: @@ -40,3 +67,38 @@ def __init__(self, repository: UserRepositoryPort) -> None: async def user_exists_by_email(self, email: str) -> bool: return await self._repository.get_by_email(email) is not None + + +class AuthSessionService: + def __init__( + self, + user_authentication: UserAuthenticationPort, + access_token_strategy: AccessTokenStrategyPort, + refresh_token_strategy: RefreshTokenStrategyPort, + ) -> None: + self._user_authentication = user_authentication + self._access_token_strategy = access_token_strategy + self._refresh_token_strategy = refresh_token_strategy + + async def login(self, credentials: OAuth2PasswordRequestForm) -> AuthTokens | None: + user = await self._user_authentication.authenticate(credentials) + if user is None or not user.is_active: + return None + + access_token = await self._access_token_strategy.write_token(user) + refresh_token = await self._refresh_token_strategy.write_token(user) + return AuthTokens(access_token=access_token, refresh_token=refresh_token) + + async def refresh(self, refresh_token: str) -> AuthTokens | None: + user = await self._refresh_token_strategy.read_token( + refresh_token, self._user_authentication + ) + if user is None or not user.is_active: + return None + + access_token = await self._access_token_strategy.write_token(user) + new_refresh_token = await self._refresh_token_strategy.write_token(user) + return AuthTokens(access_token=access_token, refresh_token=new_refresh_token) + + async def logout(self, refresh_token: str | None) -> None: + return None diff --git a/app/modules/feature_spec/models.py b/app/modules/feature_spec/models.py new file mode 100644 index 0000000..b8eeca5 --- /dev/null +++ b/app/modules/feature_spec/models.py @@ -0,0 +1,61 @@ +from datetime import datetime + +from sqlalchemy import ( + JSON, + Boolean, + CheckConstraint, + DateTime, + ForeignKey, + Integer, + String, + Text, + func, +) +from sqlalchemy.orm import Mapped, mapped_column + +from app.core.database import Base + + +class PromptTemplate(Base): + __tablename__ = "prompt_templates" + __table_args__ = ( + CheckConstraint("id = 1", name="ck_prompt_templates_single_row"), + ) + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + feature_to_feature_summary: Mapped[str] = mapped_column(Text, nullable=False) + is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ) + + +class FeatureSpecRun(Base): + __tablename__ = "feature_spec_runs" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + user_id: Mapped[int] = mapped_column( + Integer, + ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + feature_idea: Mapped[str] = mapped_column(Text, nullable=False) + status: Mapped[str] = mapped_column(String(32), nullable=False, index=True) + response_json: Mapped[dict | list | None] = mapped_column(JSON, nullable=True) + error_message: Mapped[str | None] = mapped_column(Text, nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + nullable=False, + index=True, + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ) diff --git a/app/modules/feature_spec/orchestrator.py b/app/modules/feature_spec/orchestrator.py new file mode 100644 index 0000000..b3f2a30 --- /dev/null +++ b/app/modules/feature_spec/orchestrator.py @@ -0,0 +1,107 @@ +from sqlalchemy.orm import Session +from sqlalchemy import select + +from app.core.settings import settings +from app.modules.feature_spec.models import FeatureSpecRun +from app.modules.feature_spec.parser import normalize_feature_summary_payload +from app.modules.feature_spec.prompts.feature_summary import ( + build_feature_summary_prompt_from_db, + parse_feature_summary_response, +) +from app.modules.feature_spec.providers.ollama import ollama_client +from app.modules.feature_spec.schemas import ( + FeatureSummaryResult, + FeatureSpecGenerateResponse, + FeatureSpecHistoryItem, + FeatureSpecHistoryResponse, +) + + +class FeatureSpecOrchestrator: + def __init__(self, llm_client) -> None: + self.llm = llm_client + + async def generate( + self, + idea: str, + db: Session, + user_id: int, + ) -> FeatureSpecGenerateResponse: + run = FeatureSpecRun( + user_id=user_id, + feature_idea=idea, + status="pending", + ) + db.add(run) + db.commit() + db.refresh(run) + + feature_summary_prompt = build_feature_summary_prompt_from_db(idea, db) + try: + feature_summary_raw = await self.llm.generate(feature_summary_prompt) + feature_summary_json = parse_feature_summary_response(feature_summary_raw) + feature_summary_typed = FeatureSummaryResult.model_validate( + feature_summary_json + ) + + run.status = "success" + run.response_json = feature_summary_typed.model_dump(mode="json") + run.error_message = None + db.add(run) + db.commit() + except Exception as exc: + run.status = "error" + run.error_message = str(exc) + db.add(run) + db.commit() + raise + + return FeatureSpecGenerateResponse( + feature_idea=idea, + feature_summary=feature_summary_typed, + ) + + +orchestrator = FeatureSpecOrchestrator(ollama_client) + + +async def generate_feature_spec( + feature_idea: str, + db: Session, + user_id: int, +) -> FeatureSpecGenerateResponse: + return await orchestrator.generate(feature_idea, db, user_id) + + +def get_feature_spec_history( + db: Session, + user_id: int, + limit: int = settings.FEATURE_SPEC_HISTORY_DEFAULT_LIMIT, +) -> FeatureSpecHistoryResponse: + safe_limit = max(1, min(limit, 100)) + statement = ( + select(FeatureSpecRun) + .where(FeatureSpecRun.user_id == user_id) + .order_by(FeatureSpecRun.created_at.desc()) + .limit(safe_limit) + ) + rows = db.execute(statement).scalars().all() + + items = [ + FeatureSpecHistoryItem( + id=row.id, + feature_idea=row.feature_idea, + status=row.status, + response_json=( + FeatureSummaryResult.model_validate( + normalize_feature_summary_payload(row.response_json) + ) + if row.response_json is not None + else None + ), + error_message=row.error_message, + created_at=row.created_at, + ) + for row in rows + ] + return FeatureSpecHistoryResponse(items=items) diff --git a/app/modules/feature_spec/parser.py b/app/modules/feature_spec/parser.py new file mode 100644 index 0000000..5b71088 --- /dev/null +++ b/app/modules/feature_spec/parser.py @@ -0,0 +1,281 @@ +import json +import logging +import re +from typing import Any + + +logger = logging.getLogger(__name__) + + +def _is_missing(value: Any) -> bool: + if value is None: + return True + if isinstance(value, str): + return not value.strip() + if isinstance(value, (list, dict)): + return not value + return False + + +def _normalize_acceptance_criteria(value: Any, *, strict: bool = False) -> list[str] | None: + if value is None: + return None + + items = value if isinstance(value, list) else [value] + normalized_items: list[str] = [] + + for item in items: + if isinstance(item, str): + text = item.strip() + elif isinstance(item, dict): + text = next( + ( + candidate.strip() + for key in ( + "title", + "description", + "details", + "criterion", + "text", + "message", + ) + if isinstance((candidate := item.get(key)), str) + and candidate.strip() + ), + str(item).strip(), + ) + else: + text = str(item).strip() + + if text: + normalized_items.append(text) + elif strict: + raise ValueError("Invalid acceptance_criteria item") + + return normalized_items or None + + +def _normalize_risk_assessment(value: Any, *, strict: bool = False) -> list[str] | None: + if value is None: + return None + + items = value if isinstance(value, list) else [value] + normalized_items: list[str] = [] + + for item in items: + if isinstance(item, str): + text = item.strip() + elif isinstance(item, dict): + text = next( + ( + candidate.strip() + for key in ("title", "risk", "description", "details", "message") + if isinstance((candidate := item.get(key)), str) + and candidate.strip() + ), + str(item).strip(), + ) + else: + text = str(item).strip() + + if text: + normalized_items.append(text) + elif strict: + raise ValueError("Invalid risk_assessment item") + + return normalized_items or None + + +def _normalize_mixed_items(value: Any, *, strict: bool = False) -> list[str | dict] | None: + if value is None: + return None + + items = value if isinstance(value, list) else [value] + normalized_items: list[str | dict] = [] + + for item in items: + if isinstance(item, dict): + normalized_items.append(item) + continue + + text = str(item).strip() + if text: + normalized_items.append(text) + elif strict: + raise ValueError("Invalid db/api item") + + return normalized_items or None + + +def _first_non_empty(source: dict[str, Any], *keys: str) -> str | None: + for key in keys: + value = source.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + return None + + +def _normalize_user_stories(value: Any, *, strict: bool = False) -> list[dict[str, str]] | None: + if value is None: + return None + + items = value if isinstance(value, list) else [value] + normalized_items: list[dict] = [] + + for index, item in enumerate(items, start=1): + if isinstance(item, dict): + title = _first_non_empty(item, "title", "name") + as_a = _first_non_empty(item, "as_a", "as", "actor", "user") + i_want = _first_non_empty(item, "i_want", "want", "goal", "objective") + so_that = _first_non_empty(item, "so_that", "benefit", "because", "outcome") + if not i_want: + i_want = _first_non_empty(item, "description", "details", "text") + if not so_that: + so_that = i_want + + if not all([title, as_a, i_want, so_that]): + if strict: + raise ValueError(f"Invalid user story at index {index}") + logger.warning("Skipping invalid user story item at index %s", index) + continue + + normalized_items.append( + { + "title": title, + "as_a": as_a, + "i_want": i_want, + "so_that": so_that, + } + ) + continue + + if isinstance(item, str) and item.strip(): + text = item.strip() + if strict: + raise ValueError( + f"Invalid user story string item at index {index}; object expected" + ) + normalized_items.append( + { + "title": text, + "as_a": "User", + "i_want": text, + "so_that": text, + } + ) + elif strict: + raise ValueError(f"Invalid user story item type at index {index}") + + return normalized_items or None + + +def normalize_feature_summary_payload( + data: dict[str, Any] | list, + *, + strict: bool = False, +) -> dict[str, Any] | list: + if not isinstance(data, dict): + return data + + normalized = dict(data) + + if isinstance(normalized.get("feature_summary"), dict): + nested = normalized["feature_summary"] + logger.info("Using nested feature_summary payload") + for key in ( + "user_stories", + "feature_summary_items", + "acceptance_criteria", + "acceptance", + "db_models_and_api_endpoints", + "db_models", + "api_endpoints", + "risk_assessment", + "risks", + ): + if key not in normalized and key in nested: + normalized[key] = nested[key] + + user_stories_source = normalized.get("user_stories") + if _is_missing(user_stories_source): + user_stories_source = normalized.get("feature_summary_items") + + normalized_user_stories = _normalize_user_stories(user_stories_source, strict=strict) + if normalized_user_stories is not None: + normalized["user_stories"] = normalized_user_stories + + acceptance_source = normalized.get("acceptance_criteria") + if _is_missing(acceptance_source): + logger.info("Using legacy acceptance field") + acceptance_source = normalized.get("acceptance") + + normalized_acceptance = _normalize_acceptance_criteria( + acceptance_source, + strict=strict, + ) + if normalized_acceptance is not None: + normalized["acceptance_criteria"] = normalized_acceptance + + db_api_source = normalized.get("db_models_and_api_endpoints") + db_source = normalized.get("db_models") + api_source = normalized.get("api_endpoints") + + if isinstance(db_api_source, dict): + db_source = db_api_source.get("db_models", db_source) + api_source = db_api_source.get("api_endpoints", api_source) + + normalized["db_models_and_api_endpoints"] = { + "db_models": _normalize_mixed_items(db_source, strict=strict) or [], + "api_endpoints": _normalize_mixed_items(api_source, strict=strict) or [], + } + + risk_source = normalized.get("risk_assessment") + if _is_missing(risk_source): + logger.info("Using legacy risks field") + risk_source = normalized.get("risks") + + normalized_risks = _normalize_risk_assessment(risk_source, strict=strict) + if normalized_risks is not None: + normalized["risk_assessment"] = normalized_risks + + return normalized + + +def extract_json(text: str) -> dict | list: + stripped = text.strip() + if stripped: + try: + parsed = json.loads(stripped) + if isinstance(parsed, (dict, list)): + return parsed + except json.JSONDecodeError: + pass + + fenced_match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL) + if fenced_match: + candidate = fenced_match.group(1).strip() + try: + parsed = json.loads(candidate) + if isinstance(parsed, (dict, list)): + return parsed + except json.JSONDecodeError: + pass + + for candidate in re.findall(r"\{[\s\S]*?\}|\[[\s\S]*?\]", text): + try: + parsed = json.loads(candidate) + except json.JSONDecodeError: + continue + if isinstance(parsed, (dict, list)): + return parsed + + raise ValueError("No JSON found in LLM response") + + +def strip_markdown(text: str) -> str: + text = re.sub(r"```[\w]*\n?", "", text) + return text.strip() + + +def normalize_whitespace(text: str) -> str: + return re.sub(r"\n{3,}", "\n\n", text).strip() diff --git a/app/modules/feature_spec/prompts/__init__.py b/app/modules/feature_spec/prompts/__init__.py new file mode 100644 index 0000000..891cfab --- /dev/null +++ b/app/modules/feature_spec/prompts/__init__.py @@ -0,0 +1,13 @@ +from app.modules.feature_spec.prompts.feature_summary import ( + build_feature_summary_prompt_from_db, + build_feature_summary_prompt_from_template, + load_feature_summary_template, + parse_feature_summary_response, +) + +__all__ = [ + "load_feature_summary_template", + "build_feature_summary_prompt_from_template", + "build_feature_summary_prompt_from_db", + "parse_feature_summary_response", +] diff --git a/app/modules/feature_spec/prompts/feature_summary.py b/app/modules/feature_spec/prompts/feature_summary.py new file mode 100644 index 0000000..9deb419 --- /dev/null +++ b/app/modules/feature_spec/prompts/feature_summary.py @@ -0,0 +1,46 @@ +import html + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from app.modules.feature_spec.models import PromptTemplate +from app.modules.feature_spec.parser import ( + extract_json, + normalize_feature_summary_payload, + normalize_whitespace, + strip_markdown, +) + + +def build_feature_summary_prompt_from_template(template: str, feature_idea: str) -> str: + safe_template = template.strip() + escaped_feature_idea = html.escape(feature_idea.strip(), quote=True) + rendered = safe_template.replace("{feature_idea}", escaped_feature_idea) + rendered = rendered.replace("{input}", escaped_feature_idea) + return rendered + + +def load_feature_summary_template(db: Session) -> str: + statement = ( + select(PromptTemplate.feature_to_feature_summary) + .where(PromptTemplate.is_active.is_(True)) + .order_by(PromptTemplate.updated_at.desc()) + .limit(1) + ) + template = db.execute(statement).scalar_one_or_none() + if template and template.strip(): + return template + raise ValueError( + "Active feature summary prompt template is not configured in database" + ) + + +def build_feature_summary_prompt_from_db(feature_idea: str, db: Session) -> str: + template = load_feature_summary_template(db) + return build_feature_summary_prompt_from_template(template, feature_idea) + + +def parse_feature_summary_response(raw_response: str) -> dict | list: + normalized_content = normalize_whitespace(strip_markdown(raw_response)) + parsed = extract_json(normalized_content) + return normalize_feature_summary_payload(parsed) diff --git a/app/modules/feature_spec/providers/__init__.py b/app/modules/feature_spec/providers/__init__.py new file mode 100644 index 0000000..3176b21 --- /dev/null +++ b/app/modules/feature_spec/providers/__init__.py @@ -0,0 +1,3 @@ +from app.modules.feature_spec.providers.ollama import OllamaClient, ollama_client + +__all__ = ["OllamaClient", "ollama_client"] diff --git a/app/modules/feature_spec/providers/ollama.py b/app/modules/feature_spec/providers/ollama.py new file mode 100644 index 0000000..8e0fcca --- /dev/null +++ b/app/modules/feature_spec/providers/ollama.py @@ -0,0 +1,144 @@ +import asyncio +import json +import logging +from typing import Any, AsyncGenerator + +import httpx + +from app.core.settings import settings + +logger = logging.getLogger(__name__) + + +class OllamaClient: + def __init__(self) -> None: + self._base_url = settings.OLLAMA_BASE_URL.rstrip("/") + self._model = settings.OLLAMA_MODEL + self._timeout = settings.OLLAMA_TIMEOUT + self._connect_timeout = settings.OLLAMA_CONNECT_TIMEOUT + self._max_retries = settings.OLLAMA_MAX_RETRIES + self._retry_backoff_seconds = settings.OLLAMA_RETRY_BACKOFF_SECONDS + self._system_prompt = settings.OLLAMA_SYSTEM_PROMPT + + def _build_payload(self, user_prompt: str, stream: bool = False) -> dict: + messages = [] + if self._system_prompt.strip(): + messages.append({"role": "system", "content": self._system_prompt}) + messages.append({"role": "user", "content": user_prompt}) + return { + "model": self._model, + "stream": stream, + "format": "json", + "messages": messages, + } + + def _http_timeout(self) -> httpx.Timeout: + return httpx.Timeout(float(self._timeout), connect=float(self._connect_timeout)) + + def _should_retry_status(self, status_code: int) -> bool: + return 500 <= status_code <= 599 + + async def _backoff(self, attempt: int) -> None: + await asyncio.sleep(self._retry_backoff_seconds * (attempt + 1)) + + async def generate(self, user_prompt: str) -> str: + url = f"{self._base_url}/api/chat" + payload = self._build_payload(user_prompt, stream=False) + + async with httpx.AsyncClient(timeout=self._http_timeout()) as client: + for attempt in range(self._max_retries + 1): + try: + response = await client.post(url, json=payload) + response.raise_for_status() + data = response.json() + message = data.get("message", {}) + content = message.get("content", "") + return content if isinstance(content, str) else str(content) + except (json.JSONDecodeError, ValueError) as exc: + if attempt < self._max_retries: + await self._backoff(attempt) + continue + raise RuntimeError("Ollama returned invalid JSON response") from exc + except httpx.TimeoutException as exc: + if attempt < self._max_retries: + await self._backoff(attempt) + continue + raise RuntimeError( + f"Ollama request timed out after {self._timeout}s" + ) from exc + except httpx.HTTPStatusError as exc: + status_code = exc.response.status_code + if ( + self._should_retry_status(status_code) + and attempt < self._max_retries + ): + await self._backoff(attempt) + continue + raise RuntimeError( + f"Ollama returned HTTP {status_code}: {exc.response.text}" + ) from exc + except httpx.RequestError as exc: + if attempt < self._max_retries: + await self._backoff(attempt) + continue + raise RuntimeError("Ollama request failed") from exc + + raise RuntimeError("Ollama request failed after retries") + + async def generate_stream(self, user_prompt: str) -> AsyncGenerator[str, None]: + url = f"{self._base_url}/api/chat" + payload = self._build_payload(user_prompt, stream=True) + + yielded_any = False + async with httpx.AsyncClient(timeout=self._http_timeout()) as client: + for attempt in range(self._max_retries + 1): + try: + async with client.stream("POST", url, json=payload) as response: + response.raise_for_status() + stream_done = False + async for line in response.aiter_lines(): + if not line: + continue + try: + chunk: Any = json.loads(line) + except Exception: + logger.warning("Skipping invalid Ollama stream chunk") + continue + token = chunk.get("message", {}).get("content", "") + if token: + yielded_any = True + yield token + if chunk.get("done"): + stream_done = True + return + if not stream_done: + raise RuntimeError("Ollama stream ended before completion") + except httpx.TimeoutException as exc: + if not yielded_any and attempt < self._max_retries: + await self._backoff(attempt) + continue + raise RuntimeError( + f"Ollama stream timed out after {self._timeout}s" + ) from exc + except httpx.HTTPStatusError as exc: + status_code = exc.response.status_code + if ( + not yielded_any + and self._should_retry_status(status_code) + and attempt < self._max_retries + ): + await self._backoff(attempt) + continue + raise RuntimeError( + f"Ollama stream returned HTTP {status_code}: {exc.response.text}" + ) from exc + except httpx.RequestError as exc: + if not yielded_any and attempt < self._max_retries: + await self._backoff(attempt) + continue + raise RuntimeError("Ollama stream request failed") from exc + + raise RuntimeError("Ollama stream failed after retries") + + +ollama_client = OllamaClient() diff --git a/app/modules/feature_spec/router.py b/app/modules/feature_spec/router.py new file mode 100644 index 0000000..49beb22 --- /dev/null +++ b/app/modules/feature_spec/router.py @@ -0,0 +1,43 @@ +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from app.core.database import get_db +from app.core.settings import settings +from app.modules.auth.dependencies import get_current_user +from app.modules.auth.models import User +from app.modules.feature_spec.orchestrator import get_feature_spec_history +from app.modules.feature_spec.schemas import ( + FeatureSpecGenerateRequest, + FeatureSpecHistoryResponse, + FeatureSpecTaskStatusResponse, + FeatureSpecTaskSubmitResponse, +) +from app.modules.feature_spec.service import ( + get_feature_spec_task_status, + submit_feature_spec_generation, +) + +router = APIRouter(prefix="/feature-spec", tags=["feature-spec"]) + + +@router.post("/generate", response_model=FeatureSpecTaskSubmitResponse) +async def generate_feature_spec_endpoint( + payload: FeatureSpecGenerateRequest, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> FeatureSpecTaskSubmitResponse: + return submit_feature_spec_generation(payload.feature_idea, db, current_user.id) + + +@router.get("/tasks/{task_id}", response_model=FeatureSpecTaskStatusResponse) +async def get_feature_spec_task_status_endpoint(task_id: str) -> FeatureSpecTaskStatusResponse: + return get_feature_spec_task_status(task_id) + + +@router.get("/history", response_model=FeatureSpecHistoryResponse) +async def get_feature_spec_history_endpoint( + limit: int = settings.FEATURE_SPEC_HISTORY_DEFAULT_LIMIT, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> FeatureSpecHistoryResponse: + return get_feature_spec_history(db, current_user.id, limit) diff --git a/app/modules/feature_spec/schemas.py b/app/modules/feature_spec/schemas.py new file mode 100644 index 0000000..43eb23f --- /dev/null +++ b/app/modules/feature_spec/schemas.py @@ -0,0 +1,67 @@ +from datetime import datetime +from typing import Any, Literal + +from pydantic import BaseModel, Field, field_validator + +from app.core.settings import settings + + +class FeatureSpecGenerateRequest(BaseModel): + feature_idea: str = Field(min_length=1, max_length=settings.LLM_PROMPT_MAX_LENGTH) + + @field_validator("feature_idea") + @classmethod + def validate_feature_idea(cls, value: str) -> str: + normalized = value.strip() + if not normalized: + raise ValueError("Feature idea must not be empty") + return normalized + + +class FeatureSummaryItem(BaseModel): + title: str + as_a: str + i_want: str + so_that: str + + +class DbModelsAndApiEndpoints(BaseModel): + db_models: list[str | dict[str, Any]] + api_endpoints: list[str | dict[str, Any]] + + +class FeatureSummaryResult(BaseModel): + user_stories: list[FeatureSummaryItem] + acceptance_criteria: list[str] + db_models_and_api_endpoints: DbModelsAndApiEndpoints + risk_assessment: list[str] + + +class FeatureSpecGenerateResponse(BaseModel): + feature_idea: str + feature_summary: FeatureSummaryResult + + +class FeatureSpecTaskSubmitResponse(BaseModel): + task_id: str + status: Literal["processing"] + + +class FeatureSpecTaskStatusResponse(BaseModel): + task_id: str + status: Literal["PENDING", "STARTED", "SUCCESS", "FAILURE"] + result: dict[str, Any] | None = None + error: str | None = None + + +class FeatureSpecHistoryItem(BaseModel): + id: int + feature_idea: str + status: Literal["pending", "success", "error"] + response_json: FeatureSummaryResult | None = None + error_message: str | None = None + created_at: datetime + + +class FeatureSpecHistoryResponse(BaseModel): + items: list[FeatureSpecHistoryItem] diff --git a/app/modules/feature_spec/service.py b/app/modules/feature_spec/service.py new file mode 100644 index 0000000..ff13d09 --- /dev/null +++ b/app/modules/feature_spec/service.py @@ -0,0 +1,56 @@ +from celery.result import AsyncResult +from sqlalchemy.orm import Session + +from app.infrastructure.celery_app import celery_app +from app.infrastructure.tasks.feature_spec_tasks import generate_feature_spec_task +from app.modules.feature_spec.models import FeatureSpecRun +from app.modules.feature_spec.schemas import ( + FeatureSpecTaskStatusResponse, + FeatureSpecTaskSubmitResponse, +) + + +TERMINAL_STATES = {"SUCCESS", "FAILURE"} + + +def submit_feature_spec_generation( + feature_idea: str, + db: Session, + user_id: int, +) -> FeatureSpecTaskSubmitResponse: + run = FeatureSpecRun( + user_id=user_id, + feature_idea=feature_idea, + status="pending", + ) + db.add(run) + db.commit() + db.refresh(run) + + task = generate_feature_spec_task.delay(run.id, feature_idea, user_id) + return FeatureSpecTaskSubmitResponse(task_id=task.id, status="processing") + + +def get_feature_spec_task_status(task_id: str) -> FeatureSpecTaskStatusResponse: + async_result = AsyncResult(task_id, app=celery_app) + state = async_result.state + + if state == "SUCCESS": + payload = async_result.result if isinstance(async_result.result, dict) else None + return FeatureSpecTaskStatusResponse( + task_id=task_id, + status="SUCCESS", + result=payload, + ) + + if state == "FAILURE": + return FeatureSpecTaskStatusResponse( + task_id=task_id, + status="FAILURE", + error="Task execution failed", + ) + + if state == "STARTED": + return FeatureSpecTaskStatusResponse(task_id=task_id, status="STARTED") + + return FeatureSpecTaskStatusResponse(task_id=task_id, status="PENDING") diff --git a/app/modules/llm/parser.py b/app/modules/llm/parser.py deleted file mode 100644 index 4d71793..0000000 --- a/app/modules/llm/parser.py +++ /dev/null @@ -1,49 +0,0 @@ -import json -import re - - -def extract_json(text: str) -> dict | list: - """Extract the first JSON object or array from a raw LLM response.""" - match = re.search(r"(\{.*\}|\[.*\])", text, re.DOTALL) - if not match: - raise ValueError("No JSON found in LLM response") - return json.loads(match.group(1)) - - -def strip_markdown(text: str) -> str: - """Remove Markdown code fences and trim whitespace.""" - text = re.sub(r"```[\w]*\n?", "", text) - return text.strip() - - -def extract_sections(text: str) -> dict[str, str]: - """ - Split a Markdown-formatted LLM response into {heading: content} sections. - - Example input: - ## Overview - This is a spec... - ## Requirements - 1. ... - """ - sections: dict[str, str] = {} - current_heading = "_preamble" - current_lines: list[str] = [] - - for line in text.splitlines(): - heading_match = re.match(r"^#{1,3}\s+(.+)", line) - if heading_match: - sections[current_heading] = "\n".join(current_lines).strip() - current_heading = heading_match.group(1).strip() - current_lines = [] - else: - current_lines.append(line) - - sections[current_heading] = "\n".join(current_lines).strip() - sections.pop("_preamble", None) - return sections - - -def normalize_whitespace(text: str) -> str: - """Collapse multiple blank lines to a single blank line.""" - return re.sub(r"\n{3,}", "\n\n", text).strip() diff --git a/app/modules/llm/providers/ollama.py b/app/modules/llm/providers/ollama.py deleted file mode 100644 index 743c6d4..0000000 --- a/app/modules/llm/providers/ollama.py +++ /dev/null @@ -1,70 +0,0 @@ -import logging -from typing import AsyncGenerator - -import httpx - -from app.core.settings import settings - -logger = logging.getLogger(__name__) - - -class OllamaClient: - def __init__(self) -> None: - self._base_url = settings.OLLAMA_BASE_URL.rstrip("/") - self._model = settings.OLLAMA_MODEL - self._timeout = settings.OLLAMA_TIMEOUT - self._system_prompt = settings.OLLAMA_SYSTEM_PROMPT - - def _build_payload(self, user_prompt: str, stream: bool = False) -> dict: - return { - "model": self._model, - "stream": stream, - "messages": [ - {"role": "system", "content": self._system_prompt}, - {"role": "user", "content": user_prompt}, - ], - } - - async def generate(self, user_prompt: str) -> str: - url = f"{self._base_url}/api/chat" - payload = self._build_payload(user_prompt, stream=False) - - async with httpx.AsyncClient(timeout=self._timeout) as client: - try: - response = await client.post(url, json=payload) - response.raise_for_status() - except httpx.TimeoutException as exc: - logger.exception("Ollama request timed out") - raise RuntimeError( - f"Ollama request timed out after {self._timeout}s" - ) from exc - except httpx.HTTPStatusError as exc: - logger.exception("Ollama returned error response") - raise RuntimeError( - f"Ollama returned HTTP {exc.response.status_code}: {exc.response.text}" - ) from exc - - data = response.json() - return data["message"]["content"] - - async def generate_stream(self, user_prompt: str) -> AsyncGenerator[str, None]: - url = f"{self._base_url}/api/chat" - payload = self._build_payload(user_prompt, stream=True) - - async with httpx.AsyncClient(timeout=self._timeout) as client: - async with client.stream("POST", url, json=payload) as response: - response.raise_for_status() - async for line in response.aiter_lines(): - if not line: - continue - import json as _json - - chunk = _json.loads(line) - token = chunk.get("message", {}).get("content", "") - if token: - yield token - if chunk.get("done"): - break - - -ollama_client = OllamaClient() diff --git a/app/modules/llm/router.py b/app/modules/llm/router.py deleted file mode 100644 index e868011..0000000 --- a/app/modules/llm/router.py +++ /dev/null @@ -1,23 +0,0 @@ -from fastapi import APIRouter, HTTPException, status - -from app.core.settings import settings -from app.modules.llm.schemas import LlmGenerateRequest, LlmGenerateResponse -from app.modules.llm.service import generate_completion - -router = APIRouter(prefix=settings.LLM_PREFIX, tags=[settings.LLM_TAG]) - - -@router.post(settings.LLM_GENERATE_PATH, response_model=LlmGenerateResponse) -async def generate(payload: LlmGenerateRequest) -> LlmGenerateResponse: - try: - return await generate_completion(payload) - except ValueError as exc: - raise HTTPException( - status_code=status.HTTP_502_BAD_GATEWAY, - detail=str(exc), - ) from exc - except RuntimeError as exc: - raise HTTPException( - status_code=status.HTTP_502_BAD_GATEWAY, - detail=str(exc), - ) from exc diff --git a/app/modules/llm/schemas.py b/app/modules/llm/schemas.py deleted file mode 100644 index 0d31e9f..0000000 --- a/app/modules/llm/schemas.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Literal - -from pydantic import BaseModel, Field - - -class LlmGenerateRequest(BaseModel): - prompt: str = Field(min_length=1) - response_format: Literal["text", "sections", "json"] = "text" - - -class LlmGenerateResponse(BaseModel): - raw_content: str - content: str | None = None - sections: dict[str, str] | None = None - data: dict | list | None = None diff --git a/app/modules/llm/service.py b/app/modules/llm/service.py deleted file mode 100644 index e8c347b..0000000 --- a/app/modules/llm/service.py +++ /dev/null @@ -1,26 +0,0 @@ -from app.modules.llm.parser import (extract_json, extract_sections, - normalize_whitespace, strip_markdown) -from app.modules.llm.providers.ollama import ollama_client -from app.modules.llm.schemas import LlmGenerateRequest, LlmGenerateResponse - - -async def generate_completion(payload: LlmGenerateRequest) -> LlmGenerateResponse: - raw_content = await ollama_client.generate(payload.prompt) - normalized_content = normalize_whitespace(strip_markdown(raw_content)) - - if payload.response_format == "json": - return LlmGenerateResponse( - raw_content=raw_content, - data=extract_json(normalized_content), - ) - - if payload.response_format == "sections": - return LlmGenerateResponse( - raw_content=raw_content, - sections=extract_sections(normalized_content), - ) - - return LlmGenerateResponse( - raw_content=raw_content, - content=normalized_content, - ) diff --git a/app/scripts/bootstrap_admin.py b/app/scripts/bootstrap_admin.py index 9c36452..785782d 100644 --- a/app/scripts/bootstrap_admin.py +++ b/app/scripts/bootstrap_admin.py @@ -1,15 +1,79 @@ import logging -from app.core.bootstrap import bootstrap_auth +from fastapi_users.password import PasswordHelper + from app.core.database import SessionLocal +from app.core.settings import settings +from app.modules.auth.models import User logger = logging.getLogger(__name__) +def _resolve_bootstrap_hashed_password() -> str: + helper = PasswordHelper() + if not settings.AUTH_PASSWORD: + raise RuntimeError( + "AUTH_PASSWORD must be set for admin bootstrap user creation" + ) + return helper.hash(settings.AUTH_PASSWORD) + + +def _is_bootstrap_enabled() -> bool: + if settings.ENV.lower() == "production": + return settings.AUTH_BOOTSTRAP_ENABLED + return settings.AUTH_BOOTSTRAP_ENABLED or bool(settings.AUTH_EMAIL) + + +def bootstrap_admin_user(db) -> None: + if not _is_bootstrap_enabled(): + logger.info("Auth bootstrap is disabled") + return + + if not settings.AUTH_EMAIL or not settings.AUTH_USERNAME: + raise RuntimeError( + "AUTH_EMAIL and AUTH_USERNAME must be set when bootstrap is enabled" + ) + + existing_user = db.query(User).filter(User.email == settings.AUTH_EMAIL).one_or_none() + if existing_user is not None: + try: + if not existing_user.hashed_password: + existing_user.hashed_password = _resolve_bootstrap_hashed_password() + db.add(existing_user) + db.commit() + except Exception: + logger.exception( + "Failed to ensure bootstrap state for existing user: email=%s", + settings.AUTH_EMAIL, + ) + db.rollback() + raise + return + + try: + user = User( + username=settings.AUTH_USERNAME, + email=settings.AUTH_EMAIL, + hashed_password=_resolve_bootstrap_hashed_password(), + is_active=True, + is_superuser=settings.AUTH_BOOTSTRAP_SUPERUSER, + is_verified=True, + ) + db.add(user) + db.commit() + except Exception: + logger.exception( + "Failed to create bootstrap user: email=%s", + settings.AUTH_EMAIL, + ) + db.rollback() + raise + + def main() -> None: db = SessionLocal() try: - bootstrap_auth(db) + bootstrap_admin_user(db) logger.info("Admin bootstrap script completed") finally: db.close() diff --git a/app/scripts/bootstrap_prompt_template.py b/app/scripts/bootstrap_prompt_template.py new file mode 100644 index 0000000..68626b0 --- /dev/null +++ b/app/scripts/bootstrap_prompt_template.py @@ -0,0 +1,162 @@ +import logging + +from app.core.database import SessionLocal +from app.modules.feature_spec.models import PromptTemplate + +logger = logging.getLogger(__name__) + + +DEFAULT_FEATURE_SUMMARY_PROMPT_TEMPLATE = """You are a senior staff-level product engineer +and system architect working on production-grade backend systems. + +Your task is to convert a high-level feature idea into a precise, +implementation-ready technical specification. + +You must think in terms of real backend development: APIs, database schema, +data consistency, edge cases, and scalability. + +--- + +## Core rules: + +- Be strictly structured and deterministic +- Do NOT invent product requirements beyond the given feature idea +- If information is missing, explicitly list it under "assumptions" +- Prefer simple and realistic solutions over over-engineering +- Every part of the output must be implementation-ready +- Avoid vague phrases like "etc", "additional fields", "and so on" +- If something is required — define it explicitly + +--- + +## Hard requirements: + +- Database design MUST include full schema: + - each table must have 4–8 fields + - include types and constraints (PK, FK, unique, nullable) + - include relationships + +- API design MUST: + - include at least 4 endpoints (not just 1–2) + - include request validation + - include access control logic + - include error responses + +- Acceptance criteria MUST: + - include positive and negative cases + - include authorization behavior + - be testable (QA-ready, no vague language) + +- Edge cases & risks MUST include: + - concurrency issues + - data consistency problems + - security concerns + - failure scenarios + +- If the feature involves payments: + - include payment provider interaction (e.g. Stripe-like flow) + - include idempotency handling + - include transaction status tracking + +--- + +## Output format (STRICT JSON ONLY) + +Return ONLY valid JSON. No explanations, no markdown, no extra text. + +{ + "assumptions": ["string"], + + "user_stories": [ + { + "title": "string", + "as_a": "string", + "i_want": "string", + "so_that": "string" + } + ], + + "acceptance_criteria": ["string"], + + "db_models": [ + { + "table": "string", + "fields": [ + { + "name": "string", + "type": "string", + "constraints": "string" + } + ], + "relationships": ["string"] + } + ], + + "api_endpoints": [ + { + "method": "string", + "path": "string", + "request_body": "string", + "response": "string", + "errors": ["string"], + "purpose": "string" + } + ], + + "risk_assessment": ["string"] +} + +--- + +Feature idea: +{feature_idea} +""" + + +def bootstrap_prompt_template(db) -> None: + existing_template = db.get(PromptTemplate, 1) + default_template = DEFAULT_FEATURE_SUMMARY_PROMPT_TEMPLATE + + if existing_template is None: + try: + db.add( + PromptTemplate( + id=1, + feature_to_feature_summary=default_template, + is_active=True, + ) + ) + db.commit() + logger.info("Created default feature summary prompt template") + except Exception: + logger.exception("Failed to create default feature summary prompt template") + db.rollback() + raise + return + + existing_value = existing_template.feature_to_feature_summary.strip() + if existing_value: + return + + try: + existing_template.feature_to_feature_summary = default_template + db.add(existing_template) + db.commit() + logger.info("Filled empty feature summary prompt template with default value") + except Exception: + logger.exception("Failed to fill default feature summary prompt template") + db.rollback() + raise + + +def main() -> None: + db = SessionLocal() + try: + bootstrap_prompt_template(db) + logger.info("Prompt template bootstrap script completed") + finally: + db.close() + + +if __name__ == "__main__": + main() diff --git a/app/scripts/ensure_ollama_model.py b/app/scripts/ensure_ollama_model.py new file mode 100644 index 0000000..42b3651 --- /dev/null +++ b/app/scripts/ensure_ollama_model.py @@ -0,0 +1,84 @@ +import logging +import os +import time + +import httpx + +logger = logging.getLogger(__name__) + + +def _env_int(name: str, default: int) -> int: + raw_value = os.getenv(name) + if raw_value is None: + return default + try: + return int(raw_value) + except ValueError as exc: + raise RuntimeError(f"{name} must be an integer") from exc + + +def _is_model_available(client: httpx.Client, model: str) -> bool: + response = client.get("/api/tags") + response.raise_for_status() + models = response.json().get("models", []) + return any(item.get("name") == model for item in models if isinstance(item, dict)) + + +def wait_for_ollama(client: httpx.Client, wait_seconds: int) -> None: + deadline = time.monotonic() + wait_seconds + while time.monotonic() < deadline: + try: + client.get("/api/tags").raise_for_status() + logger.info("Ollama is ready") + return + except httpx.HTTPError: + logger.info("Waiting for Ollama to become ready...") + time.sleep(2) + + raise RuntimeError(f"Ollama did not become ready within {wait_seconds}s") + + +def ensure_model(client: httpx.Client, model: str, pull_timeout: int) -> None: + if _is_model_available(client, model): + logger.info("Ollama model '%s' already available", model) + return + + logger.info("Pulling missing Ollama model '%s'", model) + response = client.post( + "/api/pull", + json={"model": model, "stream": False}, + timeout=pull_timeout, + ) + response.raise_for_status() + + if not _is_model_available(client, model): + raise RuntimeError(f"Ollama model '{model}' is still unavailable after pull") + + logger.info("Ollama model '%s' is ready", model) + + +def ensure_ollama_model() -> None: + base_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434").rstrip("/") + model = os.getenv("OLLAMA_MODEL", "mistral").strip() + connect_timeout = _env_int("OLLAMA_CONNECT_TIMEOUT", 10) + request_timeout = _env_int("OLLAMA_TIMEOUT", 120) + startup_wait = _env_int("OLLAMA_STARTUP_WAIT_SECONDS", 120) + + if not model: + raise RuntimeError("OLLAMA_MODEL must not be empty") + + logger.info("Checking Ollama availability at %s", base_url) + + timeout = httpx.Timeout(timeout=request_timeout, connect=connect_timeout) + with httpx.Client(base_url=base_url, timeout=timeout) as client: + wait_for_ollama(client, startup_wait) + ensure_model(client, model, pull_timeout=max(request_timeout, 300)) + + +def main() -> None: + logging.basicConfig(level=logging.INFO, format="[ollama-bootstrap] %(message)s") + ensure_ollama_model() + + +if __name__ == "__main__": + main() diff --git a/app/scripts/migrate_and_check.py b/app/scripts/migrate_and_check.py new file mode 100644 index 0000000..25e8da0 --- /dev/null +++ b/app/scripts/migrate_and_check.py @@ -0,0 +1,91 @@ +import logging +import importlib +from pathlib import Path +import re +import subprocess +import sys + +logger = logging.getLogger(__name__) + + +def _project_root() -> Path: + return Path(__file__).resolve().parents[2] + + +def _run_alembic(*args: str) -> str: + command = [sys.executable, "-m", "alembic", *args] + result = subprocess.run(command, cwd=_project_root(), capture_output=True, text=True) + if result.returncode != 0: + stderr = (result.stderr or "").strip() + stdout = (result.stdout or "").strip() + details = stderr or stdout or "unknown error" + raise RuntimeError(f"Alembic command failed: {' '.join(args)}; {details}") + return result.stdout + + +def _extract_revisions(raw_output: str) -> set[str]: + revisions: set[str] = set() + for line in raw_output.splitlines(): + match = re.match(r"^([0-9a-f]+)\b", line.strip()) + if match: + revisions.add(match.group(1)) + return revisions + + +def _has_migration_files() -> bool: + versions_dir = _project_root() / "alembic" / "versions" + return any(versions_dir.glob("*.py")) + + +def _bootstrap_schema_without_migrations() -> None: + importlib.import_module("app.modules.auth.models") + importlib.import_module("app.modules.feature_spec.models") + from app.core.database import Base, engine + + logger.warning( + "No Alembic migration files found; creating schema via SQLAlchemy metadata" + ) + if not Base.metadata.tables: + raise RuntimeError( + "No SQLAlchemy models are registered; cannot bootstrap schema" + ) + Base.metadata.create_all(bind=engine) + logger.info( + "Schema created via SQLAlchemy metadata (tables=%s)", + sorted(Base.metadata.tables.keys()), + ) + + +def migrate_and_check() -> None: + if not _has_migration_files(): + _bootstrap_schema_without_migrations() + return + + logger.info("Running Alembic upgrade to head") + _run_alembic("upgrade", "head") + + current_heads = _extract_revisions(_run_alembic("current")) + expected_heads = _extract_revisions(_run_alembic("heads")) + + if not current_heads: + raise RuntimeError("No migration state found in DB after alembic upgrade") + + if current_heads != expected_heads: + raise RuntimeError( + "Alembic DB revision is not at head: " + f"current={sorted(current_heads)} expected={sorted(expected_heads)}" + ) + + logger.info("Alembic migrations are applied and DB is in sync") + + +def main() -> None: + logging.basicConfig( + level=logging.INFO, + format="[migrations] %(message)s", + ) + migrate_and_check() + + +if __name__ == "__main__": + main() diff --git a/docker-compose.yml b/docker-compose.yml index 0c7cf46..ef1c199 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,14 +8,21 @@ services: env_file: - .env ports: - - "8000:8000" + - "${FASTAPI_HOST_PORT:-8005}:${FASTAPI_CONTAINER_PORT:-8001}" + depends_on: + redis: + condition: service_healthy + ollama: + condition: service_healthy + celery-worker: + condition: service_healthy healthcheck: test: [ "CMD", "python", "-c", - "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/health', timeout=2)", + "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8001/health', timeout=2)", ] interval: 30s timeout: 5s @@ -23,3 +30,69 @@ services: start_period: 20s stop_grace_period: 20s restart: unless-stopped + + celery-worker: + build: + context: . + dockerfile: Dockerfile + container_name: specification-generator-celery-worker + init: true + env_file: + - .env + command: ["celery", "-A", "app.infrastructure.celery_app:celery_app", "worker", "--loglevel=INFO"] + depends_on: + redis: + condition: service_healthy + ollama: + condition: service_healthy + healthcheck: + test: + [ + "CMD", + "python", + "-c", + "from app.infrastructure.celery_app import celery_app; import sys; i=celery_app.control.inspect(timeout=2); sys.exit(0 if i and i.ping() else 1)", + ] + interval: 30s + timeout: 5s + retries: 3 + start_period: 20s + restart: unless-stopped + + redis: + image: redis:7-alpine + container_name: specification-generator-redis + init: true + ports: + - "${REDIS_HOST_PORT:-6380}:${REDIS_CONTAINER_PORT:-6379}" + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 3s + retries: 5 + start_period: 10s + restart: unless-stopped + + ollama: + image: ollama/ollama:latest + container_name: specification-generator-ollama + init: true + ports: + - "11434:11434" + volumes: + - ollama:/root/.ollama + healthcheck: + test: + [ + "CMD", + "ollama", + "list", + ] + interval: 30s + timeout: 10s + retries: 5 + start_period: 30s + restart: unless-stopped + +volumes: + ollama: diff --git a/entrypoint.sh b/entrypoint.sh index 893b72d..d7f387a 100644 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -1,14 +1,20 @@ #!/usr/bin/env sh set -eu -echo "[start] Running Alembic migrations..." -python -m alembic upgrade head +echo "[start] Running migrations and schema sync check..." +python -m app.scripts.migrate_and_check echo "[start] Running admin bootstrap script..." python -m app.scripts.bootstrap_admin +echo "[start] Running prompt template bootstrap script..." +python -m app.scripts.bootstrap_prompt_template + +echo "[start] Ensuring Ollama model is available..." +python -m app.scripts.ensure_ollama_model + if [ "$#" -eq 0 ]; then - set -- python -m uvicorn app.main:app --host 0.0.0.0 --port 8000 + set -- python -m uvicorn app.main:app --host 0.0.0.0 --port "${PORT:-8001}" fi echo "[start] Starting: $*" diff --git a/requirements-dev.txt b/requirements-dev.txt index 0a9aa75..c8babd2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,6 @@ pytest==8.2.0 pytest-asyncio==0.23.6 +freezegun==1.5.1 black==24.4.2 isort==5.13.2 flake8==7.0.0 diff --git a/requirements.txt b/requirements.txt index e366f2b..da29dc2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,26 +1,28 @@ -# FastAPI core fastapi==0.111.0 uvicorn[standard]==0.30.0 -# Database sqlalchemy==2.0.30 psycopg2-binary==2.9.9 -# Migrations alembic==1.13.1 -# Data validation pydantic==2.7.1 pydantic-settings==2.2.1 -# Auth / security + python-multipart==0.0.20 fastapi-users==14.0.1 fastapi-users-db-sqlalchemy[sqlalchemy]==7.0.0 asyncpg==0.30.0 -# Environment + +argon2-cffi==23.1.0 + python-dotenv==1.0.1 -# HTTP requests httpx==0.27.0 + +celery[redis]==5.4.0 + +sqladmin==0.17.0 +itsdangerous==2.2.0 diff --git a/tests/modules/api/test_auth_api.py b/tests/modules/api/test_auth_api.py index 32aaa37..d5d4361 100644 --- a/tests/modules/api/test_auth_api.py +++ b/tests/modules/api/test_auth_api.py @@ -1,7 +1,12 @@ +from datetime import datetime, timedelta, timezone + import pytest +from freezegun import freeze_time from fastapi.testclient import TestClient from app.main import app +from app.modules.auth.jwt_router import get_auth_session_service +from app.modules.auth.service import AuthTokens def _get_auth_me_dependency_callable(): @@ -61,3 +66,132 @@ def test_login_validates_form_payload(api_client: TestClient) -> None: response = api_client.post("/api/v1/auth/jwt/login", data={}) assert response.status_code == 422 + + +@pytest.mark.unit +def test_login_returns_tokens_and_sets_refresh_cookie(api_client: TestClient) -> None: + class FakeAuthSessionService: + async def login(self, credentials): + return AuthTokens(access_token="access-token", refresh_token="refresh-token") + + app.dependency_overrides[get_auth_session_service] = lambda: FakeAuthSessionService() + try: + response = api_client.post( + "/api/v1/auth/jwt/login", + data={"username": "admin", "password": "!QAZ1qaz"}, + ) + finally: + app.dependency_overrides.pop(get_auth_session_service, None) + + assert response.status_code == 200 + assert response.json()["token_type"] == "bearer" + assert response.json()["access_token"] == "access-token" + assert "refresh_token=" in response.headers.get("set-cookie", "") + + +@pytest.mark.unit +def test_refresh_requires_refresh_cookie(api_client: TestClient) -> None: + response = api_client.post("/api/v1/auth/jwt/refresh") + + assert response.status_code == 401 + assert response.json()["detail"] == "Unauthorized" + + +@pytest.mark.unit +def test_refresh_returns_new_access_token(api_client: TestClient) -> None: + class FakeAuthSessionService: + async def refresh(self, refresh_token): + if refresh_token != "valid-refresh-token": + return None + return AuthTokens( + access_token="new-access-token", + refresh_token="new-refresh-token", + ) + + app.dependency_overrides[get_auth_session_service] = lambda: FakeAuthSessionService() + try: + response = api_client.post( + "/api/v1/auth/jwt/refresh", + cookies={"refresh_token": "valid-refresh-token"}, + ) + finally: + app.dependency_overrides.pop(get_auth_session_service, None) + + assert response.status_code == 200 + assert response.json()["token_type"] == "bearer" + assert response.json()["access_token"] == "new-access-token" + assert "refresh_token=" in response.headers.get("set-cookie", "") + + +@pytest.mark.unit +def test_logout_clears_refresh_cookie(api_client: TestClient) -> None: + class FakeAuthSessionService: + async def logout(self, refresh_token): + return None + + app.dependency_overrides[get_auth_session_service] = lambda: FakeAuthSessionService() + try: + response = api_client.post( + "/api/v1/auth/jwt/logout", + cookies={"refresh_token": "any-token"}, + ) + finally: + app.dependency_overrides.pop(get_auth_session_service, None) + + assert response.status_code == 204 + set_cookie = response.headers.get("set-cookie", "") + assert "refresh_token=" in set_cookie + assert "Max-Age=0" in set_cookie or "expires=" in set_cookie.lower() + + +@pytest.mark.unit +def test_refresh_returns_401_when_mocked_token_expired_by_time( + api_client: TestClient, +) -> None: + class TimeAwareFakeAuthSessionService: + def __init__(self) -> None: + self._expires_at: datetime | None = None + + async def login(self, credentials): + self._expires_at = datetime.now(timezone.utc) + timedelta(seconds=60) + return AuthTokens( + access_token="access-token", + refresh_token="refresh-token", + ) + + async def refresh(self, refresh_token): + if self._expires_at is None: + return None + + if datetime.now(timezone.utc) >= self._expires_at: + return None + + return AuthTokens( + access_token="new-access-token", + refresh_token="refresh-token", + ) + + async def logout(self, refresh_token): + return None + + fake_service = TimeAwareFakeAuthSessionService() + app.dependency_overrides[get_auth_session_service] = lambda: fake_service + try: + with freeze_time("2026-04-11 19:00:00"): + login_response = api_client.post( + "/api/v1/auth/jwt/login", + data={"username": "admin", "password": "!QAZ1qaz"}, + ) + + assert login_response.status_code == 200 + + with freeze_time("2026-04-11 19:01:01"): + refresh_response = api_client.post( + "/api/v1/auth/jwt/refresh", + cookies={"refresh_token": "refresh-token"}, + ) + finally: + app.dependency_overrides.pop(get_auth_session_service, None) + + assert refresh_response.status_code == 401 + assert refresh_response.json()["detail"] == "Unauthorized" diff --git a/tests/modules/api/test_llm_api.py b/tests/modules/api/test_llm_api.py deleted file mode 100644 index 25dbf58..0000000 --- a/tests/modules/api/test_llm_api.py +++ /dev/null @@ -1,73 +0,0 @@ -from unittest.mock import AsyncMock - -import pytest -from fastapi.testclient import TestClient - - -@pytest.mark.unit -def test_generate_returns_service_result( - monkeypatch: pytest.MonkeyPatch, api_client: TestClient -) -> None: - mocked_generate = AsyncMock( - return_value={ - "raw_content": "# Spec", - "content": "Spec", - "sections": None, - "data": None, - } - ) - monkeypatch.setattr("app.modules.llm.router.generate_completion", mocked_generate) - - response = api_client.post( - "/api/v1/llm/generate", - json={"prompt": "Generate auth spec", "response_format": "text"}, - ) - - assert response.status_code == 200 - assert response.json()["content"] == "Spec" - - -@pytest.mark.unit -def test_generate_maps_value_error_to_502( - monkeypatch: pytest.MonkeyPatch, api_client: TestClient -) -> None: - monkeypatch.setattr( - "app.modules.llm.router.generate_completion", - AsyncMock(side_effect=ValueError("invalid provider response")), - ) - - response = api_client.post( - "/api/v1/llm/generate", - json={"prompt": "Generate auth spec", "response_format": "text"}, - ) - - assert response.status_code == 502 - assert response.json()["detail"] == "invalid provider response" - - -@pytest.mark.unit -def test_generate_maps_runtime_error_to_502( - monkeypatch: pytest.MonkeyPatch, api_client: TestClient -) -> None: - monkeypatch.setattr( - "app.modules.llm.router.generate_completion", - AsyncMock(side_effect=RuntimeError("provider timeout")), - ) - - response = api_client.post( - "/api/v1/llm/generate", - json={"prompt": "Generate auth spec", "response_format": "sections"}, - ) - - assert response.status_code == 502 - assert response.json()["detail"] == "provider timeout" - - -@pytest.mark.unit -def test_generate_validates_request_payload(api_client: TestClient) -> None: - response = api_client.post( - "/api/v1/llm/generate", - json={"prompt": "", "response_format": "text"}, - ) - - assert response.status_code == 422 diff --git a/tests/modules/auth/test_auth_schemas.py b/tests/modules/auth/test_auth_schemas.py new file mode 100644 index 0000000..3207c9b --- /dev/null +++ b/tests/modules/auth/test_auth_schemas.py @@ -0,0 +1,25 @@ +import pytest + +from app.modules.auth.schemas import UserCreate + + +@pytest.mark.unit +def test_user_create_rejects_superuser_flag() -> None: + with pytest.raises(ValueError): + UserCreate( + email="user@example.com", + password="StrongPass123!", + username="user", + is_superuser=True, + ) + + +@pytest.mark.unit +def test_user_create_defaults_to_non_superuser() -> None: + user = UserCreate( + email="user@example.com", + password="StrongPass123!", + username="user", + ) + + assert user.is_superuser is False diff --git a/tests/modules/llm/test_ollama_provider.py b/tests/modules/llm/test_ollama_provider.py deleted file mode 100644 index 98b7904..0000000 --- a/tests/modules/llm/test_ollama_provider.py +++ /dev/null @@ -1,100 +0,0 @@ -import httpx -import pytest - -from app.modules.llm.providers.ollama import OllamaClient - - -class FakeResponse: - def __init__(self, *, json_data=None, status_code: int = 200, text: str = ""): - self._json_data = json_data or {} - self.status_code = status_code - self.text = text - - def raise_for_status(self) -> None: - if self.status_code >= 400: - request = httpx.Request("POST", "http://mock/api/chat") - response = httpx.Response(self.status_code, request=request, text=self.text) - raise httpx.HTTPStatusError("error", request=request, response=response) - - def json(self): - return self._json_data - - -class FakeAsyncClient: - def __init__( - self, response: FakeResponse | None = None, error: Exception | None = None - ): - self._response = response - self._error = error - self.last_url = None - self.last_json = None - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc, tb): - return False - - async def post(self, url, json): - self.last_url = url - self.last_json = json - if self._error: - raise self._error - return self._response - - -@pytest.mark.unit -@pytest.mark.asyncio -async def test_generate_returns_content(monkeypatch: pytest.MonkeyPatch) -> None: - fake_client = FakeAsyncClient( - response=FakeResponse( - json_data={"message": {"content": "Generated specification"}} - ) - ) - monkeypatch.setattr( - "app.modules.llm.providers.ollama.httpx.AsyncClient", - lambda *args, **kwargs: fake_client, - ) - - client = OllamaClient() - result = await client.generate("Build API spec") - - assert result == "Generated specification" - assert fake_client.last_url.endswith("/api/chat") - assert fake_client.last_json["messages"][1]["content"] == "Build API spec" - - -@pytest.mark.unit -@pytest.mark.asyncio -async def test_generate_maps_timeout_to_runtime_error( - monkeypatch: pytest.MonkeyPatch, -) -> None: - fake_client = FakeAsyncClient(error=httpx.TimeoutException("timeout")) - monkeypatch.setattr( - "app.modules.llm.providers.ollama.httpx.AsyncClient", - lambda *args, **kwargs: fake_client, - ) - - client = OllamaClient() - - with pytest.raises(RuntimeError, match="timed out"): - await client.generate("Build API spec") - - -@pytest.mark.unit -@pytest.mark.asyncio -async def test_generate_maps_http_status_error_to_runtime_error( - monkeypatch: pytest.MonkeyPatch, -) -> None: - fake_client = FakeAsyncClient( - response=FakeResponse(status_code=500, text="internal error") - ) - monkeypatch.setattr( - "app.modules.llm.providers.ollama.httpx.AsyncClient", - lambda *args, **kwargs: fake_client, - ) - - client = OllamaClient() - - with pytest.raises(RuntimeError, match="Ollama returned HTTP 500"): - await client.generate("Build API spec")