diff --git a/.env.example b/.env.example index 57ec2b3..1051262 100644 --- a/.env.example +++ b/.env.example @@ -5,6 +5,10 @@ VERSION= HOST= PORT= +FASTAPI_HOST_PORT= +FASTAPI_CONTAINER_PORT= +REDIS_HOST_PORT= +REDIS_CONTAINER_PORT= DATABASE_URL=postgresql://:@:5432/?sslmode=disable SECRET_KEY= @@ -44,6 +48,10 @@ OLLAMA_RETRY_BACKOFF_SECONDS= OLLAMA_SYSTEM_PROMPT= LLM_PROMPT_MAX_LENGTH= FEATURE_SPEC_HISTORY_DEFAULT_LIMIT= +CELERY_BROKER_URL= +CELERY_RESULT_BACKEND= +CELERY_TASK_MAX_RETRIES= +CELERY_TASK_RETRY_BASE_SECONDS= TEST_DB_HOST= TEST_DB_PORT= diff --git a/README.md b/README.md index 309d479..89da555 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,11 @@ # Specification Generator API -Production-ready FastAPI backend for authentication, LLM-powered specification generation, health checks, and secure API middleware baseline. +Production-ready FastAPI backend with async LLM generation via Celery + Redis, JWT auth, and Docker Compose orchestration. + +

+ FastAPI +

python @@ -19,7 +23,9 @@ Production-ready FastAPI backend for authentication, LLM-powered specification g - JWT authentication based on fastapi-users - User registration and user management endpoints -- Feature specification generation endpoints powered by Ollama +- Async feature specification generation with Celery tasks +- Redis as Celery broker/result backend +- Ollama integration for LLM responses - Readiness and health probes for runtime checks - Alembic database migrations - Security middleware baseline: @@ -38,10 +44,20 @@ Production-ready FastAPI backend for authentication, LLM-powered specification g - SQLAlchemy 2.0 - Alembic - fastapi-users +- Celery +- Redis - PostgreSQL (via DATABASE_URL) - Ollama (LLM provider) - Pytest +## Architecture (Async Flow) + +1. Client calls `POST /api/v1/feature-spec/generate`. +2. API stores a run row and enqueues Celery task to Redis. +3. API returns immediately with `task_id` and `processing` status. +4. Celery worker calls Ollama and persists success/error in DB. +5. Client polls `GET /api/v1/feature-spec/tasks/{task_id}` for result. + ## Project Structure - app/main.py: FastAPI app bootstrap and router registration @@ -49,10 +65,11 @@ Production-ready FastAPI backend for authentication, LLM-powered specification g - app/api/: health, readiness, OpenAPI customization - app/middlewares/: security middleware composition and implementations - app/modules/auth/: auth domain (models, schemas, dependencies, router) -- app/modules/feature_spec/: feature spec API, schemas, prompts, providers +- app/modules/feature_spec/: API layer + application services for feature spec +- app/infrastructure/: Celery app, task workers, Ollama client - app/scripts/: utility scripts (admin and prompt/model bootstrap) - alembic/: migration config and versions -- docker-compose.yml: containerized app run +- docker-compose.yml: app + celery-worker + redis + ollama ## Setup Guide @@ -60,7 +77,7 @@ Production-ready FastAPI backend for authentication, LLM-powered specification g - Python 3.10+ - PostgreSQL database -- Optional: Docker + Docker Compose (recommended for VPS) +- Docker + Docker Compose (recommended) ### 1) Configure environment @@ -70,6 +87,8 @@ Required minimum: - DATABASE_URL - SECRET_KEY +- CELERY_BROKER_URL +- CELERY_RESULT_BACKEND Recommended auth bootstrap values: @@ -81,10 +100,20 @@ LLM values: - OLLAMA_BASE_URL - OLLAMA_MODEL +- OLLAMA_TIMEOUT + +Compose ports (host -> container): + +- FASTAPI_HOST_PORT=8005 +- FASTAPI_CONTAINER_PORT=8001 +- REDIS_HOST_PORT=6380 +- REDIS_CONTAINER_PORT=6379 For Docker Compose in this project use: - OLLAMA_BASE_URL=http://ollama:11434 +- CELERY_BROKER_URL=redis://redis:6379/0 +- CELERY_RESULT_BACKEND=redis://redis:6379/1 ### 2A) Run locally @@ -102,7 +131,11 @@ pip install -r requirements-dev.txt python -m alembic upgrade head python -m app.scripts.bootstrap_admin python -m app.scripts.bootstrap_prompt_template -python -m uvicorn app.main:app --host 0.0.0.0 --port 8000 +# terminal 1: api +python -m uvicorn app.main:app --host 0.0.0.0 --port 8005 + +# terminal 2: worker +celery -A app.infrastructure.celery_app:celery_app worker --loglevel=INFO ``` ### 2B) Run with Docker @@ -119,28 +152,32 @@ Notes: - prompt template bootstrap script (`python -m app.scripts.bootstrap_prompt_template`) - Ollama model bootstrap (`python -m app.scripts.ensure_ollama_model`) - uvicorn app startup +- Celery worker runs in a dedicated container (`celery-worker`). +- Redis runs only in Docker Compose and is used internally by service name `redis`. - On first deploy, startup may take longer while the configured `OLLAMA_MODEL` is downloaded. - FastAPI container reaches Ollama via internal Docker network URL: http://ollama:11434 -Verify Ollama API: +Verify services: ```bash -curl http://localhost:11434/api/generate -d '{ - "model": "mistral", - "prompt": "hello", - "stream": false -}' +docker compose ps +docker compose logs -f app +docker compose logs -f celery-worker ``` ## API Docs -When server is running locally: +When server is running locally (custom port): -- Swagger UI: http://localhost:8000/docs -- ReDoc: http://localhost:8000/redoc (pinned ReDoc 2.x script) -- OpenAPI JSON: http://localhost:8000/openapi.json +- Swagger UI: http://localhost:8005/docs +- ReDoc: http://localhost:8005/redoc (pinned ReDoc 2.x script) +- OpenAPI JSON: http://localhost:8005/openapi.json -For Docker Compose deployment, use port 8001 instead of 8000. +For Docker Compose deployment (default host mapping): + +- Swagger UI: http://localhost:8005/docs +- ReDoc: http://localhost:8005/redoc +- OpenAPI JSON: http://localhost:8005/openapi.json ## API Endpoints @@ -175,9 +212,10 @@ Login note: ### Feature Spec - POST /api/v1/feature-spec/generate +- GET /api/v1/feature-spec/tasks/{task_id} - GET /api/v1/feature-spec/history?limit=10 -Request body example for generation: +Generate request: ```json { @@ -185,6 +223,45 @@ Request body example for generation: } ``` +Generate response: + +```json +{ + "task_id": "3b44daff-1e83-4328-925f-62c22a9163d2", + "status": "processing" +} +``` + +Task status response examples: + +```json +{ + "task_id": "3b44daff-1e83-4328-925f-62c22a9163d2", + "status": "PENDING" +} +``` + +```json +{ + "task_id": "3b44daff-1e83-4328-925f-62c22a9163d2", + "status": "SUCCESS", + "result": { + "run_id": 10, + "status": "success", + "feature_idea": "payment for premium posts", + "feature_summary": { + "user_stories": [], + "acceptance_criteria": [], + "db_models_and_api_endpoints": { + "db_models": [], + "api_endpoints": [] + }, + "risk_assessment": [] + } + } +} +``` + ## Quality Checks Run linter: @@ -225,8 +302,20 @@ If app cannot connect to DB: - Verify DATABASE_URL - Verify DB network access and sslmode if needed -If LLM requests fail: +If Celery tasks stay in `PENDING`: + +- Check worker is healthy: `docker compose ps` +- Check worker logs: `docker compose logs -f celery-worker` +- Verify Redis URLs in `.env` point to `redis` service inside Docker network + +If LLM requests fail or timeout: - Verify OLLAMA_BASE_URL - Ensure Ollama is running and model is available - For Docker deployment, ensure OLLAMA_BASE_URL is http://ollama:11434 +- Increase `OLLAMA_TIMEOUT` for long generations + +If compose prints `variable is not set` for random token-like names: + +- Your `.env` likely has `$` inside secret values +- Escape `$` as `$$` in `.env` values used by Docker Compose diff --git a/app/core/settings.py b/app/core/settings.py index 56e29d6..ab3fe72 100644 --- a/app/core/settings.py +++ b/app/core/settings.py @@ -68,7 +68,7 @@ class Settings(BaseSettings): OLLAMA_BASE_URL: str = "http://localhost:11434" OLLAMA_MODEL: str = "mistral" - OLLAMA_TIMEOUT: int = 120 + OLLAMA_TIMEOUT: int = 180 OLLAMA_CONNECT_TIMEOUT: int = 10 OLLAMA_MAX_RETRIES: int = 2 OLLAMA_RETRY_BACKOFF_SECONDS: float = 1.0 @@ -81,6 +81,11 @@ class Settings(BaseSettings): LLM_PROMPT_MAX_LENGTH: int = 8000 FEATURE_SPEC_HISTORY_DEFAULT_LIMIT: int = 10 + CELERY_BROKER_URL: str = "redis://redis:6379/0" + CELERY_RESULT_BACKEND: str = "redis://redis:6379/1" + CELERY_TASK_MAX_RETRIES: int = 3 + CELERY_TASK_RETRY_BASE_SECONDS: int = 2 + model_config = SettingsConfigDict( env_file=".env", case_sensitive=True, diff --git a/app/infrastructure/__init__.py b/app/infrastructure/__init__.py new file mode 100644 index 0000000..2182630 --- /dev/null +++ b/app/infrastructure/__init__.py @@ -0,0 +1 @@ +"""Infrastructure layer modules.""" diff --git a/app/infrastructure/celery_app.py b/app/infrastructure/celery_app.py new file mode 100644 index 0000000..bc5b403 --- /dev/null +++ b/app/infrastructure/celery_app.py @@ -0,0 +1,21 @@ +from celery import Celery + +from app.core.settings import settings + + +celery_app = Celery( + "specification_generator", + broker=settings.CELERY_BROKER_URL, + backend=settings.CELERY_RESULT_BACKEND, + include=["app.infrastructure.tasks.feature_spec_tasks"], +) + +celery_app.conf.update( + broker_connection_retry_on_startup=True, + task_track_started=True, + task_serializer="json", + result_serializer="json", + accept_content=["json"], + timezone="UTC", + enable_utc=True, +) diff --git a/app/infrastructure/ollama_client.py b/app/infrastructure/ollama_client.py new file mode 100644 index 0000000..35b0046 --- /dev/null +++ b/app/infrastructure/ollama_client.py @@ -0,0 +1,39 @@ +import httpx + +from app.core.settings import settings + + +class OllamaSyncClient: + def __init__(self) -> None: + self._base_url = settings.OLLAMA_BASE_URL.rstrip("/") + self._model = settings.OLLAMA_MODEL + self._timeout = httpx.Timeout( + float(settings.OLLAMA_TIMEOUT), + connect=float(settings.OLLAMA_CONNECT_TIMEOUT), + ) + self._system_prompt = settings.OLLAMA_SYSTEM_PROMPT + + def _build_payload(self, user_prompt: str) -> dict: + messages = [] + if self._system_prompt.strip(): + messages.append({"role": "system", "content": self._system_prompt}) + messages.append({"role": "user", "content": user_prompt}) + return { + "model": self._model, + "stream": False, + "format": "json", + "messages": messages, + } + + def generate(self, user_prompt: str) -> str: + url = f"{self._base_url}/api/chat" + payload = self._build_payload(user_prompt) + with httpx.Client(timeout=self._timeout) as client: + response = client.post(url, json=payload, timeout=self._timeout) + response.raise_for_status() + data = response.json() + content = data.get("message", {}).get("content", "") + return content if isinstance(content, str) else str(content) + + +ollama_sync_client = OllamaSyncClient() diff --git a/app/infrastructure/tasks/__init__.py b/app/infrastructure/tasks/__init__.py new file mode 100644 index 0000000..27a2a9f --- /dev/null +++ b/app/infrastructure/tasks/__init__.py @@ -0,0 +1 @@ +"""Celery task modules.""" diff --git a/app/infrastructure/tasks/feature_spec_tasks.py b/app/infrastructure/tasks/feature_spec_tasks.py new file mode 100644 index 0000000..bd39ff1 --- /dev/null +++ b/app/infrastructure/tasks/feature_spec_tasks.py @@ -0,0 +1,84 @@ +import httpx +from celery.exceptions import MaxRetriesExceededError + +from app.core.database import SessionLocal +from app.core.settings import settings +from app.infrastructure.celery_app import celery_app +from app.infrastructure.ollama_client import ollama_sync_client +from app.modules.feature_spec.models import FeatureSpecRun +from app.modules.feature_spec.prompts.feature_summary import ( + build_feature_summary_prompt_from_db, + parse_feature_summary_response, +) +from app.modules.feature_spec.schemas import FeatureSummaryResult + + +def _ensure_auth_models_loaded() -> None: + from app.modules.auth import models as auth_models + + _ = auth_models.User + + +@celery_app.task(bind=True, name="feature_spec.generate") +def generate_feature_spec_task(self, run_id: int, feature_idea: str, user_id: int) -> dict: + _ensure_auth_models_loaded() + db = SessionLocal() + try: + run = db.get(FeatureSpecRun, run_id) + if run is None: + return {"run_id": run_id, "status": "error", "message": "Run not found"} + + if run.status == "success" and run.response_json is not None: + return { + "run_id": run.id, + "status": run.status, + "feature_idea": run.feature_idea, + "feature_summary": run.response_json, + } + + try: + prompt = build_feature_summary_prompt_from_db(feature_idea, db) + feature_summary_raw = ollama_sync_client.generate(prompt) + feature_summary_json = parse_feature_summary_response(feature_summary_raw) + feature_summary_typed = FeatureSummaryResult.model_validate(feature_summary_json) + + run.status = "success" + run.response_json = feature_summary_typed.model_dump(mode="json") + run.error_message = None + db.add(run) + db.commit() + + return { + "run_id": run.id, + "status": run.status, + "feature_idea": run.feature_idea, + "feature_summary": run.response_json, + } + except ( + httpx.TimeoutException, + httpx.RequestError, + httpx.HTTPStatusError, + ) as exc: + retry_number = int(self.request.retries) + 1 + if retry_number <= settings.CELERY_TASK_MAX_RETRIES: + delay_seconds = settings.CELERY_TASK_RETRY_BASE_SECONDS**retry_number + try: + raise self.retry(exc=exc, countdown=delay_seconds) + except MaxRetriesExceededError: + pass + + db.rollback() + run.status = "error" + run.error_message = "LLM provider request failed" + db.add(run) + db.commit() + raise RuntimeError("LLM task failed after retries") from exc + except Exception as exc: + db.rollback() + run.status = "error" + run.error_message = "Failed to process feature specification" + db.add(run) + db.commit() + raise RuntimeError("Feature specification task failed") from exc + finally: + db.close() diff --git a/app/modules/feature_spec/__init__.py b/app/modules/feature_spec/__init__.py deleted file mode 100644 index c57be5d..0000000 --- a/app/modules/feature_spec/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from app.modules.feature_spec.router import router - -__all__ = ["router"] diff --git a/app/modules/feature_spec/orchestrator.py b/app/modules/feature_spec/orchestrator.py index 79f7f8b..b3f2a30 100644 --- a/app/modules/feature_spec/orchestrator.py +++ b/app/modules/feature_spec/orchestrator.py @@ -3,6 +3,7 @@ from app.core.settings import settings from app.modules.feature_spec.models import FeatureSpecRun +from app.modules.feature_spec.parser import normalize_feature_summary_payload from app.modules.feature_spec.prompts.feature_summary import ( build_feature_summary_prompt_from_db, parse_feature_summary_response, @@ -92,7 +93,9 @@ def get_feature_spec_history( feature_idea=row.feature_idea, status=row.status, response_json=( - FeatureSummaryResult.model_validate(row.response_json) + FeatureSummaryResult.model_validate( + normalize_feature_summary_payload(row.response_json) + ) if row.response_json is not None else None ), diff --git a/app/modules/feature_spec/parser.py b/app/modules/feature_spec/parser.py index de7a156..5b71088 100644 --- a/app/modules/feature_spec/parser.py +++ b/app/modules/feature_spec/parser.py @@ -1,18 +1,274 @@ import json +import logging import re +from typing import Any -def extract_json(text: str) -> dict | list: - decoder = json.JSONDecoder() - for index, char in enumerate(text): - if char not in "[{": +logger = logging.getLogger(__name__) + + +def _is_missing(value: Any) -> bool: + if value is None: + return True + if isinstance(value, str): + return not value.strip() + if isinstance(value, (list, dict)): + return not value + return False + + +def _normalize_acceptance_criteria(value: Any, *, strict: bool = False) -> list[str] | None: + if value is None: + return None + + items = value if isinstance(value, list) else [value] + normalized_items: list[str] = [] + + for item in items: + if isinstance(item, str): + text = item.strip() + elif isinstance(item, dict): + text = next( + ( + candidate.strip() + for key in ( + "title", + "description", + "details", + "criterion", + "text", + "message", + ) + if isinstance((candidate := item.get(key)), str) + and candidate.strip() + ), + str(item).strip(), + ) + else: + text = str(item).strip() + + if text: + normalized_items.append(text) + elif strict: + raise ValueError("Invalid acceptance_criteria item") + + return normalized_items or None + + +def _normalize_risk_assessment(value: Any, *, strict: bool = False) -> list[str] | None: + if value is None: + return None + + items = value if isinstance(value, list) else [value] + normalized_items: list[str] = [] + + for item in items: + if isinstance(item, str): + text = item.strip() + elif isinstance(item, dict): + text = next( + ( + candidate.strip() + for key in ("title", "risk", "description", "details", "message") + if isinstance((candidate := item.get(key)), str) + and candidate.strip() + ), + str(item).strip(), + ) + else: + text = str(item).strip() + + if text: + normalized_items.append(text) + elif strict: + raise ValueError("Invalid risk_assessment item") + + return normalized_items or None + + +def _normalize_mixed_items(value: Any, *, strict: bool = False) -> list[str | dict] | None: + if value is None: + return None + + items = value if isinstance(value, list) else [value] + normalized_items: list[str | dict] = [] + + for item in items: + if isinstance(item, dict): + normalized_items.append(item) + continue + + text = str(item).strip() + if text: + normalized_items.append(text) + elif strict: + raise ValueError("Invalid db/api item") + + return normalized_items or None + + +def _first_non_empty(source: dict[str, Any], *keys: str) -> str | None: + for key in keys: + value = source.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + return None + + +def _normalize_user_stories(value: Any, *, strict: bool = False) -> list[dict[str, str]] | None: + if value is None: + return None + + items = value if isinstance(value, list) else [value] + normalized_items: list[dict] = [] + + for index, item in enumerate(items, start=1): + if isinstance(item, dict): + title = _first_non_empty(item, "title", "name") + as_a = _first_non_empty(item, "as_a", "as", "actor", "user") + i_want = _first_non_empty(item, "i_want", "want", "goal", "objective") + so_that = _first_non_empty(item, "so_that", "benefit", "because", "outcome") + if not i_want: + i_want = _first_non_empty(item, "description", "details", "text") + if not so_that: + so_that = i_want + + if not all([title, as_a, i_want, so_that]): + if strict: + raise ValueError(f"Invalid user story at index {index}") + logger.warning("Skipping invalid user story item at index %s", index) + continue + + normalized_items.append( + { + "title": title, + "as_a": as_a, + "i_want": i_want, + "so_that": so_that, + } + ) continue + + if isinstance(item, str) and item.strip(): + text = item.strip() + if strict: + raise ValueError( + f"Invalid user story string item at index {index}; object expected" + ) + normalized_items.append( + { + "title": text, + "as_a": "User", + "i_want": text, + "so_that": text, + } + ) + elif strict: + raise ValueError(f"Invalid user story item type at index {index}") + + return normalized_items or None + + +def normalize_feature_summary_payload( + data: dict[str, Any] | list, + *, + strict: bool = False, +) -> dict[str, Any] | list: + if not isinstance(data, dict): + return data + + normalized = dict(data) + + if isinstance(normalized.get("feature_summary"), dict): + nested = normalized["feature_summary"] + logger.info("Using nested feature_summary payload") + for key in ( + "user_stories", + "feature_summary_items", + "acceptance_criteria", + "acceptance", + "db_models_and_api_endpoints", + "db_models", + "api_endpoints", + "risk_assessment", + "risks", + ): + if key not in normalized and key in nested: + normalized[key] = nested[key] + + user_stories_source = normalized.get("user_stories") + if _is_missing(user_stories_source): + user_stories_source = normalized.get("feature_summary_items") + + normalized_user_stories = _normalize_user_stories(user_stories_source, strict=strict) + if normalized_user_stories is not None: + normalized["user_stories"] = normalized_user_stories + + acceptance_source = normalized.get("acceptance_criteria") + if _is_missing(acceptance_source): + logger.info("Using legacy acceptance field") + acceptance_source = normalized.get("acceptance") + + normalized_acceptance = _normalize_acceptance_criteria( + acceptance_source, + strict=strict, + ) + if normalized_acceptance is not None: + normalized["acceptance_criteria"] = normalized_acceptance + + db_api_source = normalized.get("db_models_and_api_endpoints") + db_source = normalized.get("db_models") + api_source = normalized.get("api_endpoints") + + if isinstance(db_api_source, dict): + db_source = db_api_source.get("db_models", db_source) + api_source = db_api_source.get("api_endpoints", api_source) + + normalized["db_models_and_api_endpoints"] = { + "db_models": _normalize_mixed_items(db_source, strict=strict) or [], + "api_endpoints": _normalize_mixed_items(api_source, strict=strict) or [], + } + + risk_source = normalized.get("risk_assessment") + if _is_missing(risk_source): + logger.info("Using legacy risks field") + risk_source = normalized.get("risks") + + normalized_risks = _normalize_risk_assessment(risk_source, strict=strict) + if normalized_risks is not None: + normalized["risk_assessment"] = normalized_risks + + return normalized + + +def extract_json(text: str) -> dict | list: + stripped = text.strip() + if stripped: try: - parsed, _ = decoder.raw_decode(text[index:]) + parsed = json.loads(stripped) + if isinstance(parsed, (dict, list)): + return parsed + except json.JSONDecodeError: + pass + + fenced_match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL) + if fenced_match: + candidate = fenced_match.group(1).strip() + try: + parsed = json.loads(candidate) + if isinstance(parsed, (dict, list)): + return parsed + except json.JSONDecodeError: + pass + + for candidate in re.findall(r"\{[\s\S]*?\}|\[[\s\S]*?\]", text): + try: + parsed = json.loads(candidate) except json.JSONDecodeError: continue if isinstance(parsed, (dict, list)): return parsed + raise ValueError("No JSON found in LLM response") diff --git a/app/modules/feature_spec/prompts/feature_summary.py b/app/modules/feature_spec/prompts/feature_summary.py index 895889e..9deb419 100644 --- a/app/modules/feature_spec/prompts/feature_summary.py +++ b/app/modules/feature_spec/prompts/feature_summary.py @@ -4,7 +4,12 @@ from sqlalchemy.orm import Session from app.modules.feature_spec.models import PromptTemplate -from app.modules.feature_spec.parser import extract_json, normalize_whitespace, strip_markdown +from app.modules.feature_spec.parser import ( + extract_json, + normalize_feature_summary_payload, + normalize_whitespace, + strip_markdown, +) def build_feature_summary_prompt_from_template(template: str, feature_idea: str) -> str: @@ -37,4 +42,5 @@ def build_feature_summary_prompt_from_db(feature_idea: str, db: Session) -> str: def parse_feature_summary_response(raw_response: str) -> dict | list: normalized_content = normalize_whitespace(strip_markdown(raw_response)) - return extract_json(normalized_content) + parsed = extract_json(normalized_content) + return normalize_feature_summary_payload(parsed) diff --git a/app/modules/feature_spec/providers/ollama.py b/app/modules/feature_spec/providers/ollama.py index 08ff0d3..8e0fcca 100644 --- a/app/modules/feature_spec/providers/ollama.py +++ b/app/modules/feature_spec/providers/ollama.py @@ -54,6 +54,11 @@ async def generate(self, user_prompt: str) -> str: message = data.get("message", {}) content = message.get("content", "") return content if isinstance(content, str) else str(content) + except (json.JSONDecodeError, ValueError) as exc: + if attempt < self._max_retries: + await self._backoff(attempt) + continue + raise RuntimeError("Ollama returned invalid JSON response") from exc except httpx.TimeoutException as exc: if attempt < self._max_retries: await self._backoff(attempt) @@ -90,6 +95,7 @@ async def generate_stream(self, user_prompt: str) -> AsyncGenerator[str, None]: try: async with client.stream("POST", url, json=payload) as response: response.raise_for_status() + stream_done = False async for line in response.aiter_lines(): if not line: continue @@ -103,8 +109,10 @@ async def generate_stream(self, user_prompt: str) -> AsyncGenerator[str, None]: yielded_any = True yield token if chunk.get("done"): + stream_done = True return - return + if not stream_done: + raise RuntimeError("Ollama stream ended before completion") except httpx.TimeoutException as exc: if not yielded_any and attempt < self._max_retries: await self._backoff(attempt) diff --git a/app/modules/feature_spec/router.py b/app/modules/feature_spec/router.py index de4d668..49beb22 100644 --- a/app/modules/feature_spec/router.py +++ b/app/modules/feature_spec/router.py @@ -1,41 +1,37 @@ -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends from sqlalchemy.orm import Session from app.core.database import get_db from app.core.settings import settings from app.modules.auth.dependencies import get_current_user from app.modules.auth.models import User -from app.modules.feature_spec.orchestrator import ( - generate_feature_spec, - get_feature_spec_history, -) +from app.modules.feature_spec.orchestrator import get_feature_spec_history from app.modules.feature_spec.schemas import ( FeatureSpecGenerateRequest, - FeatureSpecGenerateResponse, FeatureSpecHistoryResponse, + FeatureSpecTaskStatusResponse, + FeatureSpecTaskSubmitResponse, +) +from app.modules.feature_spec.service import ( + get_feature_spec_task_status, + submit_feature_spec_generation, ) router = APIRouter(prefix="/feature-spec", tags=["feature-spec"]) -@router.post("/generate", response_model=FeatureSpecGenerateResponse) +@router.post("/generate", response_model=FeatureSpecTaskSubmitResponse) async def generate_feature_spec_endpoint( payload: FeatureSpecGenerateRequest, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), -) -> FeatureSpecGenerateResponse: - try: - return await generate_feature_spec(payload.feature_idea, db, current_user.id) - except ValueError as exc: - raise HTTPException( - status_code=status.HTTP_502_BAD_GATEWAY, - detail=str(exc), - ) from exc - except RuntimeError as exc: - raise HTTPException( - status_code=status.HTTP_502_BAD_GATEWAY, - detail=str(exc), - ) from exc +) -> FeatureSpecTaskSubmitResponse: + return submit_feature_spec_generation(payload.feature_idea, db, current_user.id) + + +@router.get("/tasks/{task_id}", response_model=FeatureSpecTaskStatusResponse) +async def get_feature_spec_task_status_endpoint(task_id: str) -> FeatureSpecTaskStatusResponse: + return get_feature_spec_task_status(task_id) @router.get("/history", response_model=FeatureSpecHistoryResponse) diff --git a/app/modules/feature_spec/schemas.py b/app/modules/feature_spec/schemas.py index 79af927..43eb23f 100644 --- a/app/modules/feature_spec/schemas.py +++ b/app/modules/feature_spec/schemas.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Any, Literal -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator from app.core.settings import settings @@ -36,38 +36,22 @@ class FeatureSummaryResult(BaseModel): db_models_and_api_endpoints: DbModelsAndApiEndpoints risk_assessment: list[str] - @model_validator(mode="before") - @classmethod - def normalize_legacy_user_stories(cls, data): - if not isinstance(data, dict): - return data - - normalized = dict(data) - - if "user_stories" not in normalized and "feature_summary_items" in normalized: - normalized["user_stories"] = normalized["feature_summary_items"] - if "acceptance_criteria" not in normalized: - if "acceptance" in normalized: - normalized["acceptance_criteria"] = normalized["acceptance"] - - if "db_models_and_api_endpoints" not in normalized: - db_models = normalized.get("db_models", []) - api_endpoints = normalized.get("api_endpoints", []) - normalized["db_models_and_api_endpoints"] = { - "db_models": db_models, - "api_endpoints": api_endpoints, - } +class FeatureSpecGenerateResponse(BaseModel): + feature_idea: str + feature_summary: FeatureSummaryResult - if "risk_assessment" not in normalized and "risks" in normalized: - normalized["risk_assessment"] = normalized["risks"] - return normalized +class FeatureSpecTaskSubmitResponse(BaseModel): + task_id: str + status: Literal["processing"] -class FeatureSpecGenerateResponse(BaseModel): - feature_idea: str - feature_summary: FeatureSummaryResult +class FeatureSpecTaskStatusResponse(BaseModel): + task_id: str + status: Literal["PENDING", "STARTED", "SUCCESS", "FAILURE"] + result: dict[str, Any] | None = None + error: str | None = None class FeatureSpecHistoryItem(BaseModel): diff --git a/app/modules/feature_spec/service.py b/app/modules/feature_spec/service.py new file mode 100644 index 0000000..ff13d09 --- /dev/null +++ b/app/modules/feature_spec/service.py @@ -0,0 +1,56 @@ +from celery.result import AsyncResult +from sqlalchemy.orm import Session + +from app.infrastructure.celery_app import celery_app +from app.infrastructure.tasks.feature_spec_tasks import generate_feature_spec_task +from app.modules.feature_spec.models import FeatureSpecRun +from app.modules.feature_spec.schemas import ( + FeatureSpecTaskStatusResponse, + FeatureSpecTaskSubmitResponse, +) + + +TERMINAL_STATES = {"SUCCESS", "FAILURE"} + + +def submit_feature_spec_generation( + feature_idea: str, + db: Session, + user_id: int, +) -> FeatureSpecTaskSubmitResponse: + run = FeatureSpecRun( + user_id=user_id, + feature_idea=feature_idea, + status="pending", + ) + db.add(run) + db.commit() + db.refresh(run) + + task = generate_feature_spec_task.delay(run.id, feature_idea, user_id) + return FeatureSpecTaskSubmitResponse(task_id=task.id, status="processing") + + +def get_feature_spec_task_status(task_id: str) -> FeatureSpecTaskStatusResponse: + async_result = AsyncResult(task_id, app=celery_app) + state = async_result.state + + if state == "SUCCESS": + payload = async_result.result if isinstance(async_result.result, dict) else None + return FeatureSpecTaskStatusResponse( + task_id=task_id, + status="SUCCESS", + result=payload, + ) + + if state == "FAILURE": + return FeatureSpecTaskStatusResponse( + task_id=task_id, + status="FAILURE", + error="Task execution failed", + ) + + if state == "STARTED": + return FeatureSpecTaskStatusResponse(task_id=task_id, status="STARTED") + + return FeatureSpecTaskStatusResponse(task_id=task_id, status="PENDING") diff --git a/docker-compose.yml b/docker-compose.yml index 906423e..ef1c199 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,10 +8,14 @@ services: env_file: - .env ports: - - "8001:8001" + - "${FASTAPI_HOST_PORT:-8005}:${FASTAPI_CONTAINER_PORT:-8001}" depends_on: + redis: + condition: service_healthy ollama: condition: service_healthy + celery-worker: + condition: service_healthy healthcheck: test: [ @@ -27,6 +31,48 @@ services: stop_grace_period: 20s restart: unless-stopped + celery-worker: + build: + context: . + dockerfile: Dockerfile + container_name: specification-generator-celery-worker + init: true + env_file: + - .env + command: ["celery", "-A", "app.infrastructure.celery_app:celery_app", "worker", "--loglevel=INFO"] + depends_on: + redis: + condition: service_healthy + ollama: + condition: service_healthy + healthcheck: + test: + [ + "CMD", + "python", + "-c", + "from app.infrastructure.celery_app import celery_app; import sys; i=celery_app.control.inspect(timeout=2); sys.exit(0 if i and i.ping() else 1)", + ] + interval: 30s + timeout: 5s + retries: 3 + start_period: 20s + restart: unless-stopped + + redis: + image: redis:7-alpine + container_name: specification-generator-redis + init: true + ports: + - "${REDIS_HOST_PORT:-6380}:${REDIS_CONTAINER_PORT:-6379}" + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 3s + retries: 5 + start_period: 10s + restart: unless-stopped + ollama: image: ollama/ollama:latest container_name: specification-generator-ollama diff --git a/entrypoint.sh b/entrypoint.sh index b6729c1..d7f387a 100644 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -14,7 +14,7 @@ echo "[start] Ensuring Ollama model is available..." python -m app.scripts.ensure_ollama_model if [ "$#" -eq 0 ]; then - set -- python -m uvicorn app.main:app --host 0.0.0.0 --port 8001 + set -- python -m uvicorn app.main:app --host 0.0.0.0 --port "${PORT:-8001}" fi echo "[start] Starting: $*" diff --git a/requirements.txt b/requirements.txt index 89bbdb0..da29dc2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,20 +1,15 @@ -# FastAPI core fastapi==0.111.0 uvicorn[standard]==0.30.0 -# Database sqlalchemy==2.0.30 psycopg2-binary==2.9.9 -# Migrations alembic==1.13.1 -# Data validation pydantic==2.7.1 pydantic-settings==2.2.1 -# Auth / security python-multipart==0.0.20 fastapi-users==14.0.1 fastapi-users-db-sqlalchemy[sqlalchemy]==7.0.0 @@ -23,12 +18,11 @@ asyncpg==0.30.0 argon2-cffi==23.1.0 -# Environment python-dotenv==1.0.1 -# HTTP requests httpx==0.27.0 -# Admin panel +celery[redis]==5.4.0 + sqladmin==0.17.0 itsdangerous==2.2.0